border_tch_agent/sac/critic/
base.rs

1use super::CriticConfig;
2use crate::{
3    model::{ModelBase, SubModel2},
4    opt::{Optimizer, OptimizerConfig},
5};
6use anyhow::{Context, Result};
7use log::{info, trace};
8use serde::{de::DeserializeOwned, Serialize};
9use std::path::Path;
10use tch::{nn, Device, Tensor};
11
12#[allow(clippy::upper_case_acronyms)]
13/// Represents soft critic for SAC agents.
14///
15/// It takes observations and actions as inputs and outputs action values.
16pub struct Critic<Q>
17where
18    Q: SubModel2<Output = Tensor>,
19    Q::Config: DeserializeOwned + Serialize,
20{
21    device: Device,
22    var_store: nn::VarStore,
23
24    // Action-value function
25    q: Q,
26
27    // Optimizer
28    opt_config: OptimizerConfig,
29    opt: Optimizer,
30}
31
32impl<Q> Critic<Q>
33where
34    Q: SubModel2<Output = Tensor>,
35    Q::Config: DeserializeOwned + Serialize,
36{
37    /// Constructs [Critic].
38    pub fn build(config: CriticConfig<Q::Config>, device: Device) -> Result<Critic<Q>> {
39        let q_config = config.q_config.context("q_config is not set.")?;
40        let opt_config = config.opt_config;
41        let var_store = nn::VarStore::new(device);
42        let q = Q::build(&var_store, q_config);
43
44        Ok(Critic::_build(device, opt_config, q, var_store, None))
45    }
46
47    fn _build(
48        device: Device,
49        opt_config: OptimizerConfig,
50        q: Q,
51        mut var_store: nn::VarStore,
52        var_store_src: Option<&nn::VarStore>,
53    ) -> Self {
54        // Optimizer
55        let opt = opt_config.build(&var_store).unwrap();
56
57        // Copy var_store
58        if let Some(var_store_src) = var_store_src {
59            var_store.copy(var_store_src).unwrap();
60        }
61
62        Self {
63            device,
64            opt_config,
65            var_store,
66            opt,
67            q,
68        }
69    }
70
71    /// Outputs the action-value given observations and actions.
72    pub fn forward(&self, obs: &Q::Input1, act: &Q::Input2) -> Tensor {
73        self.q.forward(obs, act)
74    }
75}
76
77impl<Q> Clone for Critic<Q>
78where
79    Q: SubModel2<Output = Tensor>,
80    Q::Config: DeserializeOwned + Serialize,
81{
82    fn clone(&self) -> Self {
83        let device = self.device;
84        let opt_config = self.opt_config.clone();
85        let var_store = nn::VarStore::new(device);
86        let q = self.q.clone_with_var_store(&var_store);
87
88        Self::_build(device, opt_config, q, var_store, Some(&self.var_store))
89    }
90}
91
92impl<Q> ModelBase for Critic<Q>
93where
94    Q: SubModel2<Output = Tensor>,
95    Q::Config: DeserializeOwned + Serialize,
96{
97    fn backward_step(&mut self, loss: &Tensor) {
98        self.opt.backward_step(loss);
99    }
100
101    fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
102        &mut self.var_store
103    }
104
105    fn get_var_store(&self) -> &nn::VarStore {
106        &self.var_store
107    }
108
109    fn save<T: AsRef<Path>>(&self, path: T) -> Result<()> {
110        self.var_store.save(&path)?;
111        info!("Save critic to {:?}", path.as_ref());
112        let vs = self.var_store.variables();
113        for (name, _) in vs.iter() {
114            trace!("Save variable {}", name);
115        }
116        Ok(())
117    }
118
119    fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
120        self.var_store.load(&path)?;
121        info!("Load critic from {:?}", path.as_ref());
122        Ok(())
123    }
124}