Skip to main content

rl_traits/
environment.rs

1use std::collections::HashMap;
2
3use rand::Rng;
4
5use crate::episode::StepResult;
6
7/// The core environment trait.
8///
9/// Defines the contract that all RL environments must satisfy, regardless
10/// of whether they run headless in ember-rl or are visualised via bevy-gym.
11///
12/// # Design principles
13///
14/// - **Type-safe observation and action spaces**: `Observation` and `Action`
15///   are associated types. The compiler enforces correctness; there are no
16///   runtime Box/Discrete/Dict space objects.
17///
18/// - **Typed `Info`**: auxiliary data is `Self::Info`, not `dict[str, Any]`.
19///   If you don't need it, use `()` and get `Default` for free.
20///
21/// - **No `render()`**: visualisation is entirely bevy-gym's concern.
22///   rl-traits knows nothing about rendering.
23///
24/// - **No `close()`**: implement `Drop` if your environment holds resources.
25///
26/// - **Bevy-compatible by design**: `Send + Sync + 'static` bounds on
27///   associated types mean implementations can be used as Bevy `Component`s
28///   directly, enabling free ECS-based parallelisation in bevy-gym via
29///   `Query::par_iter_mut()`.
30///
31/// # Example
32///
33/// ```rust
34/// use rl_traits::{Environment, StepResult, EpisodeStatus};
35/// use rand::Rng;
36///
37/// struct BanditsEnv {
38///     arms: [f64; 4],
39///     rng: rand::rngs::SmallRng,
40/// }
41///
42/// impl Environment for BanditsEnv {
43///     type Observation = ();      // stateless — observation is always ()
44///     type Action = usize;        // pull arm 0..3
45///     type Info = ();
46///
47///     fn step(&mut self, action: usize) -> StepResult<(), ()> {
48///         let reward = self.rng.gen::<f64>() * self.arms[action];
49///         StepResult::new((), reward, EpisodeStatus::Continuing, ())
50///     }
51///
52///     fn reset(&mut self, _seed: Option<u64>) -> ((), ()) {
53///         ((), ())
54///     }
55///
56///     fn sample_action(&self, rng: &mut impl Rng) -> usize {
57///         rng.gen_range(0..4)
58///     }
59/// }
60/// ```
61pub trait Environment {
62    /// The observation type produced by `step()` and `reset()`.
63    ///
64    /// `Send + Sync + 'static` are required for Bevy ECS compatibility.
65    type Observation: Clone + Send + Sync + 'static;
66
67    /// The action type consumed by `step()`.
68    type Action: Clone + Send + Sync + 'static;
69
70    /// Auxiliary information returned alongside observations.
71    ///
72    /// Use `()` if you don't need it — `Default` is implemented for `()`.
73    type Info: Default + Clone + Send + Sync + 'static;
74
75    /// Advance the environment by one timestep.
76    ///
77    /// The caller is responsible for checking `StepResult::is_done()` and
78    /// calling `reset()` before the next episode.
79    fn step(&mut self, action: Self::Action) -> StepResult<Self::Observation, Self::Info>;
80
81    /// Reset the environment to an initial state, starting a new episode.
82    ///
83    /// If `seed` is `Some(u64)`, the environment should use it to seed its
84    /// internal RNG for deterministic reproduction of episodes.
85    fn reset(&mut self, seed: Option<u64>) -> (Self::Observation, Self::Info);
86
87    /// Sample a random action from this environment's action space.
88    ///
89    /// Used by random exploration agents and for initial data collection.
90    /// The `rng` is caller-supplied so exploration randomness can be seeded
91    /// and tracked independently from environment randomness.
92    fn sample_action(&self, rng: &mut impl Rng) -> Self::Action;
93
94    /// Per-episode scalar metrics reported at episode end.
95    ///
96    /// Override this to expose environment-specific statistics (e.g. collisions,
97    /// tiles explored, distance travelled). The default returns an empty map.
98    /// These are merged into training records alongside algorithm-level extras.
99    fn episode_extras(&self) -> HashMap<String, f64> {
100        HashMap::new()
101    }
102}