border_candle_agent/util/
critic.rs

1//! Critic for agents with continuous action.
2use crate::{
3    model::SubModel2,
4    opt::{Optimizer, OptimizerConfig},
5    util::track_with_replace_substring,
6};
7use anyhow::{Context, Result};
8use candle_core::{DType::F32, Device, Tensor, D};
9use candle_nn::{VarBuilder, VarMap};
10use log::info;
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use std::{
13    fs::File,
14    io::{BufReader, Write},
15    path::{Path, PathBuf},
16};
17
18#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
19/// Configuration of [`MultiCritic`].
20pub struct MultiCriticConfig<Q> {
21    /// The number of critic networks.
22    pub n_nets: usize,
23
24    /// Configuration of critic networks.
25    pub q_config: Option<Q>,
26
27    /// Configuration of the optimizer.
28    pub opt_config: OptimizerConfig,
29
30    /// Soft update coefficient.
31    pub tau: f64,
32}
33
34impl<Q> Default for MultiCriticConfig<Q> {
35    fn default() -> Self {
36        Self {
37            n_nets: 2,
38            q_config: None,
39            opt_config: OptimizerConfig::Adam { lr: 0.0003 },
40            tau: 0.005,
41        }
42    }
43}
44
45impl<Q> MultiCriticConfig<Q>
46where
47    Q: DeserializeOwned + Serialize,
48{
49    /// Sets the numver of critic networks.
50    pub fn n_nets(mut self, v: usize) -> Self {
51        self.n_nets = v;
52        self
53    }
54
55    /// Sets configurations for action-value function.
56    pub fn q_config(mut self, v: Q) -> Self {
57        self.q_config = Some(v);
58        self
59    }
60
61    /// Sets optimizer configuration.
62    pub fn opt_config(mut self, v: OptimizerConfig) -> Self {
63        self.opt_config = v;
64        self
65    }
66
67    /// Sets soft update parameter tau.
68    pub fn tau(mut self, v: f64) -> Self {
69        self.tau = v;
70        self
71    }
72
73    /// Constructs [`MultiCriticConfig`] from YAML file.
74    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
75        let file = File::open(path)?;
76        let rdr = BufReader::new(file);
77        let b = serde_yaml::from_reader(rdr)?;
78        Ok(b)
79    }
80
81    /// Saves [`MultiCriticConfig`] as YAML file.
82    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
83        let mut file = File::create(path)?;
84        file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
85        Ok(())
86    }
87}
88
89/// Critic for agents with continuous action.
90///
91/// It takes observations and actions as inputs and outputs action values.
92///
93/// This struct has multiple q functions and corresponding target networks.
94pub struct MultiCritic<Q>
95where
96    Q: SubModel2<Output = Tensor>,
97    Q::Config: DeserializeOwned + Serialize,
98{
99    n_nets: usize,
100    tau: f64,
101    device: Device,
102    varmap: VarMap,
103    varmap_tgt: VarMap, // for target network
104
105    /// Action-value function
106    q_config: Q::Config,
107    qs: Vec<Q>,
108    qs_tgt: Vec<Q>, // for target network
109
110    opt_config: OptimizerConfig,
111    opt: Optimizer, // no optimizer required for tatget networks
112}
113
114impl<Q> MultiCritic<Q>
115where
116    Q: SubModel2<Output = Tensor>,
117    Q::Config: DeserializeOwned + Serialize + Clone,
118{
119    /// Constructs [`MultiCritic`].
120    pub fn build(config: MultiCriticConfig<Q::Config>, device: Device) -> Result<MultiCritic<Q>> {
121        let tau = config.tau;
122        let n_nets = config.n_nets;
123        let q_config = config.q_config.context("q_config is not set.")?;
124        let opt_config = config.opt_config;
125
126        // Critic networks
127        let (varmap, qs) = Self::build_critic_networks(&q_config, &device, n_nets, "critic");
128
129        // Target networks
130        let (varmap_tgt, qs_tgt) =
131            Self::build_critic_networks(&q_config, &device, n_nets, "critic_tgt");
132
133        // Optimizer, shared with critic networks
134        let opt = opt_config.build(varmap.all_vars())?;
135
136        // Copy parameters
137        track_with_replace_substring(&varmap_tgt, &varmap, 1.0, ("critic", "critic_tgt"))?;
138
139        Ok(Self {
140            tau,
141            n_nets,
142            device,
143            varmap,
144            varmap_tgt,
145            q_config,
146            qs,
147            qs_tgt,
148            opt_config,
149            opt,
150        })
151    }
152
153    fn build_critic_networks(
154        q_config: &Q::Config,
155        device: &Device,
156        n_nets: usize,
157        prefix: &str,
158    ) -> (VarMap, Vec<Q>) {
159        let varmap = VarMap::new();
160        let qs = (0..n_nets)
161            .map(|ix| {
162                if device.is_cuda() {
163                    device.set_seed((ix + 10) as _).unwrap();
164                }
165                let vb = VarBuilder::from_varmap(&varmap, F32, &device)
166                    .set_prefix(format!("{}{}", prefix, ix));
167                Q::build(vb, q_config.clone())
168            })
169            .collect();
170
171        (varmap, qs)
172    }
173
174    pub fn soft_update(&mut self) -> Result<()> {
175        track_with_replace_substring(
176            &self.varmap_tgt,
177            &self.varmap,
178            self.tau,
179            ("critic", "critic_tgt"),
180        )?;
181        Ok(())
182    }
183
184    /// Returns action values of all critics.
185    pub fn qvals(&self, obs: &Q::Input1, act: &Q::Input2) -> Vec<Tensor> {
186        self.qs
187            .iter()
188            .map(|critic| {
189                let q = critic.forward(obs, act).squeeze(D::Minus1).unwrap();
190                // debug_assert_eq!(q.dims(), &[self.batch_size]);
191                q
192            })
193            .collect()
194    }
195
196    /// Returns minimum action values of all target critics.
197    pub fn qvals_min(&self, obs: &Q::Input1, act: &Q::Input2) -> Result<Tensor> {
198        let qvals = self.qvals(obs, act);
199        let qvals = Tensor::stack(&qvals, 0)?; // [batch_size, self.n_nets]
200        let qvals_min = qvals.min(0)?.squeeze(D::Minus1)?; // [batch_size]
201        Ok(qvals_min)
202    }
203
204    /// Returns minimum action values of all target critics.
205    pub fn qvals_min_tgt(&self, obs: &Q::Input1, act: &Q::Input2) -> Result<Tensor> {
206        let qvals: Vec<Tensor> = self
207            .qs_tgt
208            .iter()
209            .map(|critic| {
210                let q = critic.forward(obs, act).squeeze(D::Minus1).unwrap();
211                // debug_assert_eq!(q.dims(), &[self.batch_size]);
212                q
213            })
214            .collect();
215        let qvals = Tensor::stack(&qvals, 0)?; // [self.n_nets, batch_size]
216        let qvals_min = qvals.min(0)?.squeeze(D::Minus1)?; // [batch_size]
217        Ok(qvals_min)
218    }
219}
220
221impl<Q> Clone for MultiCritic<Q>
222where
223    Q: SubModel2<Output = Tensor>,
224    Q::Config: DeserializeOwned + Serialize + Clone,
225{
226    fn clone(&self) -> Self {
227        let tau = self.tau;
228        let n_nets = self.n_nets;
229        let device = self.device.clone();
230        let q_config = self.q_config.clone();
231        let opt_config = self.opt_config.clone();
232
233        // Critic networks
234        let (mut varmap, qs) = Self::build_critic_networks(&q_config, &device, n_nets, "critic");
235
236        // Target networks
237        let (mut varmap_tgt, qs_tgt) =
238            Self::build_critic_networks(&q_config, &device, n_nets, "critic_tgt");
239
240        // Optimizer, shared with critic networks
241        let opt = opt_config.build(varmap.all_vars()).unwrap();
242
243        // Copy variables
244        varmap.clone_from(&self.varmap);
245        varmap_tgt.clone_from(&self.varmap_tgt);
246
247        Self {
248            tau,
249            n_nets,
250            device,
251            varmap,
252            varmap_tgt,
253            q_config,
254            qs,
255            qs_tgt,
256            opt_config,
257            opt,
258        }
259    }
260}
261
262impl<Q> MultiCritic<Q>
263where
264    Q: SubModel2<Output = Tensor>,
265    Q::Config: DeserializeOwned + Serialize,
266{
267    /// Backward step for all variables in critic networks.
268    pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
269        self.opt.backward_step(loss)
270    }
271
272    /// Save variables to prefix + ".pt" and + "_tgt.pt".
273    pub fn save<T: AsRef<Path>>(&self, prefix: T) -> Result<(PathBuf, PathBuf)> {
274        let mut path = PathBuf::from(prefix.as_ref());
275        path.set_extension("pt");
276        self.varmap.save(&path.as_path())?;
277        info!("Save critics to {:?}", path);
278
279        let mut path_tgt = PathBuf::from(prefix.as_ref());
280        path_tgt.set_extension("tgt.pt");
281        self.varmap.save(&path_tgt.as_path())?;
282        info!("Save target critics to {:?}", path_tgt);
283
284        Ok((path, path_tgt))
285    }
286
287    /// Load variables from prefix + ".pt" and + "_tgt.pt".
288    pub fn load<T: AsRef<Path>>(&mut self, prefix: T) -> Result<()> {
289        let mut path = PathBuf::from(prefix.as_ref());
290        path.set_extension("pt");
291        self.varmap.load(&path.as_path())?;
292        info!("Load critics from {:?}", path);
293
294        let mut path = PathBuf::from(prefix.as_ref());
295        path.set_extension("tgt.pt");
296        self.varmap.load(&path.as_path())?;
297        info!("Load target critics from {:?}", path);
298
299        Ok(())
300    }
301}
302
303mod test {
304    #[allow(unused_imports)]
305    use super::*;
306
307    #[test]
308    /// Check variable names in a VarMap.
309    fn test_varmap() -> Result<()> {
310        let varmap = VarMap::new();
311
312        // network 1
313        let vb1 = VarBuilder::from_varmap(&varmap, F32, &Device::Cpu).set_prefix("critic1");
314        candle_nn::linear(4, 4, vb1.pp("layer1"))?;
315        candle_nn::linear(4, 4, vb1.pp("layer2"))?;
316
317        // network 2
318        let vb2 = VarBuilder::from_varmap(&varmap, F32, &Device::Cpu).set_prefix("critic2");
319        candle_nn::linear(4, 4, vb2.pp("layer1"))?;
320        candle_nn::linear(4, 4, vb2.pp("layer2"))?;
321
322        // show variables in VarMap
323        varmap
324            .data()
325            .lock()
326            .unwrap()
327            .iter()
328            .for_each(|(key, _)| println!("{:?}", key));
329
330        // Output:
331        // --- iql::critic::test::test_varmap stdout ----
332        // "critic1.layer1.weight"
333        // "critic1.layer1.bias"
334        // "critic2.layer1.bias"
335        // "critic2.layer2.weight"
336        // "critic2.layer2.bias"
337        // "critic1.layer2.weight"
338        // "critic1.layer2.bias"
339        // "critic2.layer1.weight"
340
341        Ok(())
342    }
343
344    #[test]
345    /// Check broadcast on Tensor::lt().
346    fn test_lt() -> Result<()> {
347        use std::convert::TryFrom;
348
349        // A scalar
350        let tau = Tensor::try_from(0.7f32)?;
351
352        // A vector
353        let u = Tensor::from_slice(&[0.2f32, -0.1, 0.0, 0.02, -10.0], &[5], &Device::Cpu)?;
354
355        // Expectile loss weight
356        // let w = &tau - &u.lt(0f32)?;
357        let w = &tau
358            .broadcast_sub(&u.lt(0f32)?.to_dtype(candle_core::DType::F32)?)?
359            .abs()?;
360
361        println!("{:?}", tau.dims());
362        println!("{:?}", w);
363
364        Ok(())
365    }
366}