border_candle_agent/bc/
model.rs

1use crate::{
2    model::SubModel1,
3    opt::{Optimizer, OptimizerConfig},
4    util::OutDim,
5};
6use anyhow::{Context, Result};
7use border_core::record::Record;
8use candle_core::{DType, Device, Tensor};
9use candle_nn::{VarBuilder, VarMap};
10use log::info;
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use std::{
13    fs::File,
14    io::{BufReader, Write},
15    path::Path,
16};
17
18#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
19/// Configuration of [`BcModel`].
20///
21/// The type parameter `C` should be a configuration of policy model, which should outputs a tensor.
22/// The policy model supports both discrete and continuous action spaces, leaving the interpretation
23/// of the output to the caller.
24pub struct BcModelConfig<C>
25where
26    C: OutDim + Clone,
27{
28    pub policy_model_config: Option<C>,
29    #[serde(default)]
30    pub opt_config: OptimizerConfig,
31}
32
33impl<C> Default for BcModelConfig<C>
34where
35    C: DeserializeOwned + Serialize + OutDim + Clone,
36{
37    fn default() -> Self {
38        Self {
39            policy_model_config: None,
40            opt_config: OptimizerConfig::default(),
41        }
42    }
43}
44
45impl<C> BcModelConfig<C>
46where
47    C: DeserializeOwned + Serialize + OutDim + Clone,
48{
49    /// Sets configurations for the policy model.
50    pub fn policy_model_config(mut self, v: C) -> Self {
51        self.policy_model_config = Some(v);
52        self
53    }
54
55    /// Sets output dimension of the model.
56    pub fn out_dim(mut self, v: i64) -> Self {
57        match &mut self.policy_model_config {
58            None => {}
59            Some(policy_model_config) => policy_model_config.set_out_dim(v),
60        };
61        self
62    }
63
64    /// Sets optimizer configuration.
65    pub fn opt_config(mut self, v: OptimizerConfig) -> Self {
66        self.opt_config = v;
67        self
68    }
69
70    /// Constructs [`BcModelConfig`] from YAML file.
71    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
72        let file = File::open(path)?;
73        let rdr = BufReader::new(file);
74        let b = serde_yaml::from_reader(rdr)?;
75        Ok(b)
76    }
77
78    /// Saves [`BcModelConfig`] to as a YAML file.
79    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
80        let mut file = File::create(path)?;
81        file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
82        Ok(())
83    }
84}
85
86/// Policy model for behaviour cloning.
87///
88/// The model's architecture is specified by the type parameter `P`,
89/// which must implement [`SubModel1`]. It takes [`SubModel1::Input`] as
90/// input and produces a tensor as output.
91pub struct BcModel<P>
92where
93    P: SubModel1<Output = Tensor>,
94    P::Config: DeserializeOwned + Serialize + OutDim,
95{
96    device: Device,
97    varmap: VarMap,
98
99    /// Dimension of the output vector.
100    out_dim: i64,
101
102    /// Policy model.
103    policy_model: P,
104
105    /// Optimizer configuration.
106    opt_config: OptimizerConfig,
107
108    /// Optimizer.
109    opt: Optimizer,
110
111    /// Policy model configuration.
112    policy_model_config: P::Config,
113}
114
115impl<P> BcModel<P>
116where
117    P: SubModel1<Output = Tensor>,
118    P::Config: DeserializeOwned + Serialize + OutDim + Clone,
119{
120    /// Constructs [`BcModel`].
121    pub fn build(config: BcModelConfig<P::Config>, device: Device) -> Result<Self> {
122        let out_dim = config.policy_model_config.as_ref().unwrap().get_out_dim();
123        let policy_model_config = config
124            .policy_model_config
125            .context("policy_model_config is not set.")?;
126        let opt_config = config.opt_config;
127        let varmap = VarMap::new();
128
129        // Build policy model
130        let policy_model = {
131            let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
132            P::build(vb, policy_model_config.clone())
133        };
134
135        Ok(Self::_build(
136            device,
137            out_dim as _,
138            opt_config,
139            policy_model_config,
140            policy_model,
141            varmap,
142            None,
143        ))
144    }
145
146    fn _build(
147        device: Device,
148        out_dim: i64,
149        opt_config: OptimizerConfig,
150        policy_model_config: P::Config,
151        policy_model: P,
152        mut varmap: VarMap,
153        varmap_src: Option<&VarMap>,
154    ) -> Self {
155        // Optimizer
156        let opt = opt_config.build(varmap.all_vars()).unwrap();
157
158        // Copy varmap
159        if let Some(varmap_src) = varmap_src {
160            varmap.clone_from(varmap_src);
161        }
162
163        Self {
164            device,
165            out_dim,
166            opt_config,
167            varmap,
168            opt,
169            policy_model,
170            policy_model_config,
171        }
172    }
173
174    /// Outputs the action-value given observation(s).
175    pub fn forward(&self, obs: &P::Input) -> Tensor {
176        self.policy_model.forward(obs)
177    }
178
179    pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
180        // Consider to use gradient clipping, below code
181        // let mut grads = loss.backward()?;
182        // for (_, var) in self.varmap.data().lock().unwrap().iter() {
183        //     let g1 = grads.get(var).unwrap();
184        //     let g2 = g1.clamp(-1.0, 1.0)?;
185        //     let _ = grads.remove(&var).unwrap();
186        //     let _ = grads.insert(&var, g2);
187        // }
188        // self.opt.step(&grads)
189        self.opt.backward_step(loss)
190    }
191
192    pub fn get_varmap(&self) -> &VarMap {
193        &self.varmap
194    }
195
196    pub fn save<T: AsRef<Path>>(&self, path: T) -> Result<()> {
197        self.varmap.save(&path)?;
198        info!("Save bc model to {:?}", path.as_ref());
199        Ok(())
200    }
201
202    pub fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
203        self.varmap.load(&path)?;
204        info!("Load bc model from {:?}", path.as_ref());
205        Ok(())
206    }
207
208    pub fn param_stats(&self) -> Record {
209        crate::util::param_stats(&self.varmap)
210    }
211}
212
213impl<P> Clone for BcModel<P>
214where
215    P: SubModel1<Output = Tensor>,
216    P::Config: DeserializeOwned + Serialize + OutDim + Clone,
217{
218    fn clone(&self) -> Self {
219        let device = self.device.clone();
220        let out_dim = self.out_dim;
221        let opt_config = self.opt_config.clone();
222        let policy_model_config = self.policy_model_config.clone();
223        let varmap = VarMap::new();
224        let policy_model = {
225            let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
226            P::build(vb, self.policy_model_config.clone())
227        };
228
229        Self::_build(
230            device,
231            out_dim,
232            opt_config,
233            policy_model_config,
234            policy_model,
235            varmap,
236            Some(&self.varmap),
237        )
238    }
239}