Skip to main content

ember_rl/algorithms/dqn/
agent.rs

1use std::path::Path;
2
3use burn::prelude::*;
4use burn::tensor::backend::AutodiffBackend;
5use burn::module::AutodiffModule;
6use burn::optim::{Adam, AdamConfig, GradientsParams, Optimizer};
7use burn::optim::adaptor::OptimizerAdaptor;
8use burn::nn::loss::{HuberLossConfig, Reduction};
9use burn::record::{CompactRecorder, RecorderError};
10use rand::{Rng, SeedableRng};
11use rand::rngs::SmallRng;
12use rl_traits::{Environment, Experience, Policy};
13
14use crate::encoding::{DiscreteActionMapper, ObservationEncoder};
15use super::config::DqnConfig;
16use super::network::QNetwork;
17use super::replay::CircularBuffer;
18use rl_traits::ReplayBuffer;
19
20/// A DQN agent.
21///
22/// Implements ε-greedy action selection, experience replay, and TD learning
23/// with a target network. Generic over:
24///
25/// - `E`: the environment type (must satisfy `rl_traits::Environment`)
26/// - `Enc`: the observation encoder (converts `E::Observation` to tensors)
27/// - `Act`: the action mapper (converts `E::Action` to/from integer indices)
28/// - `B`: the Burn backend (e.g. `NdArray`, `Wgpu`)
29/// - `Buf`: the replay buffer (defaults to `CircularBuffer` — swap for PER etc.)
30///
31/// # Usage
32///
33/// ```rust,ignore
34/// let agent = DqnAgent::new(encoder, action_mapper, config, device, seed);
35/// ```
36///
37/// Then hand it to `DqnRunner`, which drives the training loop.
38pub struct DqnAgent<E, Enc, Act, B, Buf = CircularBuffer<
39    <E as Environment>::Observation,
40    <E as Environment>::Action,
41>>
42where
43    E: Environment,
44    B: AutodiffBackend,
45    Buf: ReplayBuffer<E::Observation, E::Action>,
46{
47    // Network pair
48    online_net: QNetwork<B>,
49    target_net: QNetwork<B::InnerBackend>,
50
51    // Optimiser
52    optimiser: OptimizerAdaptor<Adam, QNetwork<B>, B>,
53
54    // Experience replay
55    buffer: Buf,
56
57    // Encoding
58    encoder: Enc,
59    action_mapper: Act,
60
61    // Config and runtime state
62    config: DqnConfig,
63    device: B::Device,
64    total_steps: usize,
65
66    rng: SmallRng,
67
68    _env: std::marker::PhantomData<E>,
69}
70
71impl<E, Enc, Act, B> DqnAgent<E, Enc, Act, B>
72where
73    E: Environment,
74    E::Observation: Clone + Send + Sync + 'static,
75    E::Action: Clone + Send + Sync + 'static,
76    Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
77    Act: DiscreteActionMapper<E::Action>,
78    B: AutodiffBackend,
79{
80    /// Create a new agent using the default `CircularBuffer` replay buffer.
81    ///
82    /// Buffer capacity is taken from `config.buffer_capacity`.
83    pub fn new(encoder: Enc, action_mapper: Act, config: DqnConfig, device: B::Device, seed: u64) -> Self {
84        let buffer = CircularBuffer::new(config.buffer_capacity);
85        Self::new_with_buffer(encoder, action_mapper, config, device, seed, buffer)
86    }
87}
88
89impl<E, Enc, Act, B, Buf> DqnAgent<E, Enc, Act, B, Buf>
90where
91    E: Environment,
92    E::Observation: Clone + Send + Sync + 'static,
93    E::Action: Clone + Send + Sync + 'static,
94    Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
95    Act: DiscreteActionMapper<E::Action>,
96    B: AutodiffBackend,
97    Buf: ReplayBuffer<E::Observation, E::Action>,
98{
99    /// Create a new agent with a custom replay buffer.
100    ///
101    /// Use this to swap in prioritised experience replay or any other
102    /// `ReplayBuffer` implementation in place of the default `CircularBuffer`.
103    pub fn new_with_buffer(
104        encoder: Enc,
105        action_mapper: Act,
106        config: DqnConfig,
107        device: B::Device,
108        seed: u64,
109        buffer: Buf,
110    ) -> Self {
111        let obs_size = <Enc as ObservationEncoder<E::Observation, B>>::obs_size(&encoder);
112        let num_actions = action_mapper.num_actions();
113
114        // Build layer sizes: obs_size -> hidden... -> num_actions
115        let mut layer_sizes = vec![obs_size];
116        layer_sizes.extend_from_slice(&config.hidden_sizes);
117        layer_sizes.push(num_actions);
118
119        let online_net = QNetwork::new(&layer_sizes, &device);
120        let target_net = QNetwork::new(&layer_sizes, &device.clone());
121
122        let optimiser = AdamConfig::new()
123            .with_epsilon(1e-8)
124            .init::<B, QNetwork<B>>();
125
126        Self {
127            online_net,
128            target_net,
129            optimiser,
130            buffer,
131            encoder,
132            action_mapper,
133            device: device.clone(),
134            total_steps: 0,
135            config,
136            rng: SmallRng::seed_from_u64(seed),
137            _env: std::marker::PhantomData,
138        }
139    }
140
141    /// Store a transition in the replay buffer and potentially run a gradient update.
142    ///
143    /// Called by the runner after every environment step. Returns `true` if
144    /// a gradient update was performed this step.
145    pub fn observe(&mut self, experience: Experience<E::Observation, E::Action>) -> bool {
146        self.buffer.push(experience);
147        self.total_steps += 1;
148
149        // Sync target network periodically
150        if self.total_steps.is_multiple_of(self.config.target_update_freq) {
151            self.sync_target();
152        }
153
154        // Don't train until warm-up is complete
155        if !self.buffer.ready_for(self.config.batch_size) {
156            return false;
157        }
158        if self.buffer.len() < self.config.min_replay_size {
159            return false;
160        }
161
162        self.train_step();
163        true
164    }
165
166    /// The current exploration probability.
167    pub fn epsilon(&self) -> f64 {
168        self.config.epsilon_at(self.total_steps)
169    }
170
171    /// Total environment steps observed so far.
172    pub fn total_steps(&self) -> usize {
173        self.total_steps
174    }
175
176    /// Select an action using ε-greedy policy.
177    pub fn act_epsilon_greedy(&self, obs: &E::Observation, rng: &mut impl Rng) -> E::Action {
178        let epsilon = self.epsilon();
179        if rng.gen::<f64>() < epsilon {
180            // Random action — sample by index and convert
181            let idx = rng.gen_range(0..self.action_mapper.num_actions());
182            self.action_mapper.index_to_action(idx)
183        } else {
184            self.act(obs)
185        }
186    }
187
188    /// Save the online network weights to a file.
189    ///
190    /// Uses Burn's `CompactRecorder` (MessagePack format). The recorder appends
191    /// its own extension to the path, so `save("run/cartpole")` produces
192    /// `run/cartpole.mpk`.
193    ///
194    /// Only the online network weights are saved — the target network,
195    /// replay buffer, and optimizer state are not included. This is sufficient
196    /// for inference. To resume training, call `load` followed by
197    /// `set_total_steps` to restore the correct epsilon.
198    pub fn save(&self, path: impl AsRef<Path>) -> Result<(), RecorderError> {
199        self.online_net
200            .clone()
201            .save_file(path.as_ref().to_path_buf(), &CompactRecorder::new())
202            .map(|_| ())
203    }
204
205    /// Load network weights from a file into this agent.
206    ///
207    /// Loads into the online network and immediately syncs the target network.
208    /// Takes `self` by value and returns the updated agent so you can chain
209    /// with the constructor:
210    ///
211    /// ```rust,ignore
212    /// let agent = DqnAgent::new(...).load("run/cartpole")?;
213    /// ```
214    pub fn load(mut self, path: impl AsRef<Path>) -> Result<Self, RecorderError> {
215        self.online_net = self
216            .online_net
217            .load_file(path.as_ref().to_path_buf(), &CompactRecorder::new(), &self.device)?;
218        self.target_net = self.online_net.valid();
219        Ok(self)
220    }
221
222    /// Convert this trained agent into an inference-only `DqnPolicy`.
223    ///
224    /// Strips all training state (optimizer, buffer, RNG) and downcasts the
225    /// network to `B::InnerBackend` (no autodiff). Use this when training is
226    /// complete and you want a lightweight policy for evaluation or deployment.
227    ///
228    /// ```rust,ignore
229    /// let policy = runner.into_agent().into_policy();
230    /// let action = policy.act(&obs);
231    /// ```
232    pub fn into_policy(self) -> super::inference::DqnPolicy<E, Enc, Act, B::InnerBackend> {
233        super::inference::DqnPolicy::from_network(
234            self.online_net.valid(),
235            self.encoder,
236            self.action_mapper,
237            self.device,
238        )
239    }
240
241    /// Override the internal step counter.
242    ///
243    /// Useful when resuming training — restores epsilon to the correct value
244    /// for the point in training where the checkpoint was saved.
245    pub fn set_total_steps(&mut self, steps: usize) {
246        self.total_steps = steps;
247    }
248
249    /// Sync target network weights from the online network.
250    ///
251    /// Hard update: copies all parameters exactly. The target network
252    /// provides stable TD targets — if it updated every step it would
253    /// chase itself and training would diverge.
254    fn sync_target(&mut self) {
255        // Burn's valid() converts an autodiff module to its inner (non-diff) counterpart.
256        self.target_net = self.online_net.valid();
257    }
258
259    /// One gradient update step.
260    fn train_step(&mut self) {
261        let batch = self.buffer.sample(self.config.batch_size, &mut self.rng);
262
263        let batch_size = batch.len();
264
265        // Encode observations and next observations
266        let obs_batch: Vec<_> = batch.iter().map(|e| &e.observation).cloned().collect();
267        let next_obs_batch: Vec<_> = batch.iter().map(|e| &e.next_observation).cloned().collect();
268
269        let obs_tensor = self.encoder.encode_batch(&obs_batch, &self.device);
270        // next_obs is encoded on the inner (non-autodiff) backend — no gradients needed
271        let next_obs_tensor = self.encoder
272            .encode_batch(&next_obs_batch, &self.device.clone());
273
274        // Action indices, rewards, bootstrap masks
275        let action_indices: Vec<usize> = batch.iter()
276            .map(|e| self.action_mapper.action_to_index(&e.action))
277            .collect();
278
279        let rewards: Vec<f32> = batch.iter()
280            .map(|e| e.reward as f32)
281            .collect();
282
283        let masks: Vec<f32> = batch.iter()
284            .map(|e| e.bootstrap_mask() as f32)
285            .collect();
286
287        let rewards_t: Tensor<B, 1> = Tensor::from_floats(rewards.as_slice(), &self.device);
288        let masks_t: Tensor<B, 1> = Tensor::from_floats(masks.as_slice(), &self.device);
289
290        // Target Q-values (no gradients — computed on the target network)
291        let next_q_values = self.target_net.forward(next_obs_tensor);
292        let max_next_q: Tensor<B::InnerBackend, 1> = next_q_values.max_dim(1).squeeze::<1>();
293        let max_next_q_autodiff: Tensor<B, 1> = Tensor::from_inner(max_next_q);
294
295        // TD target: r + γ * mask * max_a Q_target(s', a)
296        let targets = rewards_t + masks_t * max_next_q_autodiff * self.config.gamma as f32;
297
298        // Online Q-values for the actions actually taken
299        let q_values = self.online_net.forward(obs_tensor);
300        let action_indices_t = Tensor::<B, 1, Int>::from_ints(
301            action_indices.iter().map(|&i| i as i32).collect::<Vec<_>>().as_slice(),
302            &self.device,
303        );
304        // Gather Q-values at the taken actions: q_values[i, action_indices[i]]
305        let q_taken = q_values
306            .gather(1, action_indices_t.reshape([batch_size, 1]))
307            .squeeze::<1>();
308
309        // Huber loss (more robust to outlier rewards than MSE)
310        let loss = HuberLossConfig::new(1.0).init().forward(q_taken, targets.detach(), Reduction::Mean);
311
312        // Gradient update
313        let grads = loss.backward();
314        let grads = GradientsParams::from_grads(grads, &self.online_net);
315        self.online_net = self.optimiser.step(
316            self.config.learning_rate,
317            self.online_net.clone(),
318            grads,
319        );
320    }
321}
322
323impl<E, Enc, Act, B, Buf> Policy<E::Observation, E::Action> for DqnAgent<E, Enc, Act, B, Buf>
324where
325    E: Environment,
326    Enc: ObservationEncoder<E::Observation, B>,
327    Act: DiscreteActionMapper<E::Action>,
328    B: AutodiffBackend,
329    Buf: ReplayBuffer<E::Observation, E::Action>,
330{
331    /// Greedy action selection (no exploration).
332    ///
333    /// Use this for evaluation. For training, use `act_epsilon_greedy`.
334    fn act(&self, obs: &E::Observation) -> E::Action {
335        let obs_tensor = self.encoder.encode(obs, &self.device).unsqueeze_dim(0);
336        let q_values = self.online_net.forward(obs_tensor);
337        let best_action_idx: usize = q_values
338            .argmax(1)
339            .into_data()
340            .to_vec::<i64>()
341            .unwrap()[0] as usize;
342        self.action_mapper.index_to_action(best_action_idx)
343    }
344}