burn_rl/environment/base.rs
1/// The result of taking a step in an environment.
2pub struct StepResult<S> {
3 /// The updated state.
4 pub next_state: S,
5 /// The reward.
6 pub reward: f64,
7 /// If the environment reached a terminal state.
8 pub done: bool,
9 /// If the environment reached its max length.
10 pub truncated: bool,
11}
12
13/// Trait to be implemented for a RL environment.
14pub trait Environment {
15 /// The type of the state.
16 type State;
17 /// The type of actions.
18 type Action;
19
20 /// The maximum number of step for one episode.
21 const MAX_STEPS: usize;
22
23 /// Returns the current state.
24 fn state(&self) -> Self::State;
25 /// Take a step in the environment given an action.
26 fn step(&mut self, action: Self::Action) -> StepResult<Self::State>;
27 /// Reset the environment to an initial state.
28 fn reset(&mut self);
29}
30
31/// Trait to define how to initialize an environment.
32/// By default, any function returning an environment implements it.
33pub trait EnvironmentInit<E: Environment>: Clone {
34 /// Initialize the environment.
35 fn init(&self) -> E;
36}
37
38impl<F, E> EnvironmentInit<E> for F
39where
40 F: Fn() -> E + Clone,
41 E: Environment,
42{
43 fn init(&self) -> E {
44 (self)()
45 }
46}