1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
mod step_limit;

pub use step_limit::{
    LatentStepLimit, VisibleStepLimit, WithLatentStepLimit, WithVisibleStepLimit,
};

use super::{
    BuildEnv, BuildEnvDist, BuildEnvError, EnvDistribution, EnvStructure, Environment,
    StructuredEnvDist, StructuredEnvironment,
};
use crate::Prng;
use serde::{Deserialize, Serialize};

/// Trait providing a `wrap` method for all sized types.
pub trait Wrap: Sized {
    /// Wrap in the given wrapper.
    #[inline]
    fn wrap<W>(self, wrapper: W) -> Wrapped<Self, W> {
        Wrapped {
            inner: self,
            wrapper,
        }
    }
}

impl<T> Wrap for T {}

/// A basic wrapped object.
///
/// Consists of the inner object and the wrapper state.
///
/// # Implementation
/// To implement a wrapper type, define `struct MyWrapper` and implement
/// `impl<T: Environment> Environment for Wrapped<T, MyWrapper>` and
/// `impl<T: EnvStructure> EnvStructure for Wrapped<T, MyWrapper>`.
///
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Wrapped<T, W> {
    /// Wrapped object
    pub inner: T,
    /// The wrapper
    pub wrapper: W,
}

impl<T, W> Wrapped<T, W> {
    pub const fn new(inner: T, wrapper: W) -> Self {
        Self { inner, wrapper }
    }
}

/// Marker trait for a wrapper that does not modify the environment structure.
pub trait StructurePreservingWrapper {}

impl<E, W> EnvStructure for Wrapped<E, W>
where
    E: EnvStructure,
    W: StructurePreservingWrapper,
{
    type ObservationSpace = E::ObservationSpace;
    type ActionSpace = E::ActionSpace;
    type FeedbackSpace = E::FeedbackSpace;

    #[inline]
    fn observation_space(&self) -> Self::ObservationSpace {
        self.inner.observation_space()
    }
    #[inline]
    fn action_space(&self) -> Self::ActionSpace {
        self.inner.action_space()
    }
    #[inline]
    fn feedback_space(&self) -> Self::FeedbackSpace {
        self.inner.feedback_space()
    }
    #[inline]
    fn discount_factor(&self) -> f64 {
        self.inner.discount_factor()
    }
}

impl<EC, W> BuildEnv for Wrapped<EC, W>
where
    EC: BuildEnv,
    W: Clone,
    Wrapped<EC::Environment, W>: StructuredEnvironment,
{
    type Observation = <Self::Environment as Environment>::Observation;
    type Action = <Self::Environment as Environment>::Action;
    type Feedback = <Self::Environment as Environment>::Feedback;
    type ObservationSpace = <Self::Environment as EnvStructure>::ObservationSpace;
    type ActionSpace = <Self::Environment as EnvStructure>::ActionSpace;
    type FeedbackSpace = <Self::Environment as EnvStructure>::FeedbackSpace;
    type Environment = Wrapped<EC::Environment, W>;

    #[inline]
    fn build_env(&self, rng: &mut Prng) -> Result<Self::Environment, BuildEnvError> {
        Ok(Wrapped {
            inner: self.inner.build_env(rng)?,
            wrapper: self.wrapper.clone(),
        })
    }
}

impl<ED, W> EnvDistribution for Wrapped<ED, W>
where
    ED: EnvDistribution,
    W: Clone,
    Wrapped<ED::Environment, W>: Environment,
{
    type State = <Self::Environment as Environment>::State;
    type Observation = <Self::Environment as Environment>::Observation;
    type Action = <Self::Environment as Environment>::Action;
    type Feedback = <Self::Environment as Environment>::Feedback;
    type Environment = Wrapped<ED::Environment, W>;

    #[inline]
    fn sample_environment(&self, rng: &mut Prng) -> Self::Environment {
        Wrapped {
            inner: self.inner.sample_environment(rng),
            wrapper: self.wrapper.clone(),
        }
    }
}

impl<EDC, W> BuildEnvDist for Wrapped<EDC, W>
where
    EDC: BuildEnvDist,
    W: Clone,
    Wrapped<EDC::EnvDistribution, W>: StructuredEnvDist,
{
    type Observation = <Self::EnvDistribution as EnvDistribution>::Observation;
    type Action = <Self::EnvDistribution as EnvDistribution>::Action;
    type Feedback = <Self::EnvDistribution as EnvDistribution>::Feedback;
    type ObservationSpace = <Self::EnvDistribution as EnvStructure>::ObservationSpace;
    type ActionSpace = <Self::EnvDistribution as EnvStructure>::ActionSpace;
    type FeedbackSpace = <Self::EnvDistribution as EnvStructure>::FeedbackSpace;
    type EnvDistribution = Wrapped<EDC::EnvDistribution, W>;

    #[inline]
    fn build_env_dist(&self) -> Self::EnvDistribution {
        Wrapped {
            inner: self.inner.build_env_dist(),
            wrapper: self.wrapper.clone(),
        }
    }
}