Skip to main content

rlmesh_spaces/
types.rs

1use crate::dtype::DType;
2
3#[derive(Debug, Clone, PartialEq, Default)]
4pub struct UniformBounds {
5    pub low: f64,
6    pub high: f64,
7}
8
9#[derive(Debug, Clone, PartialEq, Default)]
10pub struct ElementwiseBounds {
11    pub low: Vec<f64>,
12    pub high: Vec<f64>,
13}
14
15/// A single low/high bound pair carried as raw little-endian bytes in the
16/// enclosing space's dtype (one scalar each).
17///
18/// `low`/`high` must each be exactly `dtype_size(dtype)` bytes. Used so that
19/// integer dtypes (notably `int64`/`uint64`) carry exact bounds instead of
20/// losing precision through `f64`.
21#[derive(Debug, Clone, PartialEq, Eq, Default)]
22pub struct TypedUniformBounds {
23    pub low: Vec<u8>,
24    pub high: Vec<u8>,
25}
26
27/// Per-element low/high bounds carried as raw little-endian bytes in the
28/// enclosing space's dtype, in row-major (C) order.
29///
30/// `low`/`high` must each be exactly `numel * dtype_size(dtype)` bytes.
31#[derive(Debug, Clone, PartialEq, Eq, Default)]
32pub struct TypedElementwiseBounds {
33    pub low: Vec<u8>,
34    pub high: Vec<u8>,
35}
36
37#[derive(Debug, Clone, PartialEq, Default)]
38pub struct BoxSpec {
39    pub bounds: Option<BoxBounds>,
40}
41
42/// Bounds for a Box space (the proto `BoxBounds.bounds` oneof).
43#[derive(Debug, Clone, PartialEq)]
44pub enum BoxBounds {
45    Unbounded(bool),
46    Uniform(UniformBounds),
47    Elementwise(ElementwiseBounds),
48    TypedUniform(TypedUniformBounds),
49    TypedElementwise(TypedElementwiseBounds),
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, Default)]
53pub struct DiscreteSpec {
54    pub n: i64,
55    pub start: i64,
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, Default)]
59pub struct MultiBinarySpec {
60    pub n: Option<MultiBinaryDims>,
61}
62
63/// Size description for a MultiBinary space (the proto `n` oneof).
64#[derive(Debug, Clone, PartialEq, Eq)]
65pub enum MultiBinaryDims {
66    Size(i64),
67    Dims(Vec<i64>),
68}
69
70#[derive(Debug, Clone, PartialEq, Eq, Default)]
71pub struct MultiDiscreteSpec {
72    pub nvec: Option<MultiDiscreteNvec>,
73}
74
75/// Count layout for a MultiDiscrete space (the proto `nvec` oneof).
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub enum MultiDiscreteNvec {
78    Flat(Vec<i64>),
79    Shaped(Vec<Vec<i64>>),
80}
81
82#[derive(Debug, Clone, PartialEq, Eq, Default)]
83pub struct TextSpec {
84    pub min_length: i64,
85    pub max_length: i64,
86    pub charset: String,
87}
88
89#[derive(Debug, Clone, PartialEq, Default)]
90pub struct DictSpec {
91    pub keys: Vec<String>,
92    pub spaces: Vec<SpaceSpec>,
93}
94
95#[derive(Debug, Clone, PartialEq, Default)]
96pub struct TupleSpec {
97    pub spaces: Vec<SpaceSpec>,
98}
99
100#[derive(Debug, Clone, PartialEq, Default)]
101pub struct SpaceSpec {
102    pub shape: Vec<i64>,
103    pub dtype: DType,
104    pub spec: Option<SpaceKind>,
105}
106
107impl SpaceSpec {
108    pub fn space_type(&self) -> SpaceType {
109        match self.spec {
110            Some(SpaceKind::Box(_)) => SpaceType::Box,
111            Some(SpaceKind::Discrete(_)) => SpaceType::Discrete,
112            Some(SpaceKind::MultiBinary(_)) => SpaceType::MultiBinary,
113            Some(SpaceKind::MultiDiscrete(_)) => SpaceType::MultiDiscrete,
114            Some(SpaceKind::Text(_)) => SpaceType::Text,
115            Some(SpaceKind::Dict(_)) => SpaceType::Dict,
116            Some(SpaceKind::Tuple(_)) => SpaceType::Tuple,
117            None => SpaceType::Unspecified,
118        }
119    }
120}
121
122#[derive(Debug, Clone, PartialEq)]
123pub enum SpaceKind {
124    Box(BoxSpec),
125    Discrete(DiscreteSpec),
126    MultiBinary(MultiBinarySpec),
127    MultiDiscrete(MultiDiscreteSpec),
128    Text(TextSpec),
129    Dict(DictSpec),
130    Tuple(TupleSpec),
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
134#[repr(i32)]
135pub enum SpaceType {
136    #[default]
137    Unspecified = 0,
138    Box = 1,
139    Discrete = 2,
140    MultiBinary = 3,
141    MultiDiscrete = 4,
142    Text = 5,
143    Dict = 10,
144    Tuple = 11,
145}
146
147impl TryFrom<i32> for SpaceType {
148    type Error = &'static str;
149
150    fn try_from(value: i32) -> Result<Self, Self::Error> {
151        match value {
152            0 => Ok(Self::Unspecified),
153            1 => Ok(Self::Box),
154            2 => Ok(Self::Discrete),
155            3 => Ok(Self::MultiBinary),
156            4 => Ok(Self::MultiDiscrete),
157            5 => Ok(Self::Text),
158            10 => Ok(Self::Dict),
159            11 => Ok(Self::Tuple),
160            _ => Err("invalid space type"),
161        }
162    }
163}
164
165impl From<SpaceType> for i32 {
166    fn from(value: SpaceType) -> Self {
167        value as i32
168    }
169}
170
171/// Per-lane autoreset convention an environment follows (mirrors the proto
172/// `AutoresetMode` and gymnasium's `AutoresetMode`). There is intentionally no
173/// `Unspecified` variant: the proto `UNSPECIFIED` (0) decodes to `Disabled`, the
174/// safe explicit-reset default. Unknown values are not folded to a default;
175/// [`AutoresetMode::try_from`] rejects them so a newer peer's mode this build
176/// does not understand fails loudly instead of silently changing lifecycle
177/// semantics.
178///
179/// The numeric discriminants here must stay in sync with the proto
180/// `AutoresetMode` (UNSPECIFIED=0, NEXT_STEP=1, SAME_STEP=2, DISABLED=3); the
181/// `autoreset_mode_i32_roundtrip` test below locks this so future drift fails.
182#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
183#[repr(i32)]
184pub enum AutoresetMode {
185    /// Terminal obs at step `t`; the env internally resets the done lane and
186    /// delivers the fresh obs at `t+1`.
187    NextStep = 1,
188    /// Reset obs delivered in the same step the lane reports done. Reserved;
189    /// not honored by the runtime yet.
190    SameStep = 2,
191    /// No autoreset; a done lane stays inactive until an explicit reset.
192    #[default]
193    Disabled = 3,
194}
195
196/// Error from [`AutoresetMode::try_from`] when an i32 names no known mode.
197#[derive(Debug, Clone, Copy, PartialEq, Eq)]
198pub struct UnknownAutoresetMode(pub i32);
199
200impl std::fmt::Display for UnknownAutoresetMode {
201    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202        write!(f, "unknown autoreset mode {}", self.0)
203    }
204}
205
206impl std::error::Error for UnknownAutoresetMode {}
207
208impl TryFrom<i32> for AutoresetMode {
209    type Error = UnknownAutoresetMode;
210    fn try_from(value: i32) -> Result<Self, Self::Error> {
211        match value {
212            1 => Ok(Self::NextStep),
213            2 => Ok(Self::SameStep),
214            // proto UNSPECIFIED (0) and DISABLED (3) both decode to the safe
215            // explicit-reset default; anything else is rejected loudly.
216            0 | 3 => Ok(Self::Disabled),
217            other => Err(UnknownAutoresetMode(other)),
218        }
219    }
220}
221
222impl From<AutoresetMode> for i32 {
223    fn from(value: AutoresetMode) -> Self {
224        value as i32
225    }
226}
227
228#[derive(Debug, Clone, PartialEq, Default)]
229pub struct EnvContract {
230    pub id: String,
231    pub action_space: Option<SpaceSpec>,
232    pub observation_space: Option<SpaceSpec>,
233    pub metadata: Option<crate::meta::MetaMap>,
234    pub render_mode: String,
235    pub num_envs: u32,
236    /// Per-lane autoreset convention the runtime honors. Derived at construction
237    /// from the env's `metadata["autoreset_mode"]`; defaults to `Disabled`.
238    pub autoreset_mode: AutoresetMode,
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn autoreset_mode_i32_roundtrip() {
247        // Native -> i32 discriminants, locked to the proto AutoresetMode values
248        // (UNSPECIFIED=0, NEXT_STEP=1, SAME_STEP=2, DISABLED=3).
249        assert_eq!(i32::from(AutoresetMode::NextStep), 1);
250        assert_eq!(i32::from(AutoresetMode::SameStep), 2);
251        assert_eq!(i32::from(AutoresetMode::Disabled), 3);
252
253        // i32 -> native: proto UNSPECIFIED (0) and DISABLED (3) decode to the
254        // safe Disabled default; 1/2 to their modes.
255        assert_eq!(AutoresetMode::try_from(0), Ok(AutoresetMode::Disabled));
256        assert_eq!(AutoresetMode::try_from(1), Ok(AutoresetMode::NextStep));
257        assert_eq!(AutoresetMode::try_from(2), Ok(AutoresetMode::SameStep));
258        assert_eq!(AutoresetMode::try_from(3), Ok(AutoresetMode::Disabled));
259
260        // An unknown value is rejected loudly, never folded to a default.
261        assert_eq!(AutoresetMode::try_from(99), Err(UnknownAutoresetMode(99)));
262        assert!(AutoresetMode::try_from(-1).is_err());
263
264        // Every native variant round-trips through i32 and back.
265        for v in [
266            AutoresetMode::NextStep,
267            AutoresetMode::SameStep,
268            AutoresetMode::Disabled,
269        ] {
270            assert_eq!(AutoresetMode::try_from(i32::from(v)), Ok(v));
271        }
272    }
273}