Skip to main content

burn_rl/environment/
base.rs

1/// The result of taking a step in an environment.
2pub struct StepResult<S> {
3    /// The updated state.
4    pub next_state: S,
5    /// The reward.
6    pub reward: f64,
7    /// If the environment reached a terminal state.
8    pub done: bool,
9    /// If the environment reached its max length.
10    pub truncated: bool,
11}
12
13/// Trait to be implemented for a RL environment.
14pub trait Environment {
15    /// The type of the state.
16    type State;
17    /// The type of actions.
18    type Action;
19
20    /// The maximum number of step for one episode.
21    const MAX_STEPS: usize;
22
23    /// Returns the current state.
24    fn state(&self) -> Self::State;
25    /// Take a step in the environment given an action.
26    fn step(&mut self, action: Self::Action) -> StepResult<Self::State>;
27    /// Reset the environment to an initial state.
28    fn reset(&mut self);
29}
30
31/// Trait to define how to initialize an environment.
32/// By default, any function returning an environment implements it.
33pub trait EnvironmentInit<E: Environment>: Clone {
34    /// Initialize the environment.
35    fn init(&self) -> E;
36}
37
38impl<F, E> EnvironmentInit<E> for F
39where
40    F: Fn() -> E + Clone,
41    E: Environment,
42{
43    fn init(&self) -> E {
44        (self)()
45    }
46}