border_candle_agent/util/
critic.rs1use crate::{
3 model::SubModel2,
4 opt::{Optimizer, OptimizerConfig},
5 util::track_with_replace_substring,
6};
7use anyhow::{Context, Result};
8use candle_core::{DType::F32, Device, Tensor, D};
9use candle_nn::{VarBuilder, VarMap};
10use log::info;
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use std::{
13 fs::File,
14 io::{BufReader, Write},
15 path::{Path, PathBuf},
16};
17
18#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
19pub struct MultiCriticConfig<Q> {
21 pub n_nets: usize,
23
24 pub q_config: Option<Q>,
26
27 pub opt_config: OptimizerConfig,
29
30 pub tau: f64,
32}
33
34impl<Q> Default for MultiCriticConfig<Q> {
35 fn default() -> Self {
36 Self {
37 n_nets: 2,
38 q_config: None,
39 opt_config: OptimizerConfig::Adam { lr: 0.0003 },
40 tau: 0.005,
41 }
42 }
43}
44
45impl<Q> MultiCriticConfig<Q>
46where
47 Q: DeserializeOwned + Serialize,
48{
49 pub fn n_nets(mut self, v: usize) -> Self {
51 self.n_nets = v;
52 self
53 }
54
55 pub fn q_config(mut self, v: Q) -> Self {
57 self.q_config = Some(v);
58 self
59 }
60
61 pub fn opt_config(mut self, v: OptimizerConfig) -> Self {
63 self.opt_config = v;
64 self
65 }
66
67 pub fn tau(mut self, v: f64) -> Self {
69 self.tau = v;
70 self
71 }
72
73 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
75 let file = File::open(path)?;
76 let rdr = BufReader::new(file);
77 let b = serde_yaml::from_reader(rdr)?;
78 Ok(b)
79 }
80
81 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
83 let mut file = File::create(path)?;
84 file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
85 Ok(())
86 }
87}
88
89pub struct MultiCritic<Q>
95where
96 Q: SubModel2<Output = Tensor>,
97 Q::Config: DeserializeOwned + Serialize,
98{
99 n_nets: usize,
100 tau: f64,
101 device: Device,
102 varmap: VarMap,
103 varmap_tgt: VarMap, q_config: Q::Config,
107 qs: Vec<Q>,
108 qs_tgt: Vec<Q>, opt_config: OptimizerConfig,
111 opt: Optimizer, }
113
114impl<Q> MultiCritic<Q>
115where
116 Q: SubModel2<Output = Tensor>,
117 Q::Config: DeserializeOwned + Serialize + Clone,
118{
119 pub fn build(config: MultiCriticConfig<Q::Config>, device: Device) -> Result<MultiCritic<Q>> {
121 let tau = config.tau;
122 let n_nets = config.n_nets;
123 let q_config = config.q_config.context("q_config is not set.")?;
124 let opt_config = config.opt_config;
125
126 let (varmap, qs) = Self::build_critic_networks(&q_config, &device, n_nets, "critic");
128
129 let (varmap_tgt, qs_tgt) =
131 Self::build_critic_networks(&q_config, &device, n_nets, "critic_tgt");
132
133 let opt = opt_config.build(varmap.all_vars())?;
135
136 track_with_replace_substring(&varmap_tgt, &varmap, 1.0, ("critic", "critic_tgt"))?;
138
139 Ok(Self {
140 tau,
141 n_nets,
142 device,
143 varmap,
144 varmap_tgt,
145 q_config,
146 qs,
147 qs_tgt,
148 opt_config,
149 opt,
150 })
151 }
152
153 fn build_critic_networks(
154 q_config: &Q::Config,
155 device: &Device,
156 n_nets: usize,
157 prefix: &str,
158 ) -> (VarMap, Vec<Q>) {
159 let varmap = VarMap::new();
160 let qs = (0..n_nets)
161 .map(|ix| {
162 if device.is_cuda() {
163 device.set_seed((ix + 10) as _).unwrap();
164 }
165 let vb = VarBuilder::from_varmap(&varmap, F32, &device)
166 .set_prefix(format!("{}{}", prefix, ix));
167 Q::build(vb, q_config.clone())
168 })
169 .collect();
170
171 (varmap, qs)
172 }
173
174 pub fn soft_update(&mut self) -> Result<()> {
175 track_with_replace_substring(
176 &self.varmap_tgt,
177 &self.varmap,
178 self.tau,
179 ("critic", "critic_tgt"),
180 )?;
181 Ok(())
182 }
183
184 pub fn qvals(&self, obs: &Q::Input1, act: &Q::Input2) -> Vec<Tensor> {
186 self.qs
187 .iter()
188 .map(|critic| {
189 let q = critic.forward(obs, act).squeeze(D::Minus1).unwrap();
190 q
192 })
193 .collect()
194 }
195
196 pub fn qvals_min(&self, obs: &Q::Input1, act: &Q::Input2) -> Result<Tensor> {
198 let qvals = self.qvals(obs, act);
199 let qvals = Tensor::stack(&qvals, 0)?; let qvals_min = qvals.min(0)?.squeeze(D::Minus1)?; Ok(qvals_min)
202 }
203
204 pub fn qvals_min_tgt(&self, obs: &Q::Input1, act: &Q::Input2) -> Result<Tensor> {
206 let qvals: Vec<Tensor> = self
207 .qs_tgt
208 .iter()
209 .map(|critic| {
210 let q = critic.forward(obs, act).squeeze(D::Minus1).unwrap();
211 q
213 })
214 .collect();
215 let qvals = Tensor::stack(&qvals, 0)?; let qvals_min = qvals.min(0)?.squeeze(D::Minus1)?; Ok(qvals_min)
218 }
219}
220
221impl<Q> Clone for MultiCritic<Q>
222where
223 Q: SubModel2<Output = Tensor>,
224 Q::Config: DeserializeOwned + Serialize + Clone,
225{
226 fn clone(&self) -> Self {
227 let tau = self.tau;
228 let n_nets = self.n_nets;
229 let device = self.device.clone();
230 let q_config = self.q_config.clone();
231 let opt_config = self.opt_config.clone();
232
233 let (mut varmap, qs) = Self::build_critic_networks(&q_config, &device, n_nets, "critic");
235
236 let (mut varmap_tgt, qs_tgt) =
238 Self::build_critic_networks(&q_config, &device, n_nets, "critic_tgt");
239
240 let opt = opt_config.build(varmap.all_vars()).unwrap();
242
243 varmap.clone_from(&self.varmap);
245 varmap_tgt.clone_from(&self.varmap_tgt);
246
247 Self {
248 tau,
249 n_nets,
250 device,
251 varmap,
252 varmap_tgt,
253 q_config,
254 qs,
255 qs_tgt,
256 opt_config,
257 opt,
258 }
259 }
260}
261
262impl<Q> MultiCritic<Q>
263where
264 Q: SubModel2<Output = Tensor>,
265 Q::Config: DeserializeOwned + Serialize,
266{
267 pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
269 self.opt.backward_step(loss)
270 }
271
272 pub fn save<T: AsRef<Path>>(&self, prefix: T) -> Result<(PathBuf, PathBuf)> {
274 let mut path = PathBuf::from(prefix.as_ref());
275 path.set_extension("pt");
276 self.varmap.save(&path.as_path())?;
277 info!("Save critics to {:?}", path);
278
279 let mut path_tgt = PathBuf::from(prefix.as_ref());
280 path_tgt.set_extension("tgt.pt");
281 self.varmap.save(&path_tgt.as_path())?;
282 info!("Save target critics to {:?}", path_tgt);
283
284 Ok((path, path_tgt))
285 }
286
287 pub fn load<T: AsRef<Path>>(&mut self, prefix: T) -> Result<()> {
289 let mut path = PathBuf::from(prefix.as_ref());
290 path.set_extension("pt");
291 self.varmap.load(&path.as_path())?;
292 info!("Load critics from {:?}", path);
293
294 let mut path = PathBuf::from(prefix.as_ref());
295 path.set_extension("tgt.pt");
296 self.varmap.load(&path.as_path())?;
297 info!("Load target critics from {:?}", path);
298
299 Ok(())
300 }
301}
302
303mod test {
304 #[allow(unused_imports)]
305 use super::*;
306
307 #[test]
308 fn test_varmap() -> Result<()> {
310 let varmap = VarMap::new();
311
312 let vb1 = VarBuilder::from_varmap(&varmap, F32, &Device::Cpu).set_prefix("critic1");
314 candle_nn::linear(4, 4, vb1.pp("layer1"))?;
315 candle_nn::linear(4, 4, vb1.pp("layer2"))?;
316
317 let vb2 = VarBuilder::from_varmap(&varmap, F32, &Device::Cpu).set_prefix("critic2");
319 candle_nn::linear(4, 4, vb2.pp("layer1"))?;
320 candle_nn::linear(4, 4, vb2.pp("layer2"))?;
321
322 varmap
324 .data()
325 .lock()
326 .unwrap()
327 .iter()
328 .for_each(|(key, _)| println!("{:?}", key));
329
330 Ok(())
342 }
343
344 #[test]
345 fn test_lt() -> Result<()> {
347 use std::convert::TryFrom;
348
349 let tau = Tensor::try_from(0.7f32)?;
351
352 let u = Tensor::from_slice(&[0.2f32, -0.1, 0.0, 0.02, -10.0], &[5], &Device::Cpu)?;
354
355 let w = &tau
358 .broadcast_sub(&u.lt(0f32)?.to_dtype(candle_core::DType::F32)?)?
359 .abs()?;
360
361 println!("{:?}", tau.dims());
362 println!("{:?}", w);
363
364 Ok(())
365 }
366}