1use std::path::Path;
2
3use burn::prelude::*;
4use burn::tensor::backend::AutodiffBackend;
5use burn::module::AutodiffModule;
6use burn::optim::{Adam, AdamConfig, GradientsParams, Optimizer};
7use burn::optim::adaptor::OptimizerAdaptor;
8use burn::nn::loss::{HuberLossConfig, Reduction};
9use burn::record::{CompactRecorder, RecorderError};
10use rand::{Rng, SeedableRng};
11use rand::rngs::SmallRng;
12use rl_traits::{Environment, Experience, Policy};
13
14use crate::encoding::{DiscreteActionMapper, ObservationEncoder};
15use super::config::DqnConfig;
16use super::network::QNetwork;
17use super::replay::CircularBuffer;
18use rl_traits::ReplayBuffer;
19
20pub struct DqnAgent<E, Enc, Act, B, Buf = CircularBuffer<
39 <E as Environment>::Observation,
40 <E as Environment>::Action,
41>>
42where
43 E: Environment,
44 B: AutodiffBackend,
45 Buf: ReplayBuffer<E::Observation, E::Action>,
46{
47 online_net: QNetwork<B>,
49 target_net: QNetwork<B::InnerBackend>,
50
51 optimiser: OptimizerAdaptor<Adam, QNetwork<B>, B>,
53
54 buffer: Buf,
56
57 encoder: Enc,
59 action_mapper: Act,
60
61 config: DqnConfig,
63 device: B::Device,
64 total_steps: usize,
65
66 rng: SmallRng,
67
68 _env: std::marker::PhantomData<E>,
69}
70
71impl<E, Enc, Act, B> DqnAgent<E, Enc, Act, B>
72where
73 E: Environment,
74 E::Observation: Clone + Send + Sync + 'static,
75 E::Action: Clone + Send + Sync + 'static,
76 Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
77 Act: DiscreteActionMapper<E::Action>,
78 B: AutodiffBackend,
79{
80 pub fn new(encoder: Enc, action_mapper: Act, config: DqnConfig, device: B::Device, seed: u64) -> Self {
84 let buffer = CircularBuffer::new(config.buffer_capacity);
85 Self::new_with_buffer(encoder, action_mapper, config, device, seed, buffer)
86 }
87}
88
89impl<E, Enc, Act, B, Buf> DqnAgent<E, Enc, Act, B, Buf>
90where
91 E: Environment,
92 E::Observation: Clone + Send + Sync + 'static,
93 E::Action: Clone + Send + Sync + 'static,
94 Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
95 Act: DiscreteActionMapper<E::Action>,
96 B: AutodiffBackend,
97 Buf: ReplayBuffer<E::Observation, E::Action>,
98{
99 pub fn new_with_buffer(
104 encoder: Enc,
105 action_mapper: Act,
106 config: DqnConfig,
107 device: B::Device,
108 seed: u64,
109 buffer: Buf,
110 ) -> Self {
111 let obs_size = <Enc as ObservationEncoder<E::Observation, B>>::obs_size(&encoder);
112 let num_actions = action_mapper.num_actions();
113
114 let mut layer_sizes = vec![obs_size];
116 layer_sizes.extend_from_slice(&config.hidden_sizes);
117 layer_sizes.push(num_actions);
118
119 let online_net = QNetwork::new(&layer_sizes, &device);
120 let target_net = QNetwork::new(&layer_sizes, &device.clone());
121
122 let optimiser = AdamConfig::new()
123 .with_epsilon(1e-8)
124 .init::<B, QNetwork<B>>();
125
126 Self {
127 online_net,
128 target_net,
129 optimiser,
130 buffer,
131 encoder,
132 action_mapper,
133 device: device.clone(),
134 total_steps: 0,
135 config,
136 rng: SmallRng::seed_from_u64(seed),
137 _env: std::marker::PhantomData,
138 }
139 }
140
141 pub fn observe(&mut self, experience: Experience<E::Observation, E::Action>) -> bool {
146 self.buffer.push(experience);
147 self.total_steps += 1;
148
149 if self.total_steps.is_multiple_of(self.config.target_update_freq) {
151 self.sync_target();
152 }
153
154 if !self.buffer.ready_for(self.config.batch_size) {
156 return false;
157 }
158 if self.buffer.len() < self.config.min_replay_size {
159 return false;
160 }
161
162 self.train_step();
163 true
164 }
165
166 pub fn epsilon(&self) -> f64 {
168 self.config.epsilon_at(self.total_steps)
169 }
170
171 pub fn total_steps(&self) -> usize {
173 self.total_steps
174 }
175
176 pub fn act_epsilon_greedy(&self, obs: &E::Observation, rng: &mut impl Rng) -> E::Action {
178 let epsilon = self.epsilon();
179 if rng.gen::<f64>() < epsilon {
180 let idx = rng.gen_range(0..self.action_mapper.num_actions());
182 self.action_mapper.index_to_action(idx)
183 } else {
184 self.act(obs)
185 }
186 }
187
188 pub fn save(&self, path: impl AsRef<Path>) -> Result<(), RecorderError> {
199 self.online_net
200 .clone()
201 .save_file(path.as_ref().to_path_buf(), &CompactRecorder::new())
202 .map(|_| ())
203 }
204
205 pub fn load(mut self, path: impl AsRef<Path>) -> Result<Self, RecorderError> {
215 self.online_net = self
216 .online_net
217 .load_file(path.as_ref().to_path_buf(), &CompactRecorder::new(), &self.device)?;
218 self.target_net = self.online_net.valid();
219 Ok(self)
220 }
221
222 pub fn into_policy(self) -> super::inference::DqnPolicy<E, Enc, Act, B::InnerBackend> {
233 super::inference::DqnPolicy::from_network(
234 self.online_net.valid(),
235 self.encoder,
236 self.action_mapper,
237 self.device,
238 )
239 }
240
241 pub fn set_total_steps(&mut self, steps: usize) {
246 self.total_steps = steps;
247 }
248
249 fn sync_target(&mut self) {
255 self.target_net = self.online_net.valid();
257 }
258
259 fn train_step(&mut self) {
261 let batch = self.buffer.sample(self.config.batch_size, &mut self.rng);
262
263 let batch_size = batch.len();
264
265 let obs_batch: Vec<_> = batch.iter().map(|e| &e.observation).cloned().collect();
267 let next_obs_batch: Vec<_> = batch.iter().map(|e| &e.next_observation).cloned().collect();
268
269 let obs_tensor = self.encoder.encode_batch(&obs_batch, &self.device);
270 let next_obs_tensor = self.encoder
272 .encode_batch(&next_obs_batch, &self.device.clone());
273
274 let action_indices: Vec<usize> = batch.iter()
276 .map(|e| self.action_mapper.action_to_index(&e.action))
277 .collect();
278
279 let rewards: Vec<f32> = batch.iter()
280 .map(|e| e.reward as f32)
281 .collect();
282
283 let masks: Vec<f32> = batch.iter()
284 .map(|e| e.bootstrap_mask() as f32)
285 .collect();
286
287 let rewards_t: Tensor<B, 1> = Tensor::from_floats(rewards.as_slice(), &self.device);
288 let masks_t: Tensor<B, 1> = Tensor::from_floats(masks.as_slice(), &self.device);
289
290 let next_q_values = self.target_net.forward(next_obs_tensor);
292 let max_next_q: Tensor<B::InnerBackend, 1> = next_q_values.max_dim(1).squeeze::<1>();
293 let max_next_q_autodiff: Tensor<B, 1> = Tensor::from_inner(max_next_q);
294
295 let targets = rewards_t + masks_t * max_next_q_autodiff * self.config.gamma as f32;
297
298 let q_values = self.online_net.forward(obs_tensor);
300 let action_indices_t = Tensor::<B, 1, Int>::from_ints(
301 action_indices.iter().map(|&i| i as i32).collect::<Vec<_>>().as_slice(),
302 &self.device,
303 );
304 let q_taken = q_values
306 .gather(1, action_indices_t.reshape([batch_size, 1]))
307 .squeeze::<1>();
308
309 let loss = HuberLossConfig::new(1.0).init().forward(q_taken, targets.detach(), Reduction::Mean);
311
312 let grads = loss.backward();
314 let grads = GradientsParams::from_grads(grads, &self.online_net);
315 self.online_net = self.optimiser.step(
316 self.config.learning_rate,
317 self.online_net.clone(),
318 grads,
319 );
320 }
321}
322
323impl<E, Enc, Act, B, Buf> Policy<E::Observation, E::Action> for DqnAgent<E, Enc, Act, B, Buf>
324where
325 E: Environment,
326 Enc: ObservationEncoder<E::Observation, B>,
327 Act: DiscreteActionMapper<E::Action>,
328 B: AutodiffBackend,
329 Buf: ReplayBuffer<E::Observation, E::Action>,
330{
331 fn act(&self, obs: &E::Observation) -> E::Action {
335 let obs_tensor = self.encoder.encode(obs, &self.device).unsqueeze_dim(0);
336 let q_values = self.online_net.forward(obs_tensor);
337 let best_action_idx: usize = q_values
338 .argmax(1)
339 .into_data()
340 .to_vec::<i64>()
341 .unwrap()[0] as usize;
342 self.action_mapper.index_to_action(best_action_idx)
343 }
344}