border_candle_agent/bc/
config.rs

1//! Configuration of behavior cloning (BC) agent.
2use super::BcModelConfig;
3use crate::{model::SubModel1, opt::OptimizerConfig, util::OutDim, Device};
4use anyhow::Result;
5use candle_core::Tensor;
6use log::info;
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8use std::{
9    default::Default,
10    fs::File,
11    io::{BufReader, Write},
12    marker::PhantomData,
13    path::Path,
14};
15
16/// Action type of behavior cloning (BC) agent.
17#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
18pub enum BcActionType {
19    /// Discrete action.
20    Discrete,
21
22    /// Continuous action.
23    Continuous,
24}
25
26/// Configuration of [`Bc`](super::Bc) agent.
27///
28/// `P` is the type parameter of the policy model.
29#[derive(Debug, Deserialize, Serialize, PartialEq)]
30pub struct BcConfig<P>
31where
32    P: SubModel1<Output = Tensor>,
33    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
34{
35    pub policy_model_config: BcModelConfig<P::Config>,
36    pub batch_size: usize,
37    pub action_type: BcActionType,
38    pub device: Option<Device>,
39    pub record_verbose_level: usize,
40    pub phantom: PhantomData<P>,
41}
42
43impl<P> Clone for BcConfig<P>
44where
45    P: SubModel1<Output = Tensor>,
46    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
47{
48    fn clone(&self) -> Self {
49        Self {
50            policy_model_config: self.policy_model_config.clone(),
51            batch_size: self.batch_size,
52            action_type: self.action_type.clone(),
53            device: self.device.clone(),
54            record_verbose_level: self.record_verbose_level,
55            phantom: PhantomData,
56        }
57    }
58}
59
60impl<P> Default for BcConfig<P>
61where
62    P: SubModel1<Output = Tensor>,
63    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
64{
65    /// Constructs DQN builder with default parameters.
66    fn default() -> Self {
67        Self {
68            policy_model_config: Default::default(),
69            batch_size: 1,
70            action_type: BcActionType::Discrete,
71            device: None,
72            record_verbose_level: 0,
73            phantom: PhantomData,
74        }
75    }
76}
77
78impl<P> BcConfig<P>
79where
80    P: SubModel1<Output = Tensor>,
81    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
82{
83    /// Sets batch size.
84    pub fn batch_size(mut self, v: usize) -> Self {
85        self.batch_size = v;
86        self
87    }
88
89    /// Sets the configuration of the policy model.
90    pub fn policy_model_config(mut self, policy_model_config: BcModelConfig<P::Config>) -> Self {
91        self.policy_model_config = policy_model_config;
92        self
93    }
94
95    /// Sets the output dimention of the agent.
96    pub fn out_dim(mut self, out_dim: i64) -> Self {
97        let policy_model_config = self.policy_model_config.clone();
98        self.policy_model_config = policy_model_config.out_dim(out_dim as _);
99        self
100    }
101
102    /// Sets device.
103    pub fn device(mut self, device: candle_core::Device) -> Self {
104        self.device = Some(device.into());
105        self
106    }
107
108    // Sets action type.
109    pub fn action_type(mut self, action_type: BcActionType) -> Self {
110        self.action_type = action_type;
111        self
112    }
113
114    /// Sets optimizer.
115    pub fn optimizer(mut self, opt_config: OptimizerConfig) -> Self {
116        self.policy_model_config.opt_config = opt_config;
117        self
118    }
119
120    /// Loads [`BcConfig`] from YAML file.
121    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
122        let path_ = path.as_ref().to_owned();
123        let file = File::open(path)?;
124        let rdr = BufReader::new(file);
125        let b = serde_yaml::from_reader(rdr)?;
126        info!("Load config of BC agent from {}", path_.to_str().unwrap());
127        Ok(b)
128    }
129
130    /// Saves [`BcConfig`] to YAML file.
131    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
132        let path_ = path.as_ref().to_owned();
133        let mut file = File::create(path)?;
134        file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
135        info!("Save config of BC agent into {}", path_.to_str().unwrap());
136        Ok(())
137    }
138}