1use super::{BcActionType, BcConfig, BcModel};
3use crate::{model::SubModel1, util::OutDim};
4use anyhow::Result;
5use border_core::{
6 record::{Record, RecordValue},
7 Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch,
8};
9use candle_core::{shape::D, DType, Device, Tensor};
10use candle_nn::loss::mse;
11use serde::{de::DeserializeOwned, Serialize};
12use std::{
13 fs,
14 marker::PhantomData,
15 path::{Path, PathBuf},
16};
17
18#[allow(dead_code)]
19pub struct Bc<E, P, R>
23where
24 P: SubModel1<Output = Tensor>,
25 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
26{
27 policy_model: BcModel<P>,
28 batch_size: usize,
29 action_type: BcActionType,
30 device: Device,
31 record_verbose_level: usize,
32 phantom: PhantomData<(E, R)>,
33}
34
35impl<E, P, R> Policy<E> for Bc<E, P, R>
36where
37 E: Env,
38 P: SubModel1<Output = Tensor>,
39 E::Obs: Into<P::Input>,
40 E::Act: From<P::Output>,
41 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
42{
43 fn sample(&mut self, obs: &E::Obs) -> E::Act {
50 let a = self.policy_model.forward(&obs.clone().into()).detach();
51 match self.action_type {
52 BcActionType::Discrete => {
53 let a = a.argmax(D::Minus1).unwrap().to_dtype(DType::I64).unwrap();
54 a.into()
55 }
56 BcActionType::Continuous => a.into(),
57 }
58 }
59}
60
61impl<E, P, R> Configurable for Bc<E, P, R>
62where
63 E: Env,
64 P: SubModel1<Output = Tensor>,
65 E::Obs: Into<P::Input>,
66 E::Act: From<P::Output>,
67 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
68{
69 type Config = BcConfig<P>;
70
71 fn build(config: Self::Config) -> Self {
73 let device: Device = config
74 .device
75 .expect("No device is given for DQN agent")
76 .into();
77 let policy_model =
78 BcModel::build(config.policy_model_config.clone(), device.clone()).unwrap();
79
80 Self {
81 policy_model,
82 batch_size: config.batch_size,
83 action_type: config.action_type,
84 device,
85 record_verbose_level: config.record_verbose_level,
86 phantom: PhantomData,
87 }
88 }
89}
90
91impl<E, P, R> Agent<E, R> for Bc<E, P, R>
92where
93 E: Env,
94 P: SubModel1<Output = Tensor>,
95 R: ReplayBufferBase,
96 E::Obs: Into<P::Input>,
97 E::Act: From<P::Output>,
98 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
99 R::Batch: TransitionBatch,
100 <R::Batch as TransitionBatch>::ObsBatch: Into<P::Input>,
101 <R::Batch as TransitionBatch>::ActBatch: Into<Tensor>,
102{
103 fn train(&mut self) {}
105
106 fn eval(&mut self) {}
108
109 fn is_train(&self) -> bool {
111 false
112 }
113
114 fn opt(&mut self, buffer: &mut R) {
115 self.opt_(buffer);
116 }
117
118 fn opt_with_record(&mut self, buffer: &mut R) -> Record {
119 let record = {
120 let record = self.opt_(buffer);
121
122 match self.record_verbose_level >= 2 {
123 true => {
124 let record_weights = self.policy_model.param_stats();
125 let record = record.merge(record_weights);
126 record
127 }
128 false => record,
129 }
130 };
131
132 record
133 }
134
135 fn save_params(&self, path: &Path) -> Result<Vec<PathBuf>> {
139 fs::create_dir_all(&path)?;
141 let path = path.join("policy_model.pt").to_path_buf();
142 self.policy_model.save(&path)?;
143 Ok(vec![path])
144 }
145
146 fn load_params(&mut self, path: &Path) -> Result<()> {
150 self.policy_model
151 .load(&path.join("policy_model.pt").as_path())?;
152 Ok(())
153 }
154}
155
156impl<E, P, R> Bc<E, P, R>
157where
158 E: Env,
159 P: SubModel1<Output = Tensor>,
160 R: ReplayBufferBase,
161 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
162 R::Batch: TransitionBatch,
163 <R::Batch as TransitionBatch>::ObsBatch: Into<P::Input>,
164 <R::Batch as TransitionBatch>::ActBatch: Into<Tensor>,
165{
166 fn opt_(&mut self, buffer: &mut R) -> Record {
168 let batch = buffer.batch(self.batch_size).unwrap();
169 let (obs, act, _, _, _, _, _, _) = batch.unpack();
170 let obs = obs.into();
171 let act = act.into().to_device(&self.device).unwrap();
172 let loss = match self.action_type {
173 BcActionType::Discrete => {
174 panic!();
175 }
176 BcActionType::Continuous => {
177 let act_ = self.policy_model.forward(&obs);
178 mse(&act_, &act)
179 }
180 }
181 .unwrap();
182 self.policy_model.backward_step(&loss).unwrap();
183
184 let mut record = Record::empty();
185 record.insert(
186 "loss",
187 RecordValue::Scalar(
188 loss.to_device(&Device::Cpu)
189 .expect("Error when moving loss to CPU")
190 .mean_all()
191 .unwrap()
192 .to_scalar()
193 .unwrap(),
194 ),
195 );
196 record
197 }
198}