border_candle_agent/bc/
config.rs1use super::BcModelConfig;
3use crate::{model::SubModel1, opt::OptimizerConfig, util::OutDim, Device};
4use anyhow::Result;
5use candle_core::Tensor;
6use log::info;
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8use std::{
9 default::Default,
10 fs::File,
11 io::{BufReader, Write},
12 marker::PhantomData,
13 path::Path,
14};
15
16#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
18pub enum BcActionType {
19 Discrete,
21
22 Continuous,
24}
25
26#[derive(Debug, Deserialize, Serialize, PartialEq)]
30pub struct BcConfig<P>
31where
32 P: SubModel1<Output = Tensor>,
33 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
34{
35 pub policy_model_config: BcModelConfig<P::Config>,
36 pub batch_size: usize,
37 pub action_type: BcActionType,
38 pub device: Option<Device>,
39 pub record_verbose_level: usize,
40 pub phantom: PhantomData<P>,
41}
42
43impl<P> Clone for BcConfig<P>
44where
45 P: SubModel1<Output = Tensor>,
46 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
47{
48 fn clone(&self) -> Self {
49 Self {
50 policy_model_config: self.policy_model_config.clone(),
51 batch_size: self.batch_size,
52 action_type: self.action_type.clone(),
53 device: self.device.clone(),
54 record_verbose_level: self.record_verbose_level,
55 phantom: PhantomData,
56 }
57 }
58}
59
60impl<P> Default for BcConfig<P>
61where
62 P: SubModel1<Output = Tensor>,
63 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
64{
65 fn default() -> Self {
67 Self {
68 policy_model_config: Default::default(),
69 batch_size: 1,
70 action_type: BcActionType::Discrete,
71 device: None,
72 record_verbose_level: 0,
73 phantom: PhantomData,
74 }
75 }
76}
77
78impl<P> BcConfig<P>
79where
80 P: SubModel1<Output = Tensor>,
81 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
82{
83 pub fn batch_size(mut self, v: usize) -> Self {
85 self.batch_size = v;
86 self
87 }
88
89 pub fn policy_model_config(mut self, policy_model_config: BcModelConfig<P::Config>) -> Self {
91 self.policy_model_config = policy_model_config;
92 self
93 }
94
95 pub fn out_dim(mut self, out_dim: i64) -> Self {
97 let policy_model_config = self.policy_model_config.clone();
98 self.policy_model_config = policy_model_config.out_dim(out_dim as _);
99 self
100 }
101
102 pub fn device(mut self, device: candle_core::Device) -> Self {
104 self.device = Some(device.into());
105 self
106 }
107
108 pub fn action_type(mut self, action_type: BcActionType) -> Self {
110 self.action_type = action_type;
111 self
112 }
113
114 pub fn optimizer(mut self, opt_config: OptimizerConfig) -> Self {
116 self.policy_model_config.opt_config = opt_config;
117 self
118 }
119
120 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
122 let path_ = path.as_ref().to_owned();
123 let file = File::open(path)?;
124 let rdr = BufReader::new(file);
125 let b = serde_yaml::from_reader(rdr)?;
126 info!("Load config of BC agent from {}", path_.to_str().unwrap());
127 Ok(b)
128 }
129
130 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
132 let path_ = path.as_ref().to_owned();
133 let mut file = File::create(path)?;
134 file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
135 info!("Save config of BC agent into {}", path_.to_str().unwrap());
136 Ok(())
137 }
138}