1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
mod step_limit;
pub use step_limit::{
LatentStepLimit, VisibleStepLimit, WithLatentStepLimit, WithVisibleStepLimit,
};
use super::{
BuildEnv, BuildEnvDist, BuildEnvError, EnvDistribution, EnvStructure, Environment,
StructuredEnvDist, StructuredEnvironment,
};
use crate::Prng;
use serde::{Deserialize, Serialize};
pub trait Wrap: Sized {
#[inline]
fn wrap<W>(self, wrapper: W) -> Wrapped<Self, W> {
Wrapped {
inner: self,
wrapper,
}
}
}
impl<T> Wrap for T {}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Wrapped<T, W> {
pub inner: T,
pub wrapper: W,
}
impl<T, W> Wrapped<T, W> {
pub const fn new(inner: T, wrapper: W) -> Self {
Self { inner, wrapper }
}
}
pub trait StructurePreservingWrapper {}
impl<E, W> EnvStructure for Wrapped<E, W>
where
E: EnvStructure,
W: StructurePreservingWrapper,
{
type ObservationSpace = E::ObservationSpace;
type ActionSpace = E::ActionSpace;
type FeedbackSpace = E::FeedbackSpace;
#[inline]
fn observation_space(&self) -> Self::ObservationSpace {
self.inner.observation_space()
}
#[inline]
fn action_space(&self) -> Self::ActionSpace {
self.inner.action_space()
}
#[inline]
fn feedback_space(&self) -> Self::FeedbackSpace {
self.inner.feedback_space()
}
#[inline]
fn discount_factor(&self) -> f64 {
self.inner.discount_factor()
}
}
impl<EC, W> BuildEnv for Wrapped<EC, W>
where
EC: BuildEnv,
W: Clone,
Wrapped<EC::Environment, W>: StructuredEnvironment,
{
type Observation = <Self::Environment as Environment>::Observation;
type Action = <Self::Environment as Environment>::Action;
type Feedback = <Self::Environment as Environment>::Feedback;
type ObservationSpace = <Self::Environment as EnvStructure>::ObservationSpace;
type ActionSpace = <Self::Environment as EnvStructure>::ActionSpace;
type FeedbackSpace = <Self::Environment as EnvStructure>::FeedbackSpace;
type Environment = Wrapped<EC::Environment, W>;
#[inline]
fn build_env(&self, rng: &mut Prng) -> Result<Self::Environment, BuildEnvError> {
Ok(Wrapped {
inner: self.inner.build_env(rng)?,
wrapper: self.wrapper.clone(),
})
}
}
impl<ED, W> EnvDistribution for Wrapped<ED, W>
where
ED: EnvDistribution,
W: Clone,
Wrapped<ED::Environment, W>: Environment,
{
type State = <Self::Environment as Environment>::State;
type Observation = <Self::Environment as Environment>::Observation;
type Action = <Self::Environment as Environment>::Action;
type Feedback = <Self::Environment as Environment>::Feedback;
type Environment = Wrapped<ED::Environment, W>;
#[inline]
fn sample_environment(&self, rng: &mut Prng) -> Self::Environment {
Wrapped {
inner: self.inner.sample_environment(rng),
wrapper: self.wrapper.clone(),
}
}
}
impl<EDC, W> BuildEnvDist for Wrapped<EDC, W>
where
EDC: BuildEnvDist,
W: Clone,
Wrapped<EDC::EnvDistribution, W>: StructuredEnvDist,
{
type Observation = <Self::EnvDistribution as EnvDistribution>::Observation;
type Action = <Self::EnvDistribution as EnvDistribution>::Action;
type Feedback = <Self::EnvDistribution as EnvDistribution>::Feedback;
type ObservationSpace = <Self::EnvDistribution as EnvStructure>::ObservationSpace;
type ActionSpace = <Self::EnvDistribution as EnvStructure>::ActionSpace;
type FeedbackSpace = <Self::EnvDistribution as EnvStructure>::FeedbackSpace;
type EnvDistribution = Wrapped<EDC::EnvDistribution, W>;
#[inline]
fn build_env_dist(&self) -> Self::EnvDistribution {
Wrapped {
inner: self.inner.build_env_dist(),
wrapper: self.wrapper.clone(),
}
}
}