ember_rl/encoding.rs
1use burn::prelude::*;
2
3/// Converts environment observations into Burn tensors.
4///
5/// This is the primary bridge between rl-traits' generic world and Burn's
6/// tensor world. Users implement this for their specific observation type —
7/// for CartPole it's 4 floats stacked into a 1D tensor; for Atari it would
8/// be image preprocessing.
9///
10/// # Why this is separate from `Environment`
11///
12/// `rl-traits` deliberately knows nothing about tensors or ML backends.
13/// This trait lives in ember-rl as the adapter layer. A user can implement
14/// the same `Environment` for both headless training (with this encoder)
15/// and Bevy visualisation (with no encoder at all).
16///
17/// # Batching
18///
19/// `encode_batch` has a default implementation that calls `encode` in a loop,
20/// which is correct but slow. Override it with a vectorised implementation
21/// if your observation type allows it — which for simple flat observations
22/// (like CartPole) it always does.
23pub trait ObservationEncoder<O, B: Backend> {
24 /// The number of features in the encoded observation vector.
25 ///
26 /// Used to determine the Q-network's input layer size automatically.
27 fn obs_size(&self) -> usize;
28
29 /// Encode a single observation into a 1D tensor of shape `[obs_size]`.
30 fn encode(&self, obs: &O, device: &B::Device) -> Tensor<B, 1>;
31
32 /// Encode a batch of observations into a 2D tensor of shape `[batch, obs_size]`.
33 ///
34 /// The default implementation calls `encode` in a loop and stacks results.
35 /// Override with a vectorised implementation for performance.
36 fn encode_batch(&self, obs: &[O], device: &B::Device) -> Tensor<B, 2> {
37 let encoded: Vec<Tensor<B, 1>> = obs.iter()
38 .map(|o| self.encode(o, device))
39 .collect();
40 Tensor::stack(encoded, 0)
41 }
42}
43
44/// Maps between environment action types and integer indices.
45///
46/// DQN is a discrete-action algorithm. The Q-network outputs one Q-value
47/// per action, indexed 0..N. This trait bridges that integer world and the
48/// environment's `Action` type, which may be an enum or something richer.
49///
50/// # Example
51///
52/// ```rust
53/// use ember_rl::encoding::DiscreteActionMapper;
54/// // CartPole: action is just usize (push left = 0, push right = 1)
55/// struct CartPoleActions;
56/// impl DiscreteActionMapper<usize> for CartPoleActions {
57/// fn num_actions(&self) -> usize { 2 }
58/// fn action_to_index(&self, action: &usize) -> usize { *action }
59/// fn index_to_action(&self, index: usize) -> usize { index }
60/// }
61/// ```
62pub trait DiscreteActionMapper<A> {
63 /// Total number of discrete actions available.
64 ///
65 /// Determines the Q-network's output layer size.
66 fn num_actions(&self) -> usize;
67
68 /// Convert an action to its integer index.
69 ///
70 /// Used when storing experience — we record the index, not the action.
71 fn action_to_index(&self, action: &A) -> usize;
72
73 /// Convert an integer index to an action.
74 ///
75 /// Used when the Q-network selects an action — it returns an argmax
76 /// index that we convert back to the environment's action type.
77 fn index_to_action(&self, index: usize) -> A;
78}
79
80/// A trivial encoder for environments whose observations are already `Vec<f32>`.
81///
82/// Useful for getting something running quickly without boilerplate.
83#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
84pub struct VecEncoder {
85 size: usize,
86}
87
88impl VecEncoder {
89 pub fn new(size: usize) -> Self {
90 Self { size }
91 }
92}
93
94impl<B: Backend> ObservationEncoder<Vec<f32>, B> for VecEncoder {
95 fn obs_size(&self) -> usize {
96 self.size
97 }
98
99 fn encode(&self, obs: &Vec<f32>, device: &B::Device) -> Tensor<B, 1> {
100 Tensor::from_floats(obs.as_slice(), device)
101 }
102
103 fn encode_batch(&self, obs: &[Vec<f32>], device: &B::Device) -> Tensor<B, 2> {
104 let flat: Vec<f32> = obs.iter().flat_map(|o| o.iter().copied()).collect();
105 let batch = obs.len();
106 Tensor::<B, 1>::from_floats(flat.as_slice(), device)
107 .reshape([batch, self.size])
108 }
109}
110
111/// A trivial action mapper for environments whose actions are plain `usize`.
112#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
113pub struct UsizeActionMapper {
114 num_actions: usize,
115}
116
117impl UsizeActionMapper {
118 pub fn new(num_actions: usize) -> Self {
119 Self { num_actions }
120 }
121}
122
123impl DiscreteActionMapper<usize> for UsizeActionMapper {
124 fn num_actions(&self) -> usize {
125 self.num_actions
126 }
127
128 fn action_to_index(&self, action: &usize) -> usize {
129 *action
130 }
131
132 fn index_to_action(&self, index: usize) -> usize {
133 index
134 }
135}