border_candle_agent/iql/
value.rs

1//! State value function.
2use crate::{
3    model::SubModel1,
4    opt::{Optimizer, OptimizerConfig},
5};
6use anyhow::{Context, Result};
7use candle_core::{DType, Device, Tensor};
8use candle_nn::{VarBuilder, VarMap};
9use log::info;
10use serde::{de::DeserializeOwned, Deserialize, Serialize};
11use std::{
12    fs::File,
13    io::{BufReader, Write},
14    path::{Path, PathBuf},
15};
16
17/// Configuration of [`Value`].
18#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
19pub struct ValueConfig<P> {
20    /// Configuration of value function network.
21    pub value_config: Option<P>,
22
23    /// Configuration of optimizer.
24    pub opt_config: OptimizerConfig,
25}
26
27impl<Q> Default for ValueConfig<Q> {
28    fn default() -> Self {
29        Self {
30            value_config: None,
31            opt_config: OptimizerConfig::Adam { lr: 0.0003 },
32        }
33    }
34}
35
36impl<P> ValueConfig<P>
37where
38    P: DeserializeOwned + Serialize,
39{
40    /// Sets configurations for value function network.
41    pub fn value_config(mut self, v: P) -> Self {
42        self.value_config = Some(v);
43        self
44    }
45
46    /// Sets optimizer configuration.
47    pub fn opt_config(mut self, v: OptimizerConfig) -> Self {
48        self.opt_config = v;
49        self
50    }
51
52    /// Loads [`ValueConfig`] from YAML file.
53    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
54        let file = File::open(path)?;
55        let rdr = BufReader::new(file);
56        let b = serde_yaml::from_reader(rdr)?;
57        Ok(b)
58    }
59
60    /// Saves [`ValueConfig`] as YAML file.
61    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
62        let mut file = File::create(path)?;
63        file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
64        Ok(())
65    }
66}
67
68/// State value function.
69pub struct Value<P>
70where
71    P: SubModel1<Output = Tensor>,
72    P::Config: DeserializeOwned + Serialize + Clone,
73{
74    #[allow(dead_code)]
75    device: Device, // required when implementing Clone trait
76    varmap: VarMap,
77
78    // State-value function
79    #[allow(dead_code)]
80    value_config: P::Config, // required when implementing Clone trait
81    value: P,
82
83    // Optimizer
84    #[allow(dead_code)]
85    opt_config: OptimizerConfig, // required when implementing Clone trait
86    opt: Optimizer,
87}
88
89impl<P> Value<P>
90where
91    P: SubModel1<Output = Tensor>,
92    P::Config: DeserializeOwned + Serialize + Clone,
93{
94    /// Constructs [`Value`].
95    pub fn build(config: ValueConfig<P::Config>, device: Device) -> Result<Value<P>> {
96        let value_config = config.value_config.context("value_config is not set.")?;
97        let varmap = VarMap::new();
98        let value = {
99            let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device).set_prefix("value");
100            P::build(vb, value_config.clone())
101        };
102        let opt_config = config.opt_config;
103        let opt = opt_config.build(varmap.all_vars()).unwrap();
104
105        Ok(Self {
106            device,
107            opt_config,
108            varmap,
109            opt,
110            value,
111            value_config,
112        })
113    }
114
115    /// Returns the state-value for given state (observation).
116    pub fn forward(&self, x: &P::Input) -> Tensor {
117        self.value.forward(&x)
118    }
119
120    /// Backward step for all variables in the value network.
121    pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
122        self.opt.backward_step(loss)
123    }
124
125    /// Save variables to prefix + ".pt".
126    pub fn save(&self, prefix: impl AsRef<Path>) -> Result<PathBuf> {
127        let mut path = PathBuf::from(prefix.as_ref());
128        path.set_extension("pt");
129        self.varmap.save(&path.as_path())?;
130        info!("Save value network parameters to {:?}", path);
131
132        Ok(path)
133    }
134
135    /// Load variables from prefix + ".pt".
136    pub fn load(&mut self, prefix: impl AsRef<Path>) -> Result<()> {
137        let mut path = PathBuf::from(prefix.as_ref());
138        path.set_extension("pt");
139        self.varmap.load(&path.as_path())?;
140        info!("Load value network parameters from {:?}", path);
141
142        Ok(())
143    }
144}
145
146impl<P> Clone for Value<P>
147where
148    P: SubModel1<Output = Tensor>,
149    P::Config: DeserializeOwned + Serialize + Clone,
150{
151    fn clone(&self) -> Self {
152        unimplemented!();
153    }
154}