border_candle_agent/bc/
model.rs1use crate::{
2 model::SubModel1,
3 opt::{Optimizer, OptimizerConfig},
4 util::OutDim,
5};
6use anyhow::{Context, Result};
7use border_core::record::Record;
8use candle_core::{DType, Device, Tensor};
9use candle_nn::{VarBuilder, VarMap};
10use log::info;
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use std::{
13 fs::File,
14 io::{BufReader, Write},
15 path::Path,
16};
17
18#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
19pub struct BcModelConfig<C>
25where
26 C: OutDim + Clone,
27{
28 pub policy_model_config: Option<C>,
29 #[serde(default)]
30 pub opt_config: OptimizerConfig,
31}
32
33impl<C> Default for BcModelConfig<C>
34where
35 C: DeserializeOwned + Serialize + OutDim + Clone,
36{
37 fn default() -> Self {
38 Self {
39 policy_model_config: None,
40 opt_config: OptimizerConfig::default(),
41 }
42 }
43}
44
45impl<C> BcModelConfig<C>
46where
47 C: DeserializeOwned + Serialize + OutDim + Clone,
48{
49 pub fn policy_model_config(mut self, v: C) -> Self {
51 self.policy_model_config = Some(v);
52 self
53 }
54
55 pub fn out_dim(mut self, v: i64) -> Self {
57 match &mut self.policy_model_config {
58 None => {}
59 Some(policy_model_config) => policy_model_config.set_out_dim(v),
60 };
61 self
62 }
63
64 pub fn opt_config(mut self, v: OptimizerConfig) -> Self {
66 self.opt_config = v;
67 self
68 }
69
70 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
72 let file = File::open(path)?;
73 let rdr = BufReader::new(file);
74 let b = serde_yaml::from_reader(rdr)?;
75 Ok(b)
76 }
77
78 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
80 let mut file = File::create(path)?;
81 file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
82 Ok(())
83 }
84}
85
86pub struct BcModel<P>
92where
93 P: SubModel1<Output = Tensor>,
94 P::Config: DeserializeOwned + Serialize + OutDim,
95{
96 device: Device,
97 varmap: VarMap,
98
99 out_dim: i64,
101
102 policy_model: P,
104
105 opt_config: OptimizerConfig,
107
108 opt: Optimizer,
110
111 policy_model_config: P::Config,
113}
114
115impl<P> BcModel<P>
116where
117 P: SubModel1<Output = Tensor>,
118 P::Config: DeserializeOwned + Serialize + OutDim + Clone,
119{
120 pub fn build(config: BcModelConfig<P::Config>, device: Device) -> Result<Self> {
122 let out_dim = config.policy_model_config.as_ref().unwrap().get_out_dim();
123 let policy_model_config = config
124 .policy_model_config
125 .context("policy_model_config is not set.")?;
126 let opt_config = config.opt_config;
127 let varmap = VarMap::new();
128
129 let policy_model = {
131 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
132 P::build(vb, policy_model_config.clone())
133 };
134
135 Ok(Self::_build(
136 device,
137 out_dim as _,
138 opt_config,
139 policy_model_config,
140 policy_model,
141 varmap,
142 None,
143 ))
144 }
145
146 fn _build(
147 device: Device,
148 out_dim: i64,
149 opt_config: OptimizerConfig,
150 policy_model_config: P::Config,
151 policy_model: P,
152 mut varmap: VarMap,
153 varmap_src: Option<&VarMap>,
154 ) -> Self {
155 let opt = opt_config.build(varmap.all_vars()).unwrap();
157
158 if let Some(varmap_src) = varmap_src {
160 varmap.clone_from(varmap_src);
161 }
162
163 Self {
164 device,
165 out_dim,
166 opt_config,
167 varmap,
168 opt,
169 policy_model,
170 policy_model_config,
171 }
172 }
173
174 pub fn forward(&self, obs: &P::Input) -> Tensor {
176 self.policy_model.forward(obs)
177 }
178
179 pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
180 self.opt.backward_step(loss)
190 }
191
192 pub fn get_varmap(&self) -> &VarMap {
193 &self.varmap
194 }
195
196 pub fn save<T: AsRef<Path>>(&self, path: T) -> Result<()> {
197 self.varmap.save(&path)?;
198 info!("Save bc model to {:?}", path.as_ref());
199 Ok(())
200 }
201
202 pub fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
203 self.varmap.load(&path)?;
204 info!("Load bc model from {:?}", path.as_ref());
205 Ok(())
206 }
207
208 pub fn param_stats(&self) -> Record {
209 crate::util::param_stats(&self.varmap)
210 }
211}
212
213impl<P> Clone for BcModel<P>
214where
215 P: SubModel1<Output = Tensor>,
216 P::Config: DeserializeOwned + Serialize + OutDim + Clone,
217{
218 fn clone(&self) -> Self {
219 let device = self.device.clone();
220 let out_dim = self.out_dim;
221 let opt_config = self.opt_config.clone();
222 let policy_model_config = self.policy_model_config.clone();
223 let varmap = VarMap::new();
224 let policy_model = {
225 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
226 P::build(vb, self.policy_model_config.clone())
227 };
228
229 Self::_build(
230 device,
231 out_dim,
232 opt_config,
233 policy_model_config,
234 policy_model,
235 varmap,
236 Some(&self.varmap),
237 )
238 }
239}