Skip to main content

ember_rl/algorithms/dqn/
agent.rs

1use std::collections::HashMap;
2use std::path::Path;
3
4use burn::prelude::*;
5use burn::tensor::backend::AutodiffBackend;
6use burn::module::AutodiffModule;
7use burn::optim::{Adam, AdamConfig, GradientsParams, Optimizer};
8use burn::optim::adaptor::OptimizerAdaptor;
9use burn::nn::loss::{HuberLossConfig, Reduction};
10use burn::record::CompactRecorder;
11use rand::{Rng, SeedableRng};
12use rand::rngs::SmallRng;
13use rl_traits::{Environment, Experience, Policy};
14
15use crate::encoding::{DiscreteActionMapper, ObservationEncoder};
16use crate::stats::{Aggregator, Max, Mean, Std};
17use crate::traits::{ActMode, Checkpointable, LearningAgent};
18use super::config::DqnConfig;
19use super::network::QNetwork;
20use super::replay::CircularBuffer;
21use rl_traits::ReplayBuffer;
22
23/// A DQN agent.
24///
25/// Implements ε-greedy action selection, experience replay, and TD learning
26/// with a target network. Generic over:
27///
28/// - `E`: the environment type (must satisfy `rl_traits::Environment`)
29/// - `Enc`: the observation encoder (converts `E::Observation` to tensors)
30/// - `Act`: the action mapper (converts `E::Action` to/from integer indices)
31/// - `B`: the Burn backend (e.g. `NdArray`, `Wgpu`)
32/// - `Buf`: the replay buffer (defaults to `CircularBuffer` — swap for PER etc.)
33pub struct DqnAgent<E, Enc, Act, B, Buf = CircularBuffer<
34    <E as Environment>::Observation,
35    <E as Environment>::Action,
36>>
37where
38    E: Environment,
39    B: AutodiffBackend,
40    Buf: ReplayBuffer<E::Observation, E::Action>,
41{
42    // Network pair
43    online_net: QNetwork<B>,
44    target_net: QNetwork<B::InnerBackend>,
45
46    // Optimiser
47    optimiser: OptimizerAdaptor<Adam, QNetwork<B>, B>,
48
49    // Experience replay
50    buffer: Buf,
51
52    // Encoding
53    encoder: Enc,
54    action_mapper: Act,
55
56    // Config and runtime state
57    config: DqnConfig,
58    device: B::Device,
59    total_steps: usize,
60
61    // Two decoupled RNGs — sharing one causes correlated sampling/exploration.
62    explore_rng: SmallRng,  // drives ε-greedy action selection
63    sample_rng: SmallRng,   // drives replay buffer sampling
64
65    // Per-episode loss aggregators (reset each episode via on_episode_start)
66    ep_loss_mean: Mean,
67    ep_loss_std: Std,
68    ep_loss_max: Max,
69
70    _env: std::marker::PhantomData<E>,
71}
72
73impl<E, Enc, Act, B> DqnAgent<E, Enc, Act, B>
74where
75    E: Environment,
76    E::Observation: Clone + Send + Sync + 'static,
77    E::Action: Clone + Send + Sync + 'static,
78    Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
79    Act: DiscreteActionMapper<E::Action>,
80    B: AutodiffBackend,
81{
82    /// Create a new agent using the default `CircularBuffer` replay buffer.
83    pub fn new(encoder: Enc, action_mapper: Act, config: DqnConfig, device: B::Device, seed: u64) -> Self {
84        let buffer = CircularBuffer::new(config.buffer_capacity);
85        Self::new_with_buffer(encoder, action_mapper, config, device, seed, buffer)
86    }
87}
88
89impl<E, Enc, Act, B, Buf> DqnAgent<E, Enc, Act, B, Buf>
90where
91    E: Environment,
92    E::Observation: Clone + Send + Sync + 'static,
93    E::Action: Clone + Send + Sync + 'static,
94    Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
95    Act: DiscreteActionMapper<E::Action>,
96    B: AutodiffBackend,
97    Buf: ReplayBuffer<E::Observation, E::Action>,
98{
99    /// Create a new agent with a custom replay buffer.
100    pub fn new_with_buffer(
101        encoder: Enc,
102        action_mapper: Act,
103        config: DqnConfig,
104        device: B::Device,
105        seed: u64,
106        buffer: Buf,
107    ) -> Self {
108        let obs_size = <Enc as ObservationEncoder<E::Observation, B>>::obs_size(&encoder);
109        let num_actions = action_mapper.num_actions();
110
111        let mut layer_sizes = vec![obs_size];
112        layer_sizes.extend_from_slice(&config.hidden_sizes);
113        layer_sizes.push(num_actions);
114
115        let online_net = QNetwork::new(&layer_sizes, &device);
116        let target_net = QNetwork::new(&layer_sizes, &device.clone());
117
118        let optimiser = AdamConfig::new()
119            .with_epsilon(1e-8)
120            .init::<B, QNetwork<B>>();
121
122        Self {
123            online_net,
124            target_net,
125            optimiser,
126            buffer,
127            encoder,
128            action_mapper,
129            device: device.clone(),
130            total_steps: 0,
131            config,
132            // Offset the two seeds so they produce independent streams.
133            explore_rng: SmallRng::seed_from_u64(seed),
134            sample_rng: SmallRng::seed_from_u64(seed.wrapping_add(0x9e37_79b9_7f4a_7c15)),
135            ep_loss_mean: Mean::default(),
136            ep_loss_std: Std::default(),
137            ep_loss_max: Max::default(),
138            _env: std::marker::PhantomData,
139        }
140    }
141
142    /// The current exploration probability.
143    pub fn epsilon(&self) -> f64 {
144        self.config.epsilon_at(self.total_steps)
145    }
146
147    /// Override the internal step counter (use when resuming training).
148    pub fn set_total_steps(&mut self, steps: usize) {
149        self.total_steps = steps;
150    }
151
152    /// Convert this trained agent into an inference-only `DqnPolicy`.
153    pub fn into_policy(self) -> super::inference::DqnPolicy<E, Enc, Act, B::InnerBackend> {
154        super::inference::DqnPolicy::from_network(
155            self.online_net.valid(),
156            self.encoder,
157            self.action_mapper,
158            self.device,
159        )
160    }
161
162    fn sync_target(&mut self) {
163        self.target_net = self.online_net.valid();
164    }
165
166    /// One gradient update step. Returns the scalar loss value.
167    fn train_step(&mut self) -> f64 {
168        let batch = self.buffer.sample(self.config.batch_size, &mut self.sample_rng);
169        let batch_size = batch.len();
170
171        let obs_batch: Vec<_> = batch.iter().map(|e| &e.observation).cloned().collect();
172        let next_obs_batch: Vec<_> = batch.iter().map(|e| &e.next_observation).cloned().collect();
173
174        let obs_tensor = self.encoder.encode_batch(&obs_batch, &self.device);
175        let next_obs_tensor = self.encoder.encode_batch(&next_obs_batch, &self.device.clone());
176
177        let action_indices: Vec<usize> = batch.iter()
178            .map(|e| self.action_mapper.action_to_index(&e.action))
179            .collect();
180        let rewards: Vec<f32> = batch.iter().map(|e| e.reward as f32).collect();
181        let masks: Vec<f32> = batch.iter().map(|e| e.bootstrap_mask() as f32).collect();
182
183        let rewards_t: Tensor<B, 1> = Tensor::from_floats(rewards.as_slice(), &self.device);
184        let masks_t: Tensor<B, 1> = Tensor::from_floats(masks.as_slice(), &self.device);
185
186        let next_q_values = self.target_net.forward(next_obs_tensor);
187        let max_next_q: Tensor<B::InnerBackend, 1> = next_q_values.max_dim(1).squeeze::<1>();
188        let max_next_q_autodiff: Tensor<B, 1> = Tensor::from_inner(max_next_q);
189
190        let targets = rewards_t + masks_t * max_next_q_autodiff * self.config.gamma as f32;
191
192        let q_values = self.online_net.forward(obs_tensor);
193        let action_indices_t = Tensor::<B, 1, Int>::from_ints(
194            action_indices.iter().map(|&i| i as i32).collect::<Vec<_>>().as_slice(),
195            &self.device,
196        );
197        let q_taken = q_values
198            .gather(1, action_indices_t.reshape([batch_size, 1]))
199            .squeeze::<1>();
200
201        let loss = HuberLossConfig::new(1.0)
202            .init()
203            .forward(q_taken, targets.detach(), Reduction::Mean);
204
205        let loss_val = loss.clone().into_scalar().elem::<f64>();
206
207        let grads = loss.backward();
208        let grads = GradientsParams::from_grads(grads, &self.online_net);
209        self.online_net = self.optimiser.step(
210            self.config.learning_rate,
211            self.online_net.clone(),
212            grads,
213        );
214
215        loss_val
216    }
217}
218
219// ── Checkpointable ────────────────────────────────────────────────────────────
220
221impl<E, Enc, Act, B, Buf> Checkpointable for DqnAgent<E, Enc, Act, B, Buf>
222where
223    E: Environment,
224    E::Observation: Clone + Send + Sync + 'static,
225    E::Action: Clone + Send + Sync + 'static,
226    Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
227    Act: DiscreteActionMapper<E::Action>,
228    B: AutodiffBackend,
229    Buf: ReplayBuffer<E::Observation, E::Action>,
230{
231    fn save(&self, path: &Path) -> anyhow::Result<()> {
232        self.online_net
233            .clone()
234            .save_file(path.to_path_buf(), &CompactRecorder::new())
235            .map(|_| ())
236            .map_err(|e| anyhow::anyhow!(e))
237    }
238
239    fn load(mut self, path: &Path) -> anyhow::Result<Self> {
240        self.online_net = self.online_net
241            .load_file(path.to_path_buf(), &CompactRecorder::new(), &self.device)
242            .map_err(|e| anyhow::anyhow!(e))?;
243        self.target_net = self.online_net.valid();
244        Ok(self)
245    }
246}
247
248// ── LearningAgent ─────────────────────────────────────────────────────────────
249
250impl<E, Enc, Act, B, Buf> LearningAgent<E> for DqnAgent<E, Enc, Act, B, Buf>
251where
252    E: Environment,
253    E::Observation: Clone + Send + Sync + 'static,
254    E::Action: Clone + Send + Sync + 'static,
255    Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
256    Act: DiscreteActionMapper<E::Action>,
257    B: AutodiffBackend,
258    Buf: ReplayBuffer<E::Observation, E::Action>,
259{
260    fn act(&mut self, obs: &E::Observation, mode: ActMode) -> E::Action {
261        match mode {
262            ActMode::Exploit => {
263                // Greedy argmax — no RNG needed.
264                let obs_tensor = self.encoder.encode(obs, &self.device).unsqueeze_dim(0);
265                let q_values = self.online_net.forward(obs_tensor);
266                let idx: usize = q_values
267                    .argmax(1)
268                    .into_data()
269                    .to_vec::<i64>()
270                    .unwrap()[0] as usize;
271                self.action_mapper.index_to_action(idx)
272            }
273            ActMode::Explore => {
274                let epsilon = self.config.epsilon_at(self.total_steps);
275                if self.explore_rng.gen::<f64>() < epsilon {
276                    let idx = self.explore_rng.gen_range(0..self.action_mapper.num_actions());
277                    self.action_mapper.index_to_action(idx)
278                } else {
279                    let obs_tensor = self.encoder.encode(obs, &self.device).unsqueeze_dim(0);
280                    let q_values = self.online_net.forward(obs_tensor);
281                    let idx: usize = q_values
282                        .argmax(1)
283                        .into_data()
284                        .to_vec::<i64>()
285                        .unwrap()[0] as usize;
286                    self.action_mapper.index_to_action(idx)
287                }
288            }
289        }
290    }
291
292    fn observe(&mut self, experience: Experience<E::Observation, E::Action>) {
293        self.buffer.push(experience);
294        self.total_steps += 1;
295
296        if self.total_steps.is_multiple_of(self.config.target_update_freq) {
297            self.sync_target();
298        }
299
300        if self.buffer.ready_for(self.config.batch_size)
301            && self.buffer.len() >= self.config.min_replay_size
302        {
303            let loss = self.train_step();
304            self.ep_loss_mean.update(loss);
305            self.ep_loss_std.update(loss);
306            self.ep_loss_max.update(loss);
307        }
308    }
309
310    fn total_steps(&self) -> usize {
311        self.total_steps
312    }
313
314    fn episode_extras(&self) -> HashMap<String, f64> {
315        [
316            ("epsilon".to_string(),    self.epsilon()),
317            ("loss_mean".to_string(),  self.ep_loss_mean.value()),
318            ("loss_std".to_string(),   self.ep_loss_std.value()),
319            ("loss_max".to_string(),   self.ep_loss_max.value()),
320        ]
321        .into_iter()
322        .filter(|(_, v)| v.is_finite())
323        .collect()
324    }
325
326    fn on_episode_start(&mut self) {
327        self.ep_loss_mean.reset();
328        self.ep_loss_std.reset();
329        self.ep_loss_max.reset();
330    }
331}
332
333// ── Policy (greedy inference) ─────────────────────────────────────────────────
334
335impl<E, Enc, Act, B, Buf> Policy<E::Observation, E::Action> for DqnAgent<E, Enc, Act, B, Buf>
336where
337    E: Environment,
338    Enc: ObservationEncoder<E::Observation, B>,
339    Act: DiscreteActionMapper<E::Action>,
340    B: AutodiffBackend,
341    Buf: ReplayBuffer<E::Observation, E::Action>,
342{
343    fn act(&self, obs: &E::Observation) -> E::Action {
344        let obs_tensor = self.encoder.encode(obs, &self.device).unsqueeze_dim(0);
345        let q_values = self.online_net.forward(obs_tensor);
346        let idx: usize = q_values
347            .argmax(1)
348            .into_data()
349            .to_vec::<i64>()
350            .unwrap()[0] as usize;
351        self.action_mapper.index_to_action(idx)
352    }
353}