1use crate::dtype::DType;
2
3#[derive(Debug, Clone, PartialEq, Default)]
4pub struct UniformBounds {
5 pub low: f64,
6 pub high: f64,
7}
8
9#[derive(Debug, Clone, PartialEq, Default)]
10pub struct ElementwiseBounds {
11 pub low: Vec<f64>,
12 pub high: Vec<f64>,
13}
14
15#[derive(Debug, Clone, PartialEq, Eq, Default)]
22pub struct TypedUniformBounds {
23 pub low: Vec<u8>,
24 pub high: Vec<u8>,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq, Default)]
32pub struct TypedElementwiseBounds {
33 pub low: Vec<u8>,
34 pub high: Vec<u8>,
35}
36
37#[derive(Debug, Clone, PartialEq, Default)]
38pub struct BoxSpec {
39 pub bounds: Option<BoxBounds>,
40}
41
42#[derive(Debug, Clone, PartialEq)]
44pub enum BoxBounds {
45 Unbounded(bool),
46 Uniform(UniformBounds),
47 Elementwise(ElementwiseBounds),
48 TypedUniform(TypedUniformBounds),
49 TypedElementwise(TypedElementwiseBounds),
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, Default)]
53pub struct DiscreteSpec {
54 pub n: i64,
55 pub start: i64,
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, Default)]
59pub struct MultiBinarySpec {
60 pub n: Option<MultiBinaryDims>,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
65pub enum MultiBinaryDims {
66 Size(i64),
67 Dims(Vec<i64>),
68}
69
70#[derive(Debug, Clone, PartialEq, Eq, Default)]
71pub struct MultiDiscreteSpec {
72 pub nvec: Option<MultiDiscreteNvec>,
73}
74
75#[derive(Debug, Clone, PartialEq, Eq)]
77pub enum MultiDiscreteNvec {
78 Flat(Vec<i64>),
79 Shaped(Vec<Vec<i64>>),
80}
81
82#[derive(Debug, Clone, PartialEq, Eq, Default)]
83pub struct TextSpec {
84 pub min_length: i64,
85 pub max_length: i64,
86 pub charset: String,
87}
88
89#[derive(Debug, Clone, PartialEq, Default)]
90pub struct DictSpec {
91 pub keys: Vec<String>,
92 pub spaces: Vec<SpaceSpec>,
93}
94
95#[derive(Debug, Clone, PartialEq, Default)]
96pub struct TupleSpec {
97 pub spaces: Vec<SpaceSpec>,
98}
99
100#[derive(Debug, Clone, PartialEq, Default)]
101pub struct SpaceSpec {
102 pub shape: Vec<i64>,
103 pub dtype: DType,
104 pub spec: Option<SpaceKind>,
105}
106
107impl SpaceSpec {
108 pub fn space_type(&self) -> SpaceType {
109 match self.spec {
110 Some(SpaceKind::Box(_)) => SpaceType::Box,
111 Some(SpaceKind::Discrete(_)) => SpaceType::Discrete,
112 Some(SpaceKind::MultiBinary(_)) => SpaceType::MultiBinary,
113 Some(SpaceKind::MultiDiscrete(_)) => SpaceType::MultiDiscrete,
114 Some(SpaceKind::Text(_)) => SpaceType::Text,
115 Some(SpaceKind::Dict(_)) => SpaceType::Dict,
116 Some(SpaceKind::Tuple(_)) => SpaceType::Tuple,
117 None => SpaceType::Unspecified,
118 }
119 }
120}
121
122#[derive(Debug, Clone, PartialEq)]
123pub enum SpaceKind {
124 Box(BoxSpec),
125 Discrete(DiscreteSpec),
126 MultiBinary(MultiBinarySpec),
127 MultiDiscrete(MultiDiscreteSpec),
128 Text(TextSpec),
129 Dict(DictSpec),
130 Tuple(TupleSpec),
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
134#[repr(i32)]
135pub enum SpaceType {
136 #[default]
137 Unspecified = 0,
138 Box = 1,
139 Discrete = 2,
140 MultiBinary = 3,
141 MultiDiscrete = 4,
142 Text = 5,
143 Dict = 10,
144 Tuple = 11,
145}
146
147impl TryFrom<i32> for SpaceType {
148 type Error = &'static str;
149
150 fn try_from(value: i32) -> Result<Self, Self::Error> {
151 match value {
152 0 => Ok(Self::Unspecified),
153 1 => Ok(Self::Box),
154 2 => Ok(Self::Discrete),
155 3 => Ok(Self::MultiBinary),
156 4 => Ok(Self::MultiDiscrete),
157 5 => Ok(Self::Text),
158 10 => Ok(Self::Dict),
159 11 => Ok(Self::Tuple),
160 _ => Err("invalid space type"),
161 }
162 }
163}
164
165impl From<SpaceType> for i32 {
166 fn from(value: SpaceType) -> Self {
167 value as i32
168 }
169}
170
171#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
183#[repr(i32)]
184pub enum AutoresetMode {
185 NextStep = 1,
188 SameStep = 2,
191 #[default]
193 Disabled = 3,
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Eq)]
198pub struct UnknownAutoresetMode(pub i32);
199
200impl std::fmt::Display for UnknownAutoresetMode {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 write!(f, "unknown autoreset mode {}", self.0)
203 }
204}
205
206impl std::error::Error for UnknownAutoresetMode {}
207
208impl TryFrom<i32> for AutoresetMode {
209 type Error = UnknownAutoresetMode;
210 fn try_from(value: i32) -> Result<Self, Self::Error> {
211 match value {
212 1 => Ok(Self::NextStep),
213 2 => Ok(Self::SameStep),
214 0 | 3 => Ok(Self::Disabled),
217 other => Err(UnknownAutoresetMode(other)),
218 }
219 }
220}
221
222impl From<AutoresetMode> for i32 {
223 fn from(value: AutoresetMode) -> Self {
224 value as i32
225 }
226}
227
228#[derive(Debug, Clone, PartialEq, Default)]
229pub struct EnvContract {
230 pub id: String,
231 pub action_space: Option<SpaceSpec>,
232 pub observation_space: Option<SpaceSpec>,
233 pub metadata: Option<crate::meta::MetaMap>,
234 pub render_mode: String,
235 pub num_envs: u32,
236 pub autoreset_mode: AutoresetMode,
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn autoreset_mode_i32_roundtrip() {
247 assert_eq!(i32::from(AutoresetMode::NextStep), 1);
250 assert_eq!(i32::from(AutoresetMode::SameStep), 2);
251 assert_eq!(i32::from(AutoresetMode::Disabled), 3);
252
253 assert_eq!(AutoresetMode::try_from(0), Ok(AutoresetMode::Disabled));
256 assert_eq!(AutoresetMode::try_from(1), Ok(AutoresetMode::NextStep));
257 assert_eq!(AutoresetMode::try_from(2), Ok(AutoresetMode::SameStep));
258 assert_eq!(AutoresetMode::try_from(3), Ok(AutoresetMode::Disabled));
259
260 assert_eq!(AutoresetMode::try_from(99), Err(UnknownAutoresetMode(99)));
262 assert!(AutoresetMode::try_from(-1).is_err());
263
264 for v in [
266 AutoresetMode::NextStep,
267 AutoresetMode::SameStep,
268 AutoresetMode::Disabled,
269 ] {
270 assert_eq!(AutoresetMode::try_from(i32::from(v)), Ok(v));
271 }
272 }
273}