1use std::collections::HashMap;
2use std::path::Path;
3
4use burn::prelude::*;
5use burn::tensor::backend::AutodiffBackend;
6use burn::module::AutodiffModule;
7use burn::optim::{Adam, AdamConfig, GradientsParams, Optimizer};
8use burn::optim::adaptor::OptimizerAdaptor;
9use burn::nn::loss::{HuberLossConfig, Reduction};
10use burn::record::CompactRecorder;
11use rand::{Rng, SeedableRng};
12use rand::rngs::SmallRng;
13use rl_traits::{Environment, Experience, Policy};
14
15use crate::encoding::{DiscreteActionMapper, ObservationEncoder};
16use crate::stats::{Aggregator, Max, Mean, Std};
17use crate::traits::{ActMode, Checkpointable, LearningAgent};
18use super::config::DqnConfig;
19use super::network::QNetwork;
20use super::replay::CircularBuffer;
21use rl_traits::ReplayBuffer;
22
23pub struct DqnAgent<E, Enc, Act, B, Buf = CircularBuffer<
34 <E as Environment>::Observation,
35 <E as Environment>::Action,
36>>
37where
38 E: Environment,
39 B: AutodiffBackend,
40 Buf: ReplayBuffer<E::Observation, E::Action>,
41{
42 online_net: QNetwork<B>,
44 target_net: QNetwork<B::InnerBackend>,
45
46 optimiser: OptimizerAdaptor<Adam, QNetwork<B>, B>,
48
49 buffer: Buf,
51
52 encoder: Enc,
54 action_mapper: Act,
55
56 config: DqnConfig,
58 device: B::Device,
59 total_steps: usize,
60
61 explore_rng: SmallRng, sample_rng: SmallRng, ep_loss_mean: Mean,
67 ep_loss_std: Std,
68 ep_loss_max: Max,
69
70 _env: std::marker::PhantomData<E>,
71}
72
73impl<E, Enc, Act, B> DqnAgent<E, Enc, Act, B>
74where
75 E: Environment,
76 E::Observation: Clone + Send + Sync + 'static,
77 E::Action: Clone + Send + Sync + 'static,
78 Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
79 Act: DiscreteActionMapper<E::Action>,
80 B: AutodiffBackend,
81{
82 pub fn new(encoder: Enc, action_mapper: Act, config: DqnConfig, device: B::Device, seed: u64) -> Self {
84 let buffer = CircularBuffer::new(config.buffer_capacity);
85 Self::new_with_buffer(encoder, action_mapper, config, device, seed, buffer)
86 }
87}
88
89impl<E, Enc, Act, B, Buf> DqnAgent<E, Enc, Act, B, Buf>
90where
91 E: Environment,
92 E::Observation: Clone + Send + Sync + 'static,
93 E::Action: Clone + Send + Sync + 'static,
94 Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
95 Act: DiscreteActionMapper<E::Action>,
96 B: AutodiffBackend,
97 Buf: ReplayBuffer<E::Observation, E::Action>,
98{
99 pub fn new_with_buffer(
101 encoder: Enc,
102 action_mapper: Act,
103 config: DqnConfig,
104 device: B::Device,
105 seed: u64,
106 buffer: Buf,
107 ) -> Self {
108 let obs_size = <Enc as ObservationEncoder<E::Observation, B>>::obs_size(&encoder);
109 let num_actions = action_mapper.num_actions();
110
111 let mut layer_sizes = vec![obs_size];
112 layer_sizes.extend_from_slice(&config.hidden_sizes);
113 layer_sizes.push(num_actions);
114
115 let online_net = QNetwork::new(&layer_sizes, &device);
116 let target_net = QNetwork::new(&layer_sizes, &device.clone());
117
118 let optimiser = AdamConfig::new()
119 .with_epsilon(1e-8)
120 .init::<B, QNetwork<B>>();
121
122 Self {
123 online_net,
124 target_net,
125 optimiser,
126 buffer,
127 encoder,
128 action_mapper,
129 device: device.clone(),
130 total_steps: 0,
131 config,
132 explore_rng: SmallRng::seed_from_u64(seed),
134 sample_rng: SmallRng::seed_from_u64(seed.wrapping_add(0x9e37_79b9_7f4a_7c15)),
135 ep_loss_mean: Mean::default(),
136 ep_loss_std: Std::default(),
137 ep_loss_max: Max::default(),
138 _env: std::marker::PhantomData,
139 }
140 }
141
142 pub fn epsilon(&self) -> f64 {
144 self.config.epsilon_at(self.total_steps)
145 }
146
147 pub fn set_total_steps(&mut self, steps: usize) {
149 self.total_steps = steps;
150 }
151
152 pub fn into_policy(self) -> super::inference::DqnPolicy<E, Enc, Act, B::InnerBackend> {
154 super::inference::DqnPolicy::from_network(
155 self.online_net.valid(),
156 self.encoder,
157 self.action_mapper,
158 self.device,
159 )
160 }
161
162 fn sync_target(&mut self) {
163 self.target_net = self.online_net.valid();
164 }
165
166 fn train_step(&mut self) -> f64 {
168 let batch = self.buffer.sample(self.config.batch_size, &mut self.sample_rng);
169 let batch_size = batch.len();
170
171 let obs_batch: Vec<_> = batch.iter().map(|e| &e.observation).cloned().collect();
172 let next_obs_batch: Vec<_> = batch.iter().map(|e| &e.next_observation).cloned().collect();
173
174 let obs_tensor = self.encoder.encode_batch(&obs_batch, &self.device);
175 let next_obs_tensor = self.encoder.encode_batch(&next_obs_batch, &self.device.clone());
176
177 let action_indices: Vec<usize> = batch.iter()
178 .map(|e| self.action_mapper.action_to_index(&e.action))
179 .collect();
180 let rewards: Vec<f32> = batch.iter().map(|e| e.reward as f32).collect();
181 let masks: Vec<f32> = batch.iter().map(|e| e.bootstrap_mask() as f32).collect();
182
183 let rewards_t: Tensor<B, 1> = Tensor::from_floats(rewards.as_slice(), &self.device);
184 let masks_t: Tensor<B, 1> = Tensor::from_floats(masks.as_slice(), &self.device);
185
186 let next_q_values = self.target_net.forward(next_obs_tensor);
187 let max_next_q: Tensor<B::InnerBackend, 1> = next_q_values.max_dim(1).squeeze::<1>();
188 let max_next_q_autodiff: Tensor<B, 1> = Tensor::from_inner(max_next_q);
189
190 let targets = rewards_t + masks_t * max_next_q_autodiff * self.config.gamma as f32;
191
192 let q_values = self.online_net.forward(obs_tensor);
193 let action_indices_t = Tensor::<B, 1, Int>::from_ints(
194 action_indices.iter().map(|&i| i as i32).collect::<Vec<_>>().as_slice(),
195 &self.device,
196 );
197 let q_taken = q_values
198 .gather(1, action_indices_t.reshape([batch_size, 1]))
199 .squeeze::<1>();
200
201 let loss = HuberLossConfig::new(1.0)
202 .init()
203 .forward(q_taken, targets.detach(), Reduction::Mean);
204
205 let loss_val = loss.clone().into_scalar().elem::<f64>();
206
207 let grads = loss.backward();
208 let grads = GradientsParams::from_grads(grads, &self.online_net);
209 self.online_net = self.optimiser.step(
210 self.config.learning_rate,
211 self.online_net.clone(),
212 grads,
213 );
214
215 loss_val
216 }
217}
218
219impl<E, Enc, Act, B, Buf> Checkpointable for DqnAgent<E, Enc, Act, B, Buf>
222where
223 E: Environment,
224 E::Observation: Clone + Send + Sync + 'static,
225 E::Action: Clone + Send + Sync + 'static,
226 Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
227 Act: DiscreteActionMapper<E::Action>,
228 B: AutodiffBackend,
229 Buf: ReplayBuffer<E::Observation, E::Action>,
230{
231 fn save(&self, path: &Path) -> anyhow::Result<()> {
232 self.online_net
233 .clone()
234 .save_file(path.to_path_buf(), &CompactRecorder::new())
235 .map(|_| ())
236 .map_err(|e| anyhow::anyhow!(e))
237 }
238
239 fn load(mut self, path: &Path) -> anyhow::Result<Self> {
240 self.online_net = self.online_net
241 .load_file(path.to_path_buf(), &CompactRecorder::new(), &self.device)
242 .map_err(|e| anyhow::anyhow!(e))?;
243 self.target_net = self.online_net.valid();
244 Ok(self)
245 }
246}
247
248impl<E, Enc, Act, B, Buf> LearningAgent<E> for DqnAgent<E, Enc, Act, B, Buf>
251where
252 E: Environment,
253 E::Observation: Clone + Send + Sync + 'static,
254 E::Action: Clone + Send + Sync + 'static,
255 Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
256 Act: DiscreteActionMapper<E::Action>,
257 B: AutodiffBackend,
258 Buf: ReplayBuffer<E::Observation, E::Action>,
259{
260 fn act(&mut self, obs: &E::Observation, mode: ActMode) -> E::Action {
261 match mode {
262 ActMode::Exploit => {
263 let obs_tensor = ObservationEncoder::<E::Observation, B::InnerBackend>::encode(
266 &self.encoder, obs, &self.device,
267 ).unsqueeze_dim(0);
268 let q_values = self.online_net.valid().forward(obs_tensor);
269 let idx: usize = q_values
270 .argmax(1)
271 .into_data()
272 .to_vec::<i64>()
273 .unwrap()[0] as usize;
274 self.action_mapper.index_to_action(idx)
275 }
276 ActMode::Explore => {
277 let epsilon = self.config.epsilon_at(self.total_steps);
278 if self.explore_rng.gen::<f64>() < epsilon {
279 let idx = self.explore_rng.gen_range(0..self.action_mapper.num_actions());
280 self.action_mapper.index_to_action(idx)
281 } else {
282 let obs_tensor = self.encoder.encode(obs, &self.device).unsqueeze_dim(0);
283 let q_values = self.online_net.forward(obs_tensor);
284 let idx: usize = q_values
285 .argmax(1)
286 .into_data()
287 .to_vec::<i64>()
288 .unwrap()[0] as usize;
289 self.action_mapper.index_to_action(idx)
290 }
291 }
292 }
293 }
294
295 fn observe(&mut self, experience: Experience<E::Observation, E::Action>) {
296 self.buffer.push(experience);
297 self.total_steps += 1;
298
299 if self.total_steps.is_multiple_of(self.config.target_update_freq) {
300 self.sync_target();
301 }
302
303 if self.buffer.ready_for(self.config.batch_size)
304 && self.buffer.len() >= self.config.min_replay_size
305 {
306 let loss = self.train_step();
307 self.ep_loss_mean.update(loss);
308 self.ep_loss_std.update(loss);
309 self.ep_loss_max.update(loss);
310 }
311 }
312
313 fn total_steps(&self) -> usize {
314 self.total_steps
315 }
316
317 fn episode_extras(&self) -> HashMap<String, f64> {
318 [
319 ("epsilon".to_string(), self.epsilon()),
320 ("loss_mean".to_string(), self.ep_loss_mean.value()),
321 ("loss_std".to_string(), self.ep_loss_std.value()),
322 ("loss_max".to_string(), self.ep_loss_max.value()),
323 ]
324 .into_iter()
325 .filter(|(_, v)| v.is_finite())
326 .collect()
327 }
328
329 fn on_episode_start(&mut self) {
330 self.ep_loss_mean.reset();
331 self.ep_loss_std.reset();
332 self.ep_loss_max.reset();
333 }
334}
335
336impl<E, Enc, Act, B, Buf> Policy<E::Observation, E::Action> for DqnAgent<E, Enc, Act, B, Buf>
339where
340 E: Environment,
341 Enc: ObservationEncoder<E::Observation, B::InnerBackend>,
342 Act: DiscreteActionMapper<E::Action>,
343 B: AutodiffBackend,
344 Buf: ReplayBuffer<E::Observation, E::Action>,
345{
346 fn act(&self, obs: &E::Observation) -> E::Action {
347 let obs_tensor = self.encoder.encode(obs, &self.device).unsqueeze_dim(0);
348 let q_values = self.online_net.valid().forward(obs_tensor);
349 let idx: usize = q_values
350 .argmax(1)
351 .into_data()
352 .to_vec::<i64>()
353 .unwrap()[0] as usize;
354 self.action_mapper.index_to_action(idx)
355 }
356}