border_tch_agent/dqn/
base.rs

1//! DQN agent implemented with tch-rs.
2use super::{config::DqnConfig, explorer::DqnExplorer, model::DqnModel};
3use crate::{
4    model::{ModelBase, SubModel},
5    util::{track, CriticLoss, OutDim},
6};
7use anyhow::Result;
8use border_core::{
9    record::{Record, RecordValue},
10    Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch,
11};
12use serde::{de::DeserializeOwned, Serialize};
13use std::{
14    convert::{TryFrom, TryInto},
15    fs,
16    marker::PhantomData,
17    path::{Path, PathBuf},
18};
19use tch::{no_grad, Device, Tensor};
20
21#[allow(clippy::upper_case_acronyms)]
22/// DQN agent implemented with tch-rs.
23pub struct Dqn<E, Q, R>
24where
25    Q: SubModel<Output = Tensor>,
26    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
27{
28    pub(in crate::dqn) soft_update_interval: usize,
29    pub(in crate::dqn) soft_update_counter: usize,
30    pub(in crate::dqn) n_updates_per_opt: usize,
31    pub(in crate::dqn) batch_size: usize,
32    pub(in crate::dqn) qnet: DqnModel<Q>,
33    pub(in crate::dqn) qnet_tgt: DqnModel<Q>,
34    pub(in crate::dqn) train: bool,
35    pub(in crate::dqn) phantom: PhantomData<(E, R)>,
36    pub(in crate::dqn) discount_factor: f64,
37    pub(in crate::dqn) tau: f64,
38    pub(in crate::dqn) explorer: DqnExplorer,
39    pub(in crate::dqn) device: Device,
40    pub(in crate::dqn) n_opts: usize,
41    pub(in crate::dqn) double_dqn: bool,
42    pub(in crate::dqn) _clip_reward: Option<f64>,
43    pub(in crate::dqn) clip_td_err: Option<(f64, f64)>,
44    pub(in crate::dqn) critic_loss: CriticLoss,
45    n_samples_act: usize,
46    n_samples_best_act: usize,
47    record_verbose_level: usize,
48}
49
50impl<E, Q, R> Dqn<E, Q, R>
51where
52    E: Env,
53    Q: SubModel<Output = Tensor>,
54    R: ReplayBufferBase,
55    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
56    R::Batch: TransitionBatch,
57    <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input>,
58    <R::Batch as TransitionBatch>::ActBatch: Into<Tensor>,
59{
60    fn update_critic(&mut self, buffer: &mut R) -> Record {
61        let mut record = Record::empty();
62        let batch = buffer.batch(self.batch_size).unwrap();
63        let (obs, act, next_obs, reward, is_terminated, _is_truncated, ixs, weight) =
64            batch.unpack();
65        let obs = obs.into();
66        let act = act.into().to(self.device);
67        let next_obs = next_obs.into();
68        let reward = Tensor::from_slice(&reward[..]).to(self.device);
69        let is_terminated = Tensor::from_slice(&is_terminated[..]).to(self.device);
70
71        let pred = {
72            let x = self.qnet.forward(&obs);
73            x.gather(-1, &act, false).squeeze()
74        };
75
76        if self.record_verbose_level >= 2 {
77            record.insert(
78                "pred_mean",
79                RecordValue::Scalar(
80                    f32::try_from(pred.mean(tch::Kind::Float))
81                        .expect("Failed to convert Tensor to f32"),
82                ),
83            );
84        }
85
86        if self.record_verbose_level >= 2 {
87            let reward_mean: f32 = reward.mean(tch::Kind::Float).try_into().unwrap();
88            record.insert("reward_mean", RecordValue::Scalar(reward_mean));
89        }
90
91        let tgt: Tensor = no_grad(|| {
92            let q = if self.double_dqn {
93                let x = self.qnet.forward(&next_obs);
94                let y = x.argmax(-1, false).unsqueeze(-1);
95                self.qnet_tgt
96                    .forward(&next_obs)
97                    .gather(-1, &y, false)
98                    .squeeze()
99            } else {
100                let x = self.qnet_tgt.forward(&next_obs);
101                let y = x.argmax(-1, false).unsqueeze(-1);
102                x.gather(-1, &y, false).squeeze()
103            };
104            reward + (1 - is_terminated) * self.discount_factor * q
105        });
106
107        if self.record_verbose_level >= 2 {
108            record.insert(
109                "tgt_mean",
110                RecordValue::Scalar(
111                    f32::try_from(tgt.mean(tch::Kind::Float))
112                        .expect("Failed to convert Tensor to f32"),
113                ),
114            );
115            let tgt_minus_pred_mean: f32 =
116                (&tgt - &pred).mean(tch::Kind::Float).try_into().unwrap();
117            record.insert(
118                "tgt_minus_pred_mean",
119                RecordValue::Scalar(tgt_minus_pred_mean),
120            );
121        }
122
123        let loss = if let Some(ws) = weight {
124            let n = ws.len() as i64;
125            let td_errs = match self.clip_td_err {
126                None => (&pred - &tgt).abs(),
127                Some((min, max)) => (&pred - &tgt).abs().clip(min, max),
128            };
129            let loss = Tensor::from_slice(&ws[..]).to(self.device) * &td_errs;
130            let loss = match self.critic_loss {
131                CriticLoss::SmoothL1 => loss.smooth_l1_loss(
132                    &Tensor::zeros(&[n], tch::kind::FLOAT_CPU).to(self.device),
133                    tch::Reduction::Mean,
134                    1.0,
135                ),
136                CriticLoss::Mse => loss.mse_loss(
137                    &Tensor::zeros(&[n], tch::kind::FLOAT_CPU).to(self.device),
138                    tch::Reduction::Mean,
139                ),
140            };
141            self.qnet.backward_step(&loss);
142            let td_errs = Vec::<f32>::try_from(td_errs).expect("Failed to convert Tensor to f32");
143            buffer.update_priority(&ixs, &Some(td_errs));
144            loss
145        } else {
146            let loss = match self.critic_loss {
147                CriticLoss::SmoothL1 => pred.smooth_l1_loss(&tgt, tch::Reduction::Mean, 1.0),
148                CriticLoss::Mse => pred.mse_loss(&tgt, tch::Reduction::Mean),
149            };
150            self.qnet.backward_step(&loss);
151            loss
152        };
153
154        record.insert(
155            "loss",
156            RecordValue::Scalar(f32::try_from(loss).expect("Failed to convert Tensor to f32")),
157        );
158
159        record
160    }
161
162    // fn opt_(&mut self, buffer: &mut R) -> Record {
163    //     let mut loss = 0f32;
164
165    //     for _ in 0..self.n_updates_per_opt {
166    //         loss += self.update_critic(buffer);
167    //     }
168
169    //     self.soft_update_counter += 1;
170    //     if self.soft_update_counter == self.soft_update_interval {
171    //         self.soft_update_counter = 0;
172    //         track(&mut self.qnet_tgt, &mut self.qnet, self.tau);
173    //     }
174
175    //     loss /= self.n_updates_per_opt as f32;
176
177    //     self.n_opts += 1;
178
179    //     Record::from_slice(&[("loss", RecordValue::Scalar(loss))])
180    // }
181
182    fn opt_(&mut self, buffer: &mut R) -> Record {
183        let mut record_ = Record::empty();
184
185        for _ in 0..self.n_updates_per_opt {
186            let record = self.update_critic(buffer);
187            record_ = record_.merge(record);
188        }
189
190        self.soft_update_counter += 1;
191        if self.soft_update_counter == self.soft_update_interval {
192            self.soft_update_counter = 0;
193            track(&mut self.qnet_tgt, &mut self.qnet, self.tau);
194        }
195
196        self.n_opts += 1;
197
198        record_
199        // Record::from_slice(&[("loss", RecordValue::Scalar(loss_critic))])
200    }
201}
202
203impl<E, Q, R> Policy<E> for Dqn<E, Q, R>
204where
205    E: Env,
206    Q: SubModel<Output = Tensor>,
207    E::Obs: Into<Q::Input>,
208    E::Act: From<Q::Output>,
209    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
210{
211    fn sample(&mut self, obs: &E::Obs) -> E::Act {
212        no_grad(|| {
213            let a = self.qnet.forward(&obs.clone().into());
214            let a = if self.train {
215                self.n_samples_act += 1;
216                match &mut self.explorer {
217                    DqnExplorer::Softmax(softmax) => softmax.action(&a),
218                    DqnExplorer::EpsilonGreedy(egreedy) => {
219                        if self.record_verbose_level >= 2 {
220                            let (act, best) = egreedy.action_with_best(&a);
221                            if best {
222                                self.n_samples_best_act += 1;
223                            }
224                            act
225                        } else {
226                            egreedy.action(&a)
227                        }
228                    }
229                }
230            } else {
231                if fastrand::f32() < 0.01 {
232                    let n_actions = a.size()[1] as i64;
233                    let a = fastrand::i64(0..n_actions);
234                    Tensor::from(a)
235                } else {
236                    a.argmax(-1, true)
237                }
238            };
239            a.into()
240        })
241    }
242}
243
244impl<E, Q, R> Configurable for Dqn<E, Q, R>
245where
246    E: Env,
247    Q: SubModel<Output = Tensor>,
248    E::Obs: Into<Q::Input>,
249    E::Act: From<Q::Output>,
250    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
251{
252    type Config = DqnConfig<Q>;
253
254    /// Constructs DQN agent.
255    fn build(config: Self::Config) -> Self {
256        let device = config
257            .device
258            .expect("No device is given for DQN agent")
259            .into();
260        let qnet = DqnModel::build(config.model_config, device);
261        let qnet_tgt = qnet.clone();
262
263        Dqn {
264            qnet,
265            qnet_tgt,
266            soft_update_interval: config.soft_update_interval,
267            soft_update_counter: 0,
268            n_updates_per_opt: config.n_updates_per_opt,
269            batch_size: config.batch_size,
270            discount_factor: config.discount_factor,
271            tau: config.tau,
272            train: config.train,
273            explorer: config.explorer,
274            device,
275            n_opts: 0,
276            _clip_reward: config.clip_reward,
277            double_dqn: config.double_dqn,
278            clip_td_err: config.clip_td_err,
279            critic_loss: config.critic_loss,
280            n_samples_act: 0,
281            n_samples_best_act: 0,
282            record_verbose_level: config.record_verbose_level,
283            phantom: PhantomData,
284        }
285    }
286}
287
288impl<E, Q, R> Agent<E, R> for Dqn<E, Q, R>
289where
290    E: Env + 'static,
291    Q: SubModel<Output = Tensor> + 'static,
292    R: ReplayBufferBase + 'static,
293    E::Obs: Into<Q::Input>,
294    E::Act: From<Q::Output>,
295    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
296    R::Batch: TransitionBatch,
297    <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input>,
298    <R::Batch as TransitionBatch>::ActBatch: Into<Tensor>,
299{
300    fn train(&mut self) {
301        self.train = true;
302    }
303
304    fn eval(&mut self) {
305        self.train = false;
306    }
307
308    fn is_train(&self) -> bool {
309        self.train
310    }
311
312    fn opt(&mut self, buffer: &mut R) {
313        self.opt_(buffer);
314    }
315
316    fn opt_with_record(&mut self, buffer: &mut R) -> Record {
317        let mut record = {
318            let record = self.opt_(buffer);
319
320            match self.record_verbose_level >= 2 {
321                true => {
322                    let record_weights = self.qnet.param_stats();
323                    let record = record.merge(record_weights);
324                    record
325                }
326                false => record,
327            }
328        };
329
330        // Best action ratio for epsilon greedy
331        if self.record_verbose_level >= 2 {
332            let ratio = match self.n_samples_act == 0 {
333                true => 0f32,
334                false => self.n_samples_best_act as f32 / self.n_samples_act as f32,
335            };
336            record.insert("ratio_best_act", RecordValue::Scalar(ratio));
337            self.n_samples_act = 0;
338            self.n_samples_best_act = 0;
339        }
340
341        record
342    }
343
344    /// Save model parameters in the given directory.
345    ///
346    /// The parameters of the model are saved as `qnet.pt`.
347    /// The parameters of the target model are saved as `qnet_tgt.pt`.
348    fn save_params(&self, path: &Path) -> Result<Vec<PathBuf>> {
349        // TODO: consider to rename the path if it already exists
350        fs::create_dir_all(&path)?;
351        let path1 = path.join("qnet.pt.tch").to_path_buf();
352        let path2 = path.join("qnet_tgt.pt.tch").to_path_buf();
353        self.qnet.save(&path1)?;
354        self.qnet_tgt.save(&path2)?;
355        Ok(vec![path1, path2])
356    }
357
358    fn load_params(&mut self, path: &Path) -> Result<()> {
359        self.qnet.load(path.join("qnet.pt.tch").as_path())?;
360        self.qnet_tgt.load(path.join("qnet_tgt.pt.tch").as_path())?;
361        Ok(())
362    }
363
364    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
365        self
366    }
367
368    fn as_any_ref(&self) -> &dyn std::any::Any {
369        self
370    }
371}
372
373#[cfg(feature = "border-async-trainer")]
374use {crate::util::NamedTensors, border_async_trainer::SyncModel};
375
376#[cfg(feature = "border-async-trainer")]
377impl<E, Q, R> SyncModel for Dqn<E, Q, R>
378where
379    E: Env,
380    Q: SubModel<Output = Tensor>,
381    R: ReplayBufferBase,
382    E::Obs: Into<Q::Input>,
383    E::Act: From<Q::Output>,
384    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
385    R::Batch: TransitionBatch,
386    <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input>,
387    <R::Batch as TransitionBatch>::ActBatch: Into<Tensor>,
388{
389    type ModelInfo = NamedTensors;
390
391    fn model_info(&self) -> (usize, Self::ModelInfo) {
392        (
393            self.n_opts,
394            NamedTensors::copy_from(self.qnet.get_var_store()),
395        )
396    }
397
398    fn sync_model(&mut self, model_info: &Self::ModelInfo) {
399        let vs = self.qnet.get_var_store_mut();
400        model_info.copy_to(vs);
401    }
402}