relearn/envs/wrappers/
mod.rs1mod step_limit;
2
3pub use step_limit::{
4 LatentStepLimit, VisibleStepLimit, WithLatentStepLimit, WithVisibleStepLimit,
5};
6
7use super::{
8 BuildEnv, BuildEnvDist, BuildEnvError, EnvDistribution, EnvStructure, Environment,
9 StructuredEnvDist, StructuredEnvironment,
10};
11use crate::Prng;
12use serde::{Deserialize, Serialize};
13
14pub trait Wrap: Sized {
16 #[inline]
18 fn wrap<W>(self, wrapper: W) -> Wrapped<Self, W> {
19 Wrapped {
20 inner: self,
21 wrapper,
22 }
23 }
24}
25
26impl<T> Wrap for T {}
27
28#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
38pub struct Wrapped<T, W> {
39 pub inner: T,
41 pub wrapper: W,
43}
44
45impl<T, W> Wrapped<T, W> {
46 pub const fn new(inner: T, wrapper: W) -> Self {
47 Self { inner, wrapper }
48 }
49}
50
51pub trait StructurePreservingWrapper {}
53
54impl<E, W> EnvStructure for Wrapped<E, W>
55where
56 E: EnvStructure,
57 W: StructurePreservingWrapper,
58{
59 type ObservationSpace = E::ObservationSpace;
60 type ActionSpace = E::ActionSpace;
61 type FeedbackSpace = E::FeedbackSpace;
62
63 #[inline]
64 fn observation_space(&self) -> Self::ObservationSpace {
65 self.inner.observation_space()
66 }
67 #[inline]
68 fn action_space(&self) -> Self::ActionSpace {
69 self.inner.action_space()
70 }
71 #[inline]
72 fn feedback_space(&self) -> Self::FeedbackSpace {
73 self.inner.feedback_space()
74 }
75 #[inline]
76 fn discount_factor(&self) -> f64 {
77 self.inner.discount_factor()
78 }
79}
80
81impl<EC, W> BuildEnv for Wrapped<EC, W>
82where
83 EC: BuildEnv,
84 W: Clone,
85 Wrapped<EC::Environment, W>: StructuredEnvironment,
86{
87 type Observation = <Self::Environment as Environment>::Observation;
88 type Action = <Self::Environment as Environment>::Action;
89 type Feedback = <Self::Environment as Environment>::Feedback;
90 type ObservationSpace = <Self::Environment as EnvStructure>::ObservationSpace;
91 type ActionSpace = <Self::Environment as EnvStructure>::ActionSpace;
92 type FeedbackSpace = <Self::Environment as EnvStructure>::FeedbackSpace;
93 type Environment = Wrapped<EC::Environment, W>;
94
95 #[inline]
96 fn build_env(&self, rng: &mut Prng) -> Result<Self::Environment, BuildEnvError> {
97 Ok(Wrapped {
98 inner: self.inner.build_env(rng)?,
99 wrapper: self.wrapper.clone(),
100 })
101 }
102}
103
104impl<ED, W> EnvDistribution for Wrapped<ED, W>
105where
106 ED: EnvDistribution,
107 W: Clone,
108 Wrapped<ED::Environment, W>: Environment,
109{
110 type State = <Self::Environment as Environment>::State;
111 type Observation = <Self::Environment as Environment>::Observation;
112 type Action = <Self::Environment as Environment>::Action;
113 type Feedback = <Self::Environment as Environment>::Feedback;
114 type Environment = Wrapped<ED::Environment, W>;
115
116 #[inline]
117 fn sample_environment(&self, rng: &mut Prng) -> Self::Environment {
118 Wrapped {
119 inner: self.inner.sample_environment(rng),
120 wrapper: self.wrapper.clone(),
121 }
122 }
123}
124
125impl<EDC, W> BuildEnvDist for Wrapped<EDC, W>
126where
127 EDC: BuildEnvDist,
128 W: Clone,
129 Wrapped<EDC::EnvDistribution, W>: StructuredEnvDist,
130{
131 type Observation = <Self::EnvDistribution as EnvDistribution>::Observation;
132 type Action = <Self::EnvDistribution as EnvDistribution>::Action;
133 type Feedback = <Self::EnvDistribution as EnvDistribution>::Feedback;
134 type ObservationSpace = <Self::EnvDistribution as EnvStructure>::ObservationSpace;
135 type ActionSpace = <Self::EnvDistribution as EnvStructure>::ActionSpace;
136 type FeedbackSpace = <Self::EnvDistribution as EnvStructure>::FeedbackSpace;
137 type EnvDistribution = Wrapped<EDC::EnvDistribution, W>;
138
139 #[inline]
140 fn build_env_dist(&self) -> Self::EnvDistribution {
141 Wrapped {
142 inner: self.inner.build_env_dist(),
143 wrapper: self.wrapper.clone(),
144 }
145 }
146}