Skip to main content

ember_rl/
encoding.rs

1use burn::prelude::*;
2
3/// Converts environment observations into Burn tensors.
4///
5/// This is the primary bridge between rl-traits' generic world and Burn's
6/// tensor world. Users implement this for their specific observation type --
7/// for CartPole it's 4 floats stacked into a 1D tensor; for Atari it would
8/// be image preprocessing.
9///
10/// # Why this is separate from `Environment`
11///
12/// `rl-traits` deliberately knows nothing about tensors or ML backends.
13/// This trait lives in ember-rl as the adapter layer. A user can implement
14/// the same `Environment` for both headless training (with this encoder)
15/// and Bevy visualisation (with no encoder at all).
16///
17/// # Batching
18///
19/// `encode_batch` has a default implementation that calls `encode` in a loop,
20/// which is correct but slow. Override it with a vectorised implementation
21/// if your observation type allows it -- which for simple flat observations
22/// (like CartPole) it always does.
23pub trait ObservationEncoder<O, B: Backend> {
24    /// The number of features in the encoded observation vector.
25    ///
26    /// Used to determine the Q-network's input layer size automatically.
27    fn obs_size(&self) -> usize;
28
29    /// Encode a single observation into a 1D tensor of shape `[obs_size]`.
30    fn encode(&self, obs: &O, device: &B::Device) -> Tensor<B, 1>;
31
32    /// Encode a batch of observations into a 2D tensor of shape `[batch, obs_size]`.
33    ///
34    /// The default implementation calls `encode` in a loop and stacks results.
35    /// Override with a vectorised implementation for performance.
36    fn encode_batch(&self, obs: &[O], device: &B::Device) -> Tensor<B, 2> {
37        let encoded: Vec<Tensor<B, 1>> = obs.iter()
38            .map(|o| self.encode(o, device))
39            .collect();
40        Tensor::stack(encoded, 0)
41    }
42}
43
44/// Maps between environment action types and integer indices.
45///
46/// DQN is a discrete-action algorithm. The Q-network outputs one Q-value
47/// per action, indexed 0..N. This trait bridges that integer world and the
48/// environment's `Action` type, which may be an enum or something richer.
49///
50/// # Example
51///
52/// ```rust
53/// use ember_rl::encoding::DiscreteActionMapper;
54/// // CartPole: action is just usize (push left = 0, push right = 1)
55/// struct CartPoleActions;
56/// impl DiscreteActionMapper<usize> for CartPoleActions {
57///     fn num_actions(&self) -> usize { 2 }
58///     fn action_to_index(&self, action: &usize) -> usize { *action }
59///     fn index_to_action(&self, index: usize) -> usize { index }
60/// }
61/// ```
62pub trait DiscreteActionMapper<A> {
63    /// Total number of discrete actions available.
64    ///
65    /// Determines the Q-network's output layer size.
66    fn num_actions(&self) -> usize;
67
68    /// Convert an action to its integer index.
69    ///
70    /// Used when storing experience -- we record the index, not the action.
71    fn action_to_index(&self, action: &A) -> usize;
72
73    /// Convert an integer index to an action.
74    ///
75    /// Used when the Q-network selects an action -- it returns an argmax
76    /// index that we convert back to the environment's action type.
77    fn index_to_action(&self, index: usize) -> A;
78}
79
80/// A trivial encoder for environments whose observations are already `Vec<f32>`.
81///
82/// Useful for getting something running quickly without boilerplate.
83#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
84pub struct VecEncoder {
85    size: usize,
86}
87
88impl VecEncoder {
89    pub fn new(size: usize) -> Self {
90        Self { size }
91    }
92}
93
94impl<B: Backend> ObservationEncoder<Vec<f32>, B> for VecEncoder {
95    fn obs_size(&self) -> usize {
96        self.size
97    }
98
99    fn encode(&self, obs: &Vec<f32>, device: &B::Device) -> Tensor<B, 1> {
100        Tensor::from_floats(obs.as_slice(), device)
101    }
102
103    fn encode_batch(&self, obs: &[Vec<f32>], device: &B::Device) -> Tensor<B, 2> {
104        let flat: Vec<f32> = obs.iter().flat_map(|o| o.iter().copied()).collect();
105        let batch = obs.len();
106        Tensor::<B, 1>::from_floats(flat.as_slice(), device)
107            .reshape([batch, self.size])
108    }
109}
110
111/// A trivial action mapper for environments whose actions are plain `usize`.
112#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
113pub struct UsizeActionMapper {
114    num_actions: usize,
115}
116
117impl UsizeActionMapper {
118    pub fn new(num_actions: usize) -> Self {
119        Self { num_actions }
120    }
121}
122
123impl DiscreteActionMapper<usize> for UsizeActionMapper {
124    fn num_actions(&self) -> usize {
125        self.num_actions
126    }
127
128    fn action_to_index(&self, action: &usize) -> usize {
129        *action
130    }
131
132    fn index_to_action(&self, index: usize) -> usize {
133        index
134    }
135}