border_tch_agent/dqn/model/
base.rs1use super::DqnModelConfig;
2use crate::{
3 model::{ModelBase, SubModel},
4 opt::{Optimizer, OptimizerConfig},
5 util::OutDim,
6};
7use anyhow::Result;
8use border_core::record::Record;
9use log::{info, trace};
10use serde::{de::DeserializeOwned, Serialize};
11use std::{marker::PhantomData, path::Path};
12use tch::{nn, Device, Tensor};
13
14pub struct DqnModel<Q>
21where
22 Q: SubModel<Output = Tensor>,
23 Q::Config: DeserializeOwned + Serialize + OutDim,
24{
25 device: Device,
26 var_store: nn::VarStore,
27
28 pub(super) out_dim: i64,
30
31 q: Q,
33
34 opt_config: OptimizerConfig,
36 opt: Optimizer,
37
38 phantom: PhantomData<Q>,
39}
40
41impl<Q> DqnModel<Q>
42where
43 Q: SubModel<Output = Tensor>,
44 Q::Config: DeserializeOwned + Serialize + OutDim,
45{
46 pub fn build(config: DqnModelConfig<Q::Config>, device: Device) -> Self {
47 let out_dim = config.q_config.as_ref().unwrap().get_out_dim();
48 let opt_config = config.opt_config.clone();
49 let var_store = nn::VarStore::new(device);
50 let q = Q::build(&var_store, config.q_config.unwrap());
51
52 Self::_build(device, out_dim, opt_config, q, var_store, None)
53 }
54
55 fn _build(
56 device: Device,
57 out_dim: i64,
58 opt_config: OptimizerConfig,
59 q: Q,
60 mut var_store: nn::VarStore,
61 var_store_src: Option<&nn::VarStore>,
62 ) -> Self {
63 let opt = opt_config.build(&var_store).unwrap();
65
66 if let Some(var_store_src) = var_store_src {
68 var_store.copy(var_store_src).unwrap();
69 }
70
71 Self {
72 device,
73 out_dim,
74 opt_config,
75 var_store,
76 opt,
77 q,
78 phantom: PhantomData,
79 }
80 }
81
82 pub fn forward(&self, x: &Q::Input) -> Tensor {
84 let a = self.q.forward(&x);
85 debug_assert_eq!(a.size().as_slice()[1], self.out_dim);
86 a
87 }
88
89 pub fn param_stats(&self) -> Record {
90 crate::util::param_stats(&self.var_store)
91 }
92}
93
94impl<Q> Clone for DqnModel<Q>
95where
96 Q: SubModel<Output = Tensor>,
97 Q::Config: DeserializeOwned + Serialize + OutDim,
98{
99 fn clone(&self) -> Self {
100 let device = self.device;
101 let out_dim = self.out_dim;
102 let opt_config = self.opt_config.clone();
103 let var_store = nn::VarStore::new(device);
104 let q = self.q.clone_with_var_store(&var_store);
105
106 Self::_build(
107 device,
108 out_dim,
109 opt_config,
110 q,
111 var_store,
112 Some(&self.var_store),
113 )
114 }
115}
116
117impl<Q> ModelBase for DqnModel<Q>
118where
119 Q: SubModel<Output = Tensor>,
120 Q::Config: DeserializeOwned + Serialize + OutDim,
121{
122 fn backward_step(&mut self, loss: &Tensor) {
123 self.opt.backward_step(loss);
124 }
125
126 fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
127 &mut self.var_store
128 }
129
130 fn get_var_store(&self) -> &nn::VarStore {
131 &self.var_store
132 }
133
134 fn save<T: AsRef<Path>>(&self, path: T) -> Result<()> {
135 self.var_store.save(&path)?;
136 info!("Save DQN model to {:?}", path.as_ref());
137 let vs = self.var_store.variables();
138 for (name, _) in vs.iter() {
139 trace!("Save variable {}", name);
140 }
141 Ok(())
142 }
143
144 fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
145 self.var_store.load(&path)?;
146 info!("Load DQN model from {:?}", path.as_ref());
147 Ok(())
148 }
149}