border_tch_agent/dqn/model/
base.rs

1use super::DqnModelConfig;
2use crate::{
3    model::{ModelBase, SubModel},
4    opt::{Optimizer, OptimizerConfig},
5    util::OutDim,
6};
7use anyhow::Result;
8use border_core::record::Record;
9use log::{info, trace};
10use serde::{de::DeserializeOwned, Serialize};
11use std::{marker::PhantomData, path::Path};
12use tch::{nn, Device, Tensor};
13
14/// Action value function model for DQN.
15///
16/// The architecture of the model is defined by the type parameter `Q`,
17/// which should implement [`SubModel`].
18/// This takes [`SubModel::Input`] as input and outputs a tensor.
19/// The output tensor should have the same dimension as the number of actions.
20pub struct DqnModel<Q>
21where
22    Q: SubModel<Output = Tensor>,
23    Q::Config: DeserializeOwned + Serialize + OutDim,
24{
25    device: Device,
26    var_store: nn::VarStore,
27
28    // Dimension of the output vector (equal to the number of actions).
29    pub(super) out_dim: i64,
30
31    // Action-value function
32    q: Q,
33
34    // Optimizer
35    opt_config: OptimizerConfig,
36    opt: Optimizer,
37
38    phantom: PhantomData<Q>,
39}
40
41impl<Q> DqnModel<Q>
42where
43    Q: SubModel<Output = Tensor>,
44    Q::Config: DeserializeOwned + Serialize + OutDim,
45{
46    pub fn build(config: DqnModelConfig<Q::Config>, device: Device) -> Self {
47        let out_dim = config.q_config.as_ref().unwrap().get_out_dim();
48        let opt_config = config.opt_config.clone();
49        let var_store = nn::VarStore::new(device);
50        let q = Q::build(&var_store, config.q_config.unwrap());
51
52        Self::_build(device, out_dim, opt_config, q, var_store, None)
53    }
54
55    fn _build(
56        device: Device,
57        out_dim: i64,
58        opt_config: OptimizerConfig,
59        q: Q,
60        mut var_store: nn::VarStore,
61        var_store_src: Option<&nn::VarStore>,
62    ) -> Self {
63        // Optimizer
64        let opt = opt_config.build(&var_store).unwrap();
65
66        // Copy var_store
67        if let Some(var_store_src) = var_store_src {
68            var_store.copy(var_store_src).unwrap();
69        }
70
71        Self {
72            device,
73            out_dim,
74            opt_config,
75            var_store,
76            opt,
77            q,
78            phantom: PhantomData,
79        }
80    }
81
82    /// Outputs the action-value given observation(s).
83    pub fn forward(&self, x: &Q::Input) -> Tensor {
84        let a = self.q.forward(&x);
85        debug_assert_eq!(a.size().as_slice()[1], self.out_dim);
86        a
87    }
88
89    pub fn param_stats(&self) -> Record {
90        crate::util::param_stats(&self.var_store)
91    }
92}
93
94impl<Q> Clone for DqnModel<Q>
95where
96    Q: SubModel<Output = Tensor>,
97    Q::Config: DeserializeOwned + Serialize + OutDim,
98{
99    fn clone(&self) -> Self {
100        let device = self.device;
101        let out_dim = self.out_dim;
102        let opt_config = self.opt_config.clone();
103        let var_store = nn::VarStore::new(device);
104        let q = self.q.clone_with_var_store(&var_store);
105
106        Self::_build(
107            device,
108            out_dim,
109            opt_config,
110            q,
111            var_store,
112            Some(&self.var_store),
113        )
114    }
115}
116
117impl<Q> ModelBase for DqnModel<Q>
118where
119    Q: SubModel<Output = Tensor>,
120    Q::Config: DeserializeOwned + Serialize + OutDim,
121{
122    fn backward_step(&mut self, loss: &Tensor) {
123        self.opt.backward_step(loss);
124    }
125
126    fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
127        &mut self.var_store
128    }
129
130    fn get_var_store(&self) -> &nn::VarStore {
131        &self.var_store
132    }
133
134    fn save<T: AsRef<Path>>(&self, path: T) -> Result<()> {
135        self.var_store.save(&path)?;
136        info!("Save DQN model to {:?}", path.as_ref());
137        let vs = self.var_store.variables();
138        for (name, _) in vs.iter() {
139            trace!("Save variable {}", name);
140        }
141        Ok(())
142    }
143
144    fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
145        self.var_store.load(&path)?;
146        info!("Load DQN model from {:?}", path.as_ref());
147        Ok(())
148    }
149}