border_tch_agent/sac/critic/
base.rs1use super::CriticConfig;
2use crate::{
3 model::{ModelBase, SubModel2},
4 opt::{Optimizer, OptimizerConfig},
5};
6use anyhow::{Context, Result};
7use log::{info, trace};
8use serde::{de::DeserializeOwned, Serialize};
9use std::path::Path;
10use tch::{nn, Device, Tensor};
11
12#[allow(clippy::upper_case_acronyms)]
13pub struct Critic<Q>
17where
18 Q: SubModel2<Output = Tensor>,
19 Q::Config: DeserializeOwned + Serialize,
20{
21 device: Device,
22 var_store: nn::VarStore,
23
24 q: Q,
26
27 opt_config: OptimizerConfig,
29 opt: Optimizer,
30}
31
32impl<Q> Critic<Q>
33where
34 Q: SubModel2<Output = Tensor>,
35 Q::Config: DeserializeOwned + Serialize,
36{
37 pub fn build(config: CriticConfig<Q::Config>, device: Device) -> Result<Critic<Q>> {
39 let q_config = config.q_config.context("q_config is not set.")?;
40 let opt_config = config.opt_config;
41 let var_store = nn::VarStore::new(device);
42 let q = Q::build(&var_store, q_config);
43
44 Ok(Critic::_build(device, opt_config, q, var_store, None))
45 }
46
47 fn _build(
48 device: Device,
49 opt_config: OptimizerConfig,
50 q: Q,
51 mut var_store: nn::VarStore,
52 var_store_src: Option<&nn::VarStore>,
53 ) -> Self {
54 let opt = opt_config.build(&var_store).unwrap();
56
57 if let Some(var_store_src) = var_store_src {
59 var_store.copy(var_store_src).unwrap();
60 }
61
62 Self {
63 device,
64 opt_config,
65 var_store,
66 opt,
67 q,
68 }
69 }
70
71 pub fn forward(&self, obs: &Q::Input1, act: &Q::Input2) -> Tensor {
73 self.q.forward(obs, act)
74 }
75}
76
77impl<Q> Clone for Critic<Q>
78where
79 Q: SubModel2<Output = Tensor>,
80 Q::Config: DeserializeOwned + Serialize,
81{
82 fn clone(&self) -> Self {
83 let device = self.device;
84 let opt_config = self.opt_config.clone();
85 let var_store = nn::VarStore::new(device);
86 let q = self.q.clone_with_var_store(&var_store);
87
88 Self::_build(device, opt_config, q, var_store, Some(&self.var_store))
89 }
90}
91
92impl<Q> ModelBase for Critic<Q>
93where
94 Q: SubModel2<Output = Tensor>,
95 Q::Config: DeserializeOwned + Serialize,
96{
97 fn backward_step(&mut self, loss: &Tensor) {
98 self.opt.backward_step(loss);
99 }
100
101 fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
102 &mut self.var_store
103 }
104
105 fn get_var_store(&self) -> &nn::VarStore {
106 &self.var_store
107 }
108
109 fn save<T: AsRef<Path>>(&self, path: T) -> Result<()> {
110 self.var_store.save(&path)?;
111 info!("Save critic to {:?}", path.as_ref());
112 let vs = self.var_store.variables();
113 for (name, _) in vs.iter() {
114 trace!("Save variable {}", name);
115 }
116 Ok(())
117 }
118
119 fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
120 self.var_store.load(&path)?;
121 info!("Load critic from {:?}", path.as_ref());
122 Ok(())
123 }
124}