1use super::{config::DqnConfig, explorer::DqnExplorer, model::DqnModel};
3use crate::{
4 model::{ModelBase, SubModel},
5 util::{track, CriticLoss, OutDim},
6};
7use anyhow::Result;
8use border_core::{
9 record::{Record, RecordValue},
10 Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch,
11};
12use serde::{de::DeserializeOwned, Serialize};
13use std::{
14 convert::{TryFrom, TryInto},
15 fs,
16 marker::PhantomData,
17 path::{Path, PathBuf},
18};
19use tch::{no_grad, Device, Tensor};
20
21#[allow(clippy::upper_case_acronyms)]
22pub struct Dqn<E, Q, R>
24where
25 Q: SubModel<Output = Tensor>,
26 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
27{
28 pub(in crate::dqn) soft_update_interval: usize,
29 pub(in crate::dqn) soft_update_counter: usize,
30 pub(in crate::dqn) n_updates_per_opt: usize,
31 pub(in crate::dqn) batch_size: usize,
32 pub(in crate::dqn) qnet: DqnModel<Q>,
33 pub(in crate::dqn) qnet_tgt: DqnModel<Q>,
34 pub(in crate::dqn) train: bool,
35 pub(in crate::dqn) phantom: PhantomData<(E, R)>,
36 pub(in crate::dqn) discount_factor: f64,
37 pub(in crate::dqn) tau: f64,
38 pub(in crate::dqn) explorer: DqnExplorer,
39 pub(in crate::dqn) device: Device,
40 pub(in crate::dqn) n_opts: usize,
41 pub(in crate::dqn) double_dqn: bool,
42 pub(in crate::dqn) _clip_reward: Option<f64>,
43 pub(in crate::dqn) clip_td_err: Option<(f64, f64)>,
44 pub(in crate::dqn) critic_loss: CriticLoss,
45 n_samples_act: usize,
46 n_samples_best_act: usize,
47 record_verbose_level: usize,
48}
49
50impl<E, Q, R> Dqn<E, Q, R>
51where
52 E: Env,
53 Q: SubModel<Output = Tensor>,
54 R: ReplayBufferBase,
55 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
56 R::Batch: TransitionBatch,
57 <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input>,
58 <R::Batch as TransitionBatch>::ActBatch: Into<Tensor>,
59{
60 fn update_critic(&mut self, buffer: &mut R) -> Record {
61 let mut record = Record::empty();
62 let batch = buffer.batch(self.batch_size).unwrap();
63 let (obs, act, next_obs, reward, is_terminated, _is_truncated, ixs, weight) =
64 batch.unpack();
65 let obs = obs.into();
66 let act = act.into().to(self.device);
67 let next_obs = next_obs.into();
68 let reward = Tensor::from_slice(&reward[..]).to(self.device);
69 let is_terminated = Tensor::from_slice(&is_terminated[..]).to(self.device);
70
71 let pred = {
72 let x = self.qnet.forward(&obs);
73 x.gather(-1, &act, false).squeeze()
74 };
75
76 if self.record_verbose_level >= 2 {
77 record.insert(
78 "pred_mean",
79 RecordValue::Scalar(
80 f32::try_from(pred.mean(tch::Kind::Float))
81 .expect("Failed to convert Tensor to f32"),
82 ),
83 );
84 }
85
86 if self.record_verbose_level >= 2 {
87 let reward_mean: f32 = reward.mean(tch::Kind::Float).try_into().unwrap();
88 record.insert("reward_mean", RecordValue::Scalar(reward_mean));
89 }
90
91 let tgt: Tensor = no_grad(|| {
92 let q = if self.double_dqn {
93 let x = self.qnet.forward(&next_obs);
94 let y = x.argmax(-1, false).unsqueeze(-1);
95 self.qnet_tgt
96 .forward(&next_obs)
97 .gather(-1, &y, false)
98 .squeeze()
99 } else {
100 let x = self.qnet_tgt.forward(&next_obs);
101 let y = x.argmax(-1, false).unsqueeze(-1);
102 x.gather(-1, &y, false).squeeze()
103 };
104 reward + (1 - is_terminated) * self.discount_factor * q
105 });
106
107 if self.record_verbose_level >= 2 {
108 record.insert(
109 "tgt_mean",
110 RecordValue::Scalar(
111 f32::try_from(tgt.mean(tch::Kind::Float))
112 .expect("Failed to convert Tensor to f32"),
113 ),
114 );
115 let tgt_minus_pred_mean: f32 =
116 (&tgt - &pred).mean(tch::Kind::Float).try_into().unwrap();
117 record.insert(
118 "tgt_minus_pred_mean",
119 RecordValue::Scalar(tgt_minus_pred_mean),
120 );
121 }
122
123 let loss = if let Some(ws) = weight {
124 let n = ws.len() as i64;
125 let td_errs = match self.clip_td_err {
126 None => (&pred - &tgt).abs(),
127 Some((min, max)) => (&pred - &tgt).abs().clip(min, max),
128 };
129 let loss = Tensor::from_slice(&ws[..]).to(self.device) * &td_errs;
130 let loss = match self.critic_loss {
131 CriticLoss::SmoothL1 => loss.smooth_l1_loss(
132 &Tensor::zeros(&[n], tch::kind::FLOAT_CPU).to(self.device),
133 tch::Reduction::Mean,
134 1.0,
135 ),
136 CriticLoss::Mse => loss.mse_loss(
137 &Tensor::zeros(&[n], tch::kind::FLOAT_CPU).to(self.device),
138 tch::Reduction::Mean,
139 ),
140 };
141 self.qnet.backward_step(&loss);
142 let td_errs = Vec::<f32>::try_from(td_errs).expect("Failed to convert Tensor to f32");
143 buffer.update_priority(&ixs, &Some(td_errs));
144 loss
145 } else {
146 let loss = match self.critic_loss {
147 CriticLoss::SmoothL1 => pred.smooth_l1_loss(&tgt, tch::Reduction::Mean, 1.0),
148 CriticLoss::Mse => pred.mse_loss(&tgt, tch::Reduction::Mean),
149 };
150 self.qnet.backward_step(&loss);
151 loss
152 };
153
154 record.insert(
155 "loss",
156 RecordValue::Scalar(f32::try_from(loss).expect("Failed to convert Tensor to f32")),
157 );
158
159 record
160 }
161
162 fn opt_(&mut self, buffer: &mut R) -> Record {
183 let mut record_ = Record::empty();
184
185 for _ in 0..self.n_updates_per_opt {
186 let record = self.update_critic(buffer);
187 record_ = record_.merge(record);
188 }
189
190 self.soft_update_counter += 1;
191 if self.soft_update_counter == self.soft_update_interval {
192 self.soft_update_counter = 0;
193 track(&mut self.qnet_tgt, &mut self.qnet, self.tau);
194 }
195
196 self.n_opts += 1;
197
198 record_
199 }
201}
202
203impl<E, Q, R> Policy<E> for Dqn<E, Q, R>
204where
205 E: Env,
206 Q: SubModel<Output = Tensor>,
207 E::Obs: Into<Q::Input>,
208 E::Act: From<Q::Output>,
209 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
210{
211 fn sample(&mut self, obs: &E::Obs) -> E::Act {
212 no_grad(|| {
213 let a = self.qnet.forward(&obs.clone().into());
214 let a = if self.train {
215 self.n_samples_act += 1;
216 match &mut self.explorer {
217 DqnExplorer::Softmax(softmax) => softmax.action(&a),
218 DqnExplorer::EpsilonGreedy(egreedy) => {
219 if self.record_verbose_level >= 2 {
220 let (act, best) = egreedy.action_with_best(&a);
221 if best {
222 self.n_samples_best_act += 1;
223 }
224 act
225 } else {
226 egreedy.action(&a)
227 }
228 }
229 }
230 } else {
231 if fastrand::f32() < 0.01 {
232 let n_actions = a.size()[1] as i64;
233 let a = fastrand::i64(0..n_actions);
234 Tensor::from(a)
235 } else {
236 a.argmax(-1, true)
237 }
238 };
239 a.into()
240 })
241 }
242}
243
244impl<E, Q, R> Configurable for Dqn<E, Q, R>
245where
246 E: Env,
247 Q: SubModel<Output = Tensor>,
248 E::Obs: Into<Q::Input>,
249 E::Act: From<Q::Output>,
250 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
251{
252 type Config = DqnConfig<Q>;
253
254 fn build(config: Self::Config) -> Self {
256 let device = config
257 .device
258 .expect("No device is given for DQN agent")
259 .into();
260 let qnet = DqnModel::build(config.model_config, device);
261 let qnet_tgt = qnet.clone();
262
263 Dqn {
264 qnet,
265 qnet_tgt,
266 soft_update_interval: config.soft_update_interval,
267 soft_update_counter: 0,
268 n_updates_per_opt: config.n_updates_per_opt,
269 batch_size: config.batch_size,
270 discount_factor: config.discount_factor,
271 tau: config.tau,
272 train: config.train,
273 explorer: config.explorer,
274 device,
275 n_opts: 0,
276 _clip_reward: config.clip_reward,
277 double_dqn: config.double_dqn,
278 clip_td_err: config.clip_td_err,
279 critic_loss: config.critic_loss,
280 n_samples_act: 0,
281 n_samples_best_act: 0,
282 record_verbose_level: config.record_verbose_level,
283 phantom: PhantomData,
284 }
285 }
286}
287
288impl<E, Q, R> Agent<E, R> for Dqn<E, Q, R>
289where
290 E: Env + 'static,
291 Q: SubModel<Output = Tensor> + 'static,
292 R: ReplayBufferBase + 'static,
293 E::Obs: Into<Q::Input>,
294 E::Act: From<Q::Output>,
295 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
296 R::Batch: TransitionBatch,
297 <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input>,
298 <R::Batch as TransitionBatch>::ActBatch: Into<Tensor>,
299{
300 fn train(&mut self) {
301 self.train = true;
302 }
303
304 fn eval(&mut self) {
305 self.train = false;
306 }
307
308 fn is_train(&self) -> bool {
309 self.train
310 }
311
312 fn opt(&mut self, buffer: &mut R) {
313 self.opt_(buffer);
314 }
315
316 fn opt_with_record(&mut self, buffer: &mut R) -> Record {
317 let mut record = {
318 let record = self.opt_(buffer);
319
320 match self.record_verbose_level >= 2 {
321 true => {
322 let record_weights = self.qnet.param_stats();
323 let record = record.merge(record_weights);
324 record
325 }
326 false => record,
327 }
328 };
329
330 if self.record_verbose_level >= 2 {
332 let ratio = match self.n_samples_act == 0 {
333 true => 0f32,
334 false => self.n_samples_best_act as f32 / self.n_samples_act as f32,
335 };
336 record.insert("ratio_best_act", RecordValue::Scalar(ratio));
337 self.n_samples_act = 0;
338 self.n_samples_best_act = 0;
339 }
340
341 record
342 }
343
344 fn save_params(&self, path: &Path) -> Result<Vec<PathBuf>> {
349 fs::create_dir_all(&path)?;
351 let path1 = path.join("qnet.pt.tch").to_path_buf();
352 let path2 = path.join("qnet_tgt.pt.tch").to_path_buf();
353 self.qnet.save(&path1)?;
354 self.qnet_tgt.save(&path2)?;
355 Ok(vec![path1, path2])
356 }
357
358 fn load_params(&mut self, path: &Path) -> Result<()> {
359 self.qnet.load(path.join("qnet.pt.tch").as_path())?;
360 self.qnet_tgt.load(path.join("qnet_tgt.pt.tch").as_path())?;
361 Ok(())
362 }
363
364 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
365 self
366 }
367
368 fn as_any_ref(&self) -> &dyn std::any::Any {
369 self
370 }
371}
372
373#[cfg(feature = "border-async-trainer")]
374use {crate::util::NamedTensors, border_async_trainer::SyncModel};
375
376#[cfg(feature = "border-async-trainer")]
377impl<E, Q, R> SyncModel for Dqn<E, Q, R>
378where
379 E: Env,
380 Q: SubModel<Output = Tensor>,
381 R: ReplayBufferBase,
382 E::Obs: Into<Q::Input>,
383 E::Act: From<Q::Output>,
384 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
385 R::Batch: TransitionBatch,
386 <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input>,
387 <R::Batch as TransitionBatch>::ActBatch: Into<Tensor>,
388{
389 type ModelInfo = NamedTensors;
390
391 fn model_info(&self) -> (usize, Self::ModelInfo) {
392 (
393 self.n_opts,
394 NamedTensors::copy_from(self.qnet.get_var_store()),
395 )
396 }
397
398 fn sync_model(&mut self, model_info: &Self::ModelInfo) {
399 let vs = self.qnet.get_var_store_mut();
400 model_info.copy_to(vs);
401 }
402}