border_candle_agent/bc/
base.rs

1//! Behavior cloning (BC) agent implemented with candle.
2use super::{BcActionType, BcConfig, BcModel};
3use crate::{model::SubModel1, util::OutDim};
4use anyhow::Result;
5use border_core::{
6    record::{Record, RecordValue},
7    Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch,
8};
9use candle_core::{shape::D, DType, Device, Tensor};
10use candle_nn::loss::mse;
11use serde::{de::DeserializeOwned, Serialize};
12use std::{
13    fs,
14    marker::PhantomData,
15    path::{Path, PathBuf},
16};
17
18#[allow(dead_code)]
19/// Behavior cloning (BC) agent implemented with candle.
20///
21/// `P` is the type parameter of the policy model.
22pub struct Bc<E, P, R>
23where
24    P: SubModel1<Output = Tensor>,
25    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
26{
27    policy_model: BcModel<P>,
28    batch_size: usize,
29    action_type: BcActionType,
30    device: Device,
31    record_verbose_level: usize,
32    phantom: PhantomData<(E, R)>,
33}
34
35impl<E, P, R> Policy<E> for Bc<E, P, R>
36where
37    E: Env,
38    P: SubModel1<Output = Tensor>,
39    E::Obs: Into<P::Input>,
40    E::Act: From<P::Output>,
41    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
42{
43    /// Sample an action.
44    ///
45    /// When `action_type` is set to [`BcActionType::Discrete`], this method returns the action
46    /// corresponding to the argmax of the policy model's output tensor.
47    /// On the other hand, when `action_type` is set to [`BcActionType::Continuous`], this method
48    /// returns the output tensor as is.
49    fn sample(&mut self, obs: &E::Obs) -> E::Act {
50        let a = self.policy_model.forward(&obs.clone().into()).detach();
51        match self.action_type {
52            BcActionType::Discrete => {
53                let a = a.argmax(D::Minus1).unwrap().to_dtype(DType::I64).unwrap();
54                a.into()
55            }
56            BcActionType::Continuous => a.into(),
57        }
58    }
59}
60
61impl<E, P, R> Configurable for Bc<E, P, R>
62where
63    E: Env,
64    P: SubModel1<Output = Tensor>,
65    E::Obs: Into<P::Input>,
66    E::Act: From<P::Output>,
67    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
68{
69    type Config = BcConfig<P>;
70
71    /// Constructs DQN agent.
72    fn build(config: Self::Config) -> Self {
73        let device: Device = config
74            .device
75            .expect("No device is given for DQN agent")
76            .into();
77        let policy_model =
78            BcModel::build(config.policy_model_config.clone(), device.clone()).unwrap();
79
80        Self {
81            policy_model,
82            batch_size: config.batch_size,
83            action_type: config.action_type,
84            device,
85            record_verbose_level: config.record_verbose_level,
86            phantom: PhantomData,
87        }
88    }
89}
90
91impl<E, P, R> Agent<E, R> for Bc<E, P, R>
92where
93    E: Env,
94    P: SubModel1<Output = Tensor>,
95    R: ReplayBufferBase,
96    E::Obs: Into<P::Input>,
97    E::Act: From<P::Output>,
98    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
99    R::Batch: TransitionBatch,
100    <R::Batch as TransitionBatch>::ObsBatch: Into<P::Input>,
101    <R::Batch as TransitionBatch>::ActBatch: Into<Tensor>,
102{
103    /// For BC agent, this method does nothing.
104    fn train(&mut self) {}
105
106    /// For BC agent, this method does nothing.
107    fn eval(&mut self) {}
108
109    /// For BC agent, this method always returns `false`.
110    fn is_train(&self) -> bool {
111        false
112    }
113
114    fn opt(&mut self, buffer: &mut R) {
115        self.opt_(buffer);
116    }
117
118    fn opt_with_record(&mut self, buffer: &mut R) -> Record {
119        let record = {
120            let record = self.opt_(buffer);
121
122            match self.record_verbose_level >= 2 {
123                true => {
124                    let record_weights = self.policy_model.param_stats();
125                    let record = record.merge(record_weights);
126                    record
127                }
128                false => record,
129            }
130        };
131
132        record
133    }
134
135    /// Save model parameters in the given directory.
136    ///
137    /// The parameters of the policy_model are saved as `policy_model.pt`.
138    fn save_params(&self, path: &Path) -> Result<Vec<PathBuf>> {
139        // TODO: consider to rename the path if it already exists
140        fs::create_dir_all(&path)?;
141        let path = path.join("policy_model.pt").to_path_buf();
142        self.policy_model.save(&path)?;
143        Ok(vec![path])
144    }
145
146    /// Load model parameters in the given directory.
147    ///
148    /// The parameters of the policy_model are loaded from `policy_model.pt`.
149    fn load_params(&mut self, path: &Path) -> Result<()> {
150        self.policy_model
151            .load(&path.join("policy_model.pt").as_path())?;
152        Ok(())
153    }
154}
155
156impl<E, P, R> Bc<E, P, R>
157where
158    E: Env,
159    P: SubModel1<Output = Tensor>,
160    R: ReplayBufferBase,
161    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
162    R::Batch: TransitionBatch,
163    <R::Batch as TransitionBatch>::ObsBatch: Into<P::Input>,
164    <R::Batch as TransitionBatch>::ActBatch: Into<Tensor>,
165{
166    // Currently, this method supports only continuous action.
167    fn opt_(&mut self, buffer: &mut R) -> Record {
168        let batch = buffer.batch(self.batch_size).unwrap();
169        let (obs, act, _, _, _, _, _, _) = batch.unpack();
170        let obs = obs.into();
171        let act = act.into().to_device(&self.device).unwrap();
172        let loss = match self.action_type {
173            BcActionType::Discrete => {
174                panic!();
175            }
176            BcActionType::Continuous => {
177                let act_ = self.policy_model.forward(&obs);
178                mse(&act_, &act)
179            }
180        }
181        .unwrap();
182        self.policy_model.backward_step(&loss).unwrap();
183
184        let mut record = Record::empty();
185        record.insert(
186            "loss",
187            RecordValue::Scalar(
188                loss.to_device(&Device::Cpu)
189                    .expect("Error when moving loss to CPU")
190                    .mean_all()
191                    .unwrap()
192                    .to_scalar()
193                    .unwrap(),
194            ),
195        );
196        record
197    }
198}