border_candle_agent/iql/
value.rs1use crate::{
3 model::SubModel1,
4 opt::{Optimizer, OptimizerConfig},
5};
6use anyhow::{Context, Result};
7use candle_core::{DType, Device, Tensor};
8use candle_nn::{VarBuilder, VarMap};
9use log::info;
10use serde::{de::DeserializeOwned, Deserialize, Serialize};
11use std::{
12 fs::File,
13 io::{BufReader, Write},
14 path::{Path, PathBuf},
15};
16
17#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
19pub struct ValueConfig<P> {
20 pub value_config: Option<P>,
22
23 pub opt_config: OptimizerConfig,
25}
26
27impl<Q> Default for ValueConfig<Q> {
28 fn default() -> Self {
29 Self {
30 value_config: None,
31 opt_config: OptimizerConfig::Adam { lr: 0.0003 },
32 }
33 }
34}
35
36impl<P> ValueConfig<P>
37where
38 P: DeserializeOwned + Serialize,
39{
40 pub fn value_config(mut self, v: P) -> Self {
42 self.value_config = Some(v);
43 self
44 }
45
46 pub fn opt_config(mut self, v: OptimizerConfig) -> Self {
48 self.opt_config = v;
49 self
50 }
51
52 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
54 let file = File::open(path)?;
55 let rdr = BufReader::new(file);
56 let b = serde_yaml::from_reader(rdr)?;
57 Ok(b)
58 }
59
60 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
62 let mut file = File::create(path)?;
63 file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
64 Ok(())
65 }
66}
67
68pub struct Value<P>
70where
71 P: SubModel1<Output = Tensor>,
72 P::Config: DeserializeOwned + Serialize + Clone,
73{
74 #[allow(dead_code)]
75 device: Device, varmap: VarMap,
77
78 #[allow(dead_code)]
80 value_config: P::Config, value: P,
82
83 #[allow(dead_code)]
85 opt_config: OptimizerConfig, opt: Optimizer,
87}
88
89impl<P> Value<P>
90where
91 P: SubModel1<Output = Tensor>,
92 P::Config: DeserializeOwned + Serialize + Clone,
93{
94 pub fn build(config: ValueConfig<P::Config>, device: Device) -> Result<Value<P>> {
96 let value_config = config.value_config.context("value_config is not set.")?;
97 let varmap = VarMap::new();
98 let value = {
99 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device).set_prefix("value");
100 P::build(vb, value_config.clone())
101 };
102 let opt_config = config.opt_config;
103 let opt = opt_config.build(varmap.all_vars()).unwrap();
104
105 Ok(Self {
106 device,
107 opt_config,
108 varmap,
109 opt,
110 value,
111 value_config,
112 })
113 }
114
115 pub fn forward(&self, x: &P::Input) -> Tensor {
117 self.value.forward(&x)
118 }
119
120 pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
122 self.opt.backward_step(loss)
123 }
124
125 pub fn save(&self, prefix: impl AsRef<Path>) -> Result<PathBuf> {
127 let mut path = PathBuf::from(prefix.as_ref());
128 path.set_extension("pt");
129 self.varmap.save(&path.as_path())?;
130 info!("Save value network parameters to {:?}", path);
131
132 Ok(path)
133 }
134
135 pub fn load(&mut self, prefix: impl AsRef<Path>) -> Result<()> {
137 let mut path = PathBuf::from(prefix.as_ref());
138 path.set_extension("pt");
139 self.varmap.load(&path.as_path())?;
140 info!("Load value network parameters from {:?}", path);
141
142 Ok(())
143 }
144}
145
146impl<P> Clone for Value<P>
147where
148 P: SubModel1<Output = Tensor>,
149 P::Config: DeserializeOwned + Serialize + Clone,
150{
151 fn clone(&self) -> Self {
152 unimplemented!();
153 }
154}