relearn/envs/wrappers/
mod.rs

1mod step_limit;
2
3pub use step_limit::{
4    LatentStepLimit, VisibleStepLimit, WithLatentStepLimit, WithVisibleStepLimit,
5};
6
7use super::{
8    BuildEnv, BuildEnvDist, BuildEnvError, EnvDistribution, EnvStructure, Environment,
9    StructuredEnvDist, StructuredEnvironment,
10};
11use crate::Prng;
12use serde::{Deserialize, Serialize};
13
14/// Trait providing a `wrap` method for all sized types.
15pub trait Wrap: Sized {
16    /// Wrap in the given wrapper.
17    #[inline]
18    fn wrap<W>(self, wrapper: W) -> Wrapped<Self, W> {
19        Wrapped {
20            inner: self,
21            wrapper,
22        }
23    }
24}
25
26impl<T> Wrap for T {}
27
28/// A basic wrapped object.
29///
30/// Consists of the inner object and the wrapper state.
31///
32/// # Implementation
33/// To implement a wrapper type, define `struct MyWrapper` and implement
34/// `impl<T: Environment> Environment for Wrapped<T, MyWrapper>` and
35/// `impl<T: EnvStructure> EnvStructure for Wrapped<T, MyWrapper>`.
36///
37#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
38pub struct Wrapped<T, W> {
39    /// Wrapped object
40    pub inner: T,
41    /// The wrapper
42    pub wrapper: W,
43}
44
45impl<T, W> Wrapped<T, W> {
46    pub const fn new(inner: T, wrapper: W) -> Self {
47        Self { inner, wrapper }
48    }
49}
50
51/// Marker trait for a wrapper that does not modify the environment structure.
52pub trait StructurePreservingWrapper {}
53
54impl<E, W> EnvStructure for Wrapped<E, W>
55where
56    E: EnvStructure,
57    W: StructurePreservingWrapper,
58{
59    type ObservationSpace = E::ObservationSpace;
60    type ActionSpace = E::ActionSpace;
61    type FeedbackSpace = E::FeedbackSpace;
62
63    #[inline]
64    fn observation_space(&self) -> Self::ObservationSpace {
65        self.inner.observation_space()
66    }
67    #[inline]
68    fn action_space(&self) -> Self::ActionSpace {
69        self.inner.action_space()
70    }
71    #[inline]
72    fn feedback_space(&self) -> Self::FeedbackSpace {
73        self.inner.feedback_space()
74    }
75    #[inline]
76    fn discount_factor(&self) -> f64 {
77        self.inner.discount_factor()
78    }
79}
80
81impl<EC, W> BuildEnv for Wrapped<EC, W>
82where
83    EC: BuildEnv,
84    W: Clone,
85    Wrapped<EC::Environment, W>: StructuredEnvironment,
86{
87    type Observation = <Self::Environment as Environment>::Observation;
88    type Action = <Self::Environment as Environment>::Action;
89    type Feedback = <Self::Environment as Environment>::Feedback;
90    type ObservationSpace = <Self::Environment as EnvStructure>::ObservationSpace;
91    type ActionSpace = <Self::Environment as EnvStructure>::ActionSpace;
92    type FeedbackSpace = <Self::Environment as EnvStructure>::FeedbackSpace;
93    type Environment = Wrapped<EC::Environment, W>;
94
95    #[inline]
96    fn build_env(&self, rng: &mut Prng) -> Result<Self::Environment, BuildEnvError> {
97        Ok(Wrapped {
98            inner: self.inner.build_env(rng)?,
99            wrapper: self.wrapper.clone(),
100        })
101    }
102}
103
104impl<ED, W> EnvDistribution for Wrapped<ED, W>
105where
106    ED: EnvDistribution,
107    W: Clone,
108    Wrapped<ED::Environment, W>: Environment,
109{
110    type State = <Self::Environment as Environment>::State;
111    type Observation = <Self::Environment as Environment>::Observation;
112    type Action = <Self::Environment as Environment>::Action;
113    type Feedback = <Self::Environment as Environment>::Feedback;
114    type Environment = Wrapped<ED::Environment, W>;
115
116    #[inline]
117    fn sample_environment(&self, rng: &mut Prng) -> Self::Environment {
118        Wrapped {
119            inner: self.inner.sample_environment(rng),
120            wrapper: self.wrapper.clone(),
121        }
122    }
123}
124
125impl<EDC, W> BuildEnvDist for Wrapped<EDC, W>
126where
127    EDC: BuildEnvDist,
128    W: Clone,
129    Wrapped<EDC::EnvDistribution, W>: StructuredEnvDist,
130{
131    type Observation = <Self::EnvDistribution as EnvDistribution>::Observation;
132    type Action = <Self::EnvDistribution as EnvDistribution>::Action;
133    type Feedback = <Self::EnvDistribution as EnvDistribution>::Feedback;
134    type ObservationSpace = <Self::EnvDistribution as EnvStructure>::ObservationSpace;
135    type ActionSpace = <Self::EnvDistribution as EnvStructure>::ActionSpace;
136    type FeedbackSpace = <Self::EnvDistribution as EnvStructure>::FeedbackSpace;
137    type EnvDistribution = Wrapped<EDC::EnvDistribution, W>;
138
139    #[inline]
140    fn build_env_dist(&self) -> Self::EnvDistribution {
141        Wrapped {
142            inner: self.inner.build_env_dist(),
143            wrapper: self.wrapper.clone(),
144        }
145    }
146}