border_tch_agent/sac/actor/
base.rs1use super::ActorConfig;
2use crate::{
3 model::{ModelBase, SubModel},
4 opt::{Optimizer, OptimizerConfig},
5 util::OutDim,
6};
7use anyhow::{Context, Result};
8use log::{info, trace};
9use serde::{de::DeserializeOwned, Serialize};
10use std::path::Path;
11use tch::{nn, Device, Tensor};
12
13pub struct Actor<P>
15where
16 P: SubModel<Output = (Tensor, Tensor)>,
17 P::Config: DeserializeOwned + Serialize + OutDim,
18{
19 device: Device,
20 var_store: nn::VarStore,
21
22 pub(super) out_dim: i64,
24
25 pi: P,
27
28 opt_config: OptimizerConfig,
30 opt: Optimizer,
31}
32
33impl<P> Actor<P>
34where
35 P: SubModel<Output = (Tensor, Tensor)>,
36 P::Config: DeserializeOwned + Serialize + OutDim,
37{
38 pub fn build(config: ActorConfig<P::Config>, device: Device) -> Result<Actor<P>> {
40 let pi_config = config.pi_config.context("pi_config is not set.")?;
41 let out_dim = pi_config.get_out_dim();
42 let opt_config = config.opt_config;
43 let var_store = nn::VarStore::new(device);
44 let pi = P::build(&var_store, pi_config);
45
46 Ok(Actor::_build(
47 device, out_dim, opt_config, pi, var_store, None,
48 ))
49 }
50
51 fn _build(
52 device: Device,
53 out_dim: i64,
54 opt_config: OptimizerConfig,
55 pi: P,
56 mut var_store: nn::VarStore,
57 var_store_src: Option<&nn::VarStore>,
58 ) -> Self {
59 let opt = opt_config.build(&var_store).unwrap();
61
62 if let Some(var_store_src) = var_store_src {
64 var_store.copy(var_store_src).unwrap();
65 }
66
67 Self {
68 device,
69 out_dim,
70 opt_config,
71 var_store,
72 opt,
73 pi,
74 }
75 }
76
77 pub fn forward(&self, x: &P::Input) -> (Tensor, Tensor) {
79 let (mean, std) = self.pi.forward(&x);
80 debug_assert_eq!(mean.size().as_slice()[1], self.out_dim);
81 debug_assert_eq!(std.size().as_slice()[1], self.out_dim);
82 (mean, std)
83 }
84}
85
86impl<P> Clone for Actor<P>
87where
88 P: SubModel<Output = (Tensor, Tensor)>,
89 P::Config: DeserializeOwned + Serialize + OutDim,
90{
91 fn clone(&self) -> Self {
92 let device = self.device;
93 let out_dim = self.out_dim;
94 let opt_config = self.opt_config.clone();
95 let var_store = nn::VarStore::new(device);
96 let pi = self.pi.clone_with_var_store(&var_store);
97
98 Self::_build(
99 device,
100 out_dim,
101 opt_config,
102 pi,
103 var_store,
104 Some(&self.var_store),
105 )
106 }
107}
108
109impl<P> ModelBase for Actor<P>
110where
111 P: SubModel<Output = (Tensor, Tensor)>,
112 P::Config: DeserializeOwned + Serialize + OutDim,
113{
114 fn backward_step(&mut self, loss: &Tensor) {
115 self.opt.backward_step(loss);
116 }
117
118 fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
119 &mut self.var_store
120 }
121
122 fn get_var_store(&self) -> &nn::VarStore {
123 &self.var_store
124 }
125
126 fn save<T: AsRef<Path>>(&self, path: T) -> Result<()> {
127 self.var_store.save(&path)?;
128 info!("Save actor to {:?}", path.as_ref());
129 let vs = self.var_store.variables();
130 for (name, _) in vs.iter() {
131 trace!("Save variable {}", name);
132 }
133 Ok(())
134 }
135
136 fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
137 self.var_store.load(&path)?;
138 info!("Load actor from {:?}", path.as_ref());
139 Ok(())
140 }
141}