border_candle_agent/iql/
config.rs

1//! Configuration of IQL agent.
2use super::ValueConfig;
3use crate::{
4    model::{SubModel1, SubModel2},
5    util::{actor::GaussianActorConfig, critic::MultiCriticConfig, CriticLoss, OutDim},
6    Device,
7};
8use anyhow::Result;
9use candle_core::Tensor;
10use log::info;
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use std::{
13    fmt::Debug,
14    fs::File,
15    io::{BufReader, Write},
16    path::Path,
17};
18
19/// Configuration of [`Iql`](super::Iql).
20#[allow(clippy::upper_case_acronyms)]
21#[derive(Debug, Deserialize, Serialize, PartialEq)]
22pub struct IqlConfig<Q, P, V>
23where
24    Q: SubModel2<Output = Tensor>,
25    P: SubModel1<Output = (Tensor, Tensor)>,
26    V: SubModel1<Output = Tensor>,
27    Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
28    P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
29    V::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
30{
31    /// Configuration of the value model.
32    pub value_config: ValueConfig<V::Config>,
33
34    /// Configuration of the critic model.
35    pub critic_config: MultiCriticConfig<Q::Config>,
36
37    /// Configuration of the actor model.
38    pub actor_config: GaussianActorConfig<P::Config>,
39
40    /// Discont factor.
41    pub gamma: f32,
42
43    /// Expectile value.
44    pub tau_iql: f64,
45
46    /// The inverse of lambda in the paper.
47    pub inv_lambda: f64,
48
49    /// Number of parameter updates per optimization step.
50    pub n_updates_per_opt: usize,
51
52    /// Batch size for training.
53    pub batch_size: usize,
54
55    // /// Scaling factor for rewards.
56    // pub reward_scale: f32,
57    /// If true, advantage weights are calculated with softmax within each mini-batch.
58    pub adv_softmax: bool,
59
60    // /// If `true`, the agent is
61    // pub train: bool,
62    /// Type of critic loss function.
63    pub critic_loss: CriticLoss,
64
65    /// Device used for the actor and critic models (e.g., CPU or GPU).
66    pub device: Option<Device>,
67
68    /// Maximum of exponent of advantage.
69    pub exp_adv_max: f64,
70}
71
72impl<Q, P, V> Clone for IqlConfig<Q, P, V>
73where
74    Q: SubModel2<Output = Tensor>,
75    P: SubModel1<Output = (Tensor, Tensor)>,
76    V: SubModel1<Output = Tensor>,
77    Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
78    P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
79    V::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
80{
81    fn clone(&self) -> Self {
82        Self {
83            value_config: self.value_config.clone(),
84            critic_config: self.critic_config.clone(),
85            actor_config: self.actor_config.clone(),
86            gamma: self.gamma,
87            tau_iql: self.tau_iql,
88            inv_lambda: self.inv_lambda,
89            n_updates_per_opt: self.n_updates_per_opt,
90            batch_size: self.batch_size,
91            // reward_scale: self.reward_scale,
92            adv_softmax: self.adv_softmax,
93            critic_loss: self.critic_loss.clone(),
94            device: self.device.clone(),
95            exp_adv_max: self.exp_adv_max,
96        }
97    }
98}
99
100impl<Q, P, V> Default for IqlConfig<Q, P, V>
101where
102    Q: SubModel2<Output = Tensor>,
103    P: SubModel1<Output = (Tensor, Tensor)>,
104    V: SubModel1<Output = Tensor>,
105    Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
106    P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
107    V::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
108{
109    fn default() -> Self {
110        Self {
111            value_config: Default::default(),
112            critic_config: Default::default(),
113            actor_config: Default::default(),
114            gamma: 0.99,
115            tau_iql: 0.7,
116            inv_lambda: 10.0,
117            n_updates_per_opt: 1,
118            batch_size: 1,
119            // reward_scale: 1.0,
120            adv_softmax: false,
121            critic_loss: CriticLoss::Mse,
122            device: None,
123            exp_adv_max: 100.0,
124        }
125    }
126}
127
128impl<Q, P, V> IqlConfig<Q, P, V>
129where
130    Q: SubModel2<Output = Tensor>,
131    P: SubModel1<Output = (Tensor, Tensor)>,
132    V: SubModel1<Output = Tensor>,
133    Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
134    P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
135    V::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
136{
137    /// Sets lambda.
138    pub fn lambda(mut self, v: f64) -> Self {
139        self.inv_lambda = 1.0 / v;
140        self
141    }
142
143    /// Sets the numper of parameter update steps per optimization step.
144    pub fn n_updates_per_opt(mut self, v: usize) -> Self {
145        self.n_updates_per_opt = v;
146        self
147    }
148
149    /// Batch size.
150    pub fn batch_size(mut self, v: usize) -> Self {
151        self.batch_size = v;
152        self
153    }
154
155    /// Discount factor.
156    pub fn discount_factor(mut self, v: f32) -> Self {
157        self.gamma = v;
158        self
159    }
160
161    // /// Reward scale.
162    // ///
163    // /// It works for obtaining target values, not the values in logs.
164    // pub fn reward_scale(mut self, v: f32) -> Self {
165    //     self.reward_scale = v;
166    //     self
167    // }
168
169    /// Critic loss.
170    pub fn critic_loss(mut self, v: CriticLoss) -> Self {
171        self.critic_loss = v;
172        self
173    }
174
175    /// Configuration of value function.
176    pub fn value_config(mut self, value_config: ValueConfig<V::Config>) -> Self {
177        self.value_config = value_config;
178        self
179    }
180
181    /// Configuration of actor.
182    pub fn actor_config(mut self, actor_config: GaussianActorConfig<P::Config>) -> Self {
183        self.actor_config = actor_config;
184        self
185    }
186
187    /// Configuration of critic.
188    pub fn critic_config(mut self, critic_config: MultiCriticConfig<Q::Config>) -> Self {
189        self.critic_config = critic_config;
190        self
191    }
192
193    /// Device.
194    pub fn device(mut self, device: candle_core::Device) -> Self {
195        self.device = Some(device.into());
196        self
197    }
198
199    /// If true, advantage weights are calculated with softmax within each mini-batch.
200    pub fn adv_softmax(mut self, b: bool) -> Self {
201        self.adv_softmax = b;
202        self
203    }
204
205    /// Saves [`IqlConfig`] to YAML file.
206    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
207        let path_ = path.as_ref().to_owned();
208        let mut file = File::create(path)?;
209        file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
210        info!("Save config of IQL agent into {}", path_.to_str().unwrap());
211        Ok(())
212    }
213
214    /// Constructs [`IqlConfig`] from YAML file.
215    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
216        let path_ = path.as_ref().to_owned();
217        let file = File::open(path)?;
218        let rdr = BufReader::new(file);
219        let b = serde_yaml::from_reader(rdr)?;
220        info!("Load config of IQL agent from {}", path_.to_str().unwrap());
221        Ok(b)
222    }
223}