border_tch_agent/sac/actor/
base.rs

1use super::ActorConfig;
2use crate::{
3    model::{ModelBase, SubModel},
4    opt::{Optimizer, OptimizerConfig},
5    util::OutDim,
6};
7use anyhow::{Context, Result};
8use log::{info, trace};
9use serde::{de::DeserializeOwned, Serialize};
10use std::path::Path;
11use tch::{nn, Device, Tensor};
12
13/// Stochastic policy for SAC agents.
14pub struct Actor<P>
15where
16    P: SubModel<Output = (Tensor, Tensor)>,
17    P::Config: DeserializeOwned + Serialize + OutDim,
18{
19    device: Device,
20    var_store: nn::VarStore,
21
22    // Dimension of the action vector.
23    pub(super) out_dim: i64,
24
25    // Action-value function
26    pi: P,
27
28    // Optimizer
29    opt_config: OptimizerConfig,
30    opt: Optimizer,
31}
32
33impl<P> Actor<P>
34where
35    P: SubModel<Output = (Tensor, Tensor)>,
36    P::Config: DeserializeOwned + Serialize + OutDim,
37{
38    /// Constructs [`Actor`].
39    pub fn build(config: ActorConfig<P::Config>, device: Device) -> Result<Actor<P>> {
40        let pi_config = config.pi_config.context("pi_config is not set.")?;
41        let out_dim = pi_config.get_out_dim();
42        let opt_config = config.opt_config;
43        let var_store = nn::VarStore::new(device);
44        let pi = P::build(&var_store, pi_config);
45
46        Ok(Actor::_build(
47            device, out_dim, opt_config, pi, var_store, None,
48        ))
49    }
50
51    fn _build(
52        device: Device,
53        out_dim: i64,
54        opt_config: OptimizerConfig,
55        pi: P,
56        mut var_store: nn::VarStore,
57        var_store_src: Option<&nn::VarStore>,
58    ) -> Self {
59        // Optimizer
60        let opt = opt_config.build(&var_store).unwrap();
61
62        // Copy var_store
63        if let Some(var_store_src) = var_store_src {
64            var_store.copy(var_store_src).unwrap();
65        }
66
67        Self {
68            device,
69            out_dim,
70            opt_config,
71            var_store,
72            opt,
73            pi,
74        }
75    }
76
77    /// Outputs the parameters of Gaussian distribution given an observation.
78    pub fn forward(&self, x: &P::Input) -> (Tensor, Tensor) {
79        let (mean, std) = self.pi.forward(&x);
80        debug_assert_eq!(mean.size().as_slice()[1], self.out_dim);
81        debug_assert_eq!(std.size().as_slice()[1], self.out_dim);
82        (mean, std)
83    }
84}
85
86impl<P> Clone for Actor<P>
87where
88    P: SubModel<Output = (Tensor, Tensor)>,
89    P::Config: DeserializeOwned + Serialize + OutDim,
90{
91    fn clone(&self) -> Self {
92        let device = self.device;
93        let out_dim = self.out_dim;
94        let opt_config = self.opt_config.clone();
95        let var_store = nn::VarStore::new(device);
96        let pi = self.pi.clone_with_var_store(&var_store);
97
98        Self::_build(
99            device,
100            out_dim,
101            opt_config,
102            pi,
103            var_store,
104            Some(&self.var_store),
105        )
106    }
107}
108
109impl<P> ModelBase for Actor<P>
110where
111    P: SubModel<Output = (Tensor, Tensor)>,
112    P::Config: DeserializeOwned + Serialize + OutDim,
113{
114    fn backward_step(&mut self, loss: &Tensor) {
115        self.opt.backward_step(loss);
116    }
117
118    fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
119        &mut self.var_store
120    }
121
122    fn get_var_store(&self) -> &nn::VarStore {
123        &self.var_store
124    }
125
126    fn save<T: AsRef<Path>>(&self, path: T) -> Result<()> {
127        self.var_store.save(&path)?;
128        info!("Save actor to {:?}", path.as_ref());
129        let vs = self.var_store.variables();
130        for (name, _) in vs.iter() {
131            trace!("Save variable {}", name);
132        }
133        Ok(())
134    }
135
136    fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
137        self.var_store.load(&path)?;
138        info!("Load actor from {:?}", path.as_ref());
139        Ok(())
140    }
141}