border_candle_agent/awac/
base.rs

1use super::AwacConfig;
2use crate::{
3    model::{SubModel1, SubModel2},
4    util::{
5        actor::GaussianActor, critic::MultiCritic, gamma_not_done, smooth_l1_loss, CriticLoss,
6        OutDim,
7    },
8};
9use anyhow::Result;
10use border_core::{
11    record::{Record, RecordValue},
12    Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch,
13};
14use candle_core::{Device, Tensor, D};
15use candle_nn::{loss::mse, ops::softmax};
16use serde::{de::DeserializeOwned, Serialize};
17use std::{
18    fs,
19    marker::PhantomData,
20    path::{Path, PathBuf},
21};
22
23type ActionValue = Tensor;
24type ActMean = Tensor;
25type ActStd = Tensor;
26
27/// Advantage weighted actor critic (AWAC) agent.
28pub struct Awac<E, Q, P, R>
29where
30    Q: SubModel2<Output = ActionValue>,
31    P: SubModel1<Output = (ActMean, ActStd)>,
32    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
33    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
34{
35    critic: MultiCritic<Q>,
36    actor: GaussianActor<P>,
37    gamma: f64,
38    inv_lambda: f64,
39    n_updates_per_opt: usize,
40    batch_size: usize,
41    train: bool,
42    // reward_scale: f32,
43    n_opts: usize,
44    exp_adv_max: f64,
45    critic_loss: CriticLoss,
46    phantom: PhantomData<(E, R)>,
47    device: Device,
48    adv_softmax: bool,
49}
50
51impl<E, Q, P, R> Awac<E, Q, P, R>
52where
53    E: Env,
54    Q: SubModel2<Output = ActionValue>,
55    P: SubModel1<Output = (ActMean, ActStd)>,
56    R: ReplayBufferBase,
57    E::Obs: Into<Q::Input1> + Into<P::Input>,
58    E::Act: Into<Q::Input2> + Into<Tensor>,
59    Q::Input2: From<ActMean> + Into<Tensor>,
60    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
61    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
62    R::Batch: TransitionBatch,
63    <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input1> + Into<P::Input> + Clone,
64    <R::Batch as TransitionBatch>::ActBatch: Into<Q::Input2> + Into<Tensor> + Clone,
65{
66    fn update_critic(&mut self, batch: R::Batch) -> Result<(f32, f32, f32, f32)> {
67        let (loss, q_tgt_abs_mean, reward_mean, next_q_mean) = {
68            // Extract items in the batch
69            let (obs, act, next_obs, reward, is_terminated, is_truncated, _, _) = batch.unpack();
70            let batch_size = reward.len();
71            let reward = Tensor::from_slice(&reward[..], (batch_size,), &self.device)?;
72
73            // Prediction
74            let qs = self.critic.qvals(&obs.into(), &act.into());
75
76            // Target
77            let (tgt, reward, next_q) = {
78                let gamma_not_done = gamma_not_done(
79                    self.gamma as f32,
80                    is_terminated,
81                    Some(is_truncated),
82                    &self.device,
83                )?;
84                let next_act = self.actor.sample(&next_obs.clone().into(), self.train)?;
85                let next_q = self
86                    .critic
87                    .qvals_min_tgt(&next_obs.into(), &next_act.into())?;
88                let tgt = (&reward + (&gamma_not_done * &next_q)?)?.squeeze(D::Minus1)?;
89
90                (tgt.detach(), reward, next_q)
91            };
92            debug_assert_eq!(tgt.dims(), [self.batch_size]);
93
94            // Loss
95            let losses: Vec<_> = match self.critic_loss {
96                CriticLoss::Mse => qs.iter().map(|pred| mse(&pred, &tgt).unwrap()).collect(),
97                CriticLoss::SmoothL1 => qs
98                    .iter()
99                    .map(|pred| smooth_l1_loss(&pred, &tgt).unwrap())
100                    .collect(),
101            };
102
103            // for debug
104            let q_tgt_abs_mean = tgt.abs()?.mean_all()?.to_scalar::<f32>()?;
105            let reward_mean = reward.mean_all()?.to_scalar::<f32>()?;
106            let next_q_mean = next_q.mean_all()?.to_scalar::<f32>()?;
107
108            (
109                Tensor::stack(&losses, 0)?.sum_all()?,
110                q_tgt_abs_mean,
111                reward_mean,
112                next_q_mean,
113            )
114        };
115
116        self.critic.backward_step(&loss)?;
117        self.critic.soft_update()?;
118
119        Ok((
120            loss.to_scalar::<f32>()?,
121            q_tgt_abs_mean,
122            reward_mean,
123            next_q_mean,
124        ))
125    }
126
127    fn update_actor(&mut self, batch: &R::Batch) -> Result<(f32, f32, f32, f32)> {
128        // Extract items in the batch
129        log::trace!("Extract items in the batch");
130        let obs = batch.obs().clone();
131        let act = batch.act().clone();
132
133        let (w, adv) = {
134            let act_ = self.actor.sample(&obs.clone().into(), self.train)?;
135            let q = self
136                .critic
137                .qvals_min(&obs.clone().into(), &act.clone().into())?;
138            let v = self.critic.qvals_min(&obs.clone().into(), &act_.into())?;
139            let adv = (&q - &v)?;
140            debug_assert_eq!(adv.dims(), &[self.batch_size]);
141
142            let w = match self.adv_softmax {
143                false => (&adv * self.inv_lambda)?
144                    .exp()?
145                    .clamp(0f64, self.exp_adv_max)?,
146                true => softmax(&(&adv * self.inv_lambda)?, 0)?,
147            }
148            .detach();
149            (w, adv)
150        };
151        debug_assert_eq!(w.dims(), &[self.batch_size]);
152
153        let (loss, logp) = {
154            let logp = self.actor.logp(&obs.into(), &act.into())?;
155            debug_assert_eq!(logp.dims(), &[self.batch_size]);
156
157            ((-1f64 * &logp * w)?.mean_all()?, logp)
158        };
159
160        self.actor.backward_step(&loss)?;
161
162        let loss = loss.to_scalar::<f32>()?;
163        let adv_mean = adv.mean_all()?.to_scalar::<f32>()?;
164        let adv_abs_mean = adv.abs()?.mean_all()?.to_scalar::<f32>()?;
165        let logp_mean = logp.mean_all()?.to_scalar::<f32>()?;
166
167        Ok((loss, adv_mean, adv_abs_mean, logp_mean))
168    }
169
170    fn opt_(&mut self, buffer: &mut R) -> Result<Record> {
171        let mut loss_critic = 0f32;
172        let mut loss_actor = 0f32;
173        let mut q_tgt_abs_mean = 0f32;
174        let mut adv_mean = 0f32;
175        let mut adv_abs_mean = 0f32;
176        let mut logp_mean = 0f32;
177        let mut reward_mean = 0f32;
178        let mut next_q_mean = 0f32;
179
180        for _ in 0..self.n_updates_per_opt {
181            let batch = buffer.batch(self.batch_size).unwrap();
182            let (loss_actor_, adv_mean_, adv_abs_mean_, logp_mean_) = self.update_actor(&batch)?;
183            loss_actor += loss_actor_;
184            adv_mean += adv_mean_;
185            adv_abs_mean += adv_abs_mean_;
186            logp_mean += logp_mean_;
187
188            let (loss_critic_, q_tgt_abs_mean_, reward_mean_, next_q_mean_) =
189                self.update_critic(batch)?;
190            loss_critic += loss_critic_;
191            q_tgt_abs_mean += q_tgt_abs_mean_;
192            reward_mean += reward_mean_;
193            next_q_mean += next_q_mean_;
194            self.n_opts += 1;
195        }
196
197        loss_critic /= self.n_updates_per_opt as f32;
198        loss_actor /= self.n_updates_per_opt as f32;
199        q_tgt_abs_mean /= self.n_updates_per_opt as f32;
200        adv_mean /= self.n_updates_per_opt as f32;
201        adv_abs_mean /= self.n_updates_per_opt as f32;
202
203        let record = Record::from_slice(&[
204            ("loss_critic", RecordValue::Scalar(loss_critic)),
205            ("loss_actor", RecordValue::Scalar(loss_actor)),
206            ("q_tgt_abs_mean", RecordValue::Scalar(q_tgt_abs_mean)),
207            ("adv_mean", RecordValue::Scalar(adv_mean)),
208            ("adv_abs_mean", RecordValue::Scalar(adv_abs_mean)),
209            ("logp_mean", RecordValue::Scalar(logp_mean)),
210            ("reward_mean", RecordValue::Scalar(reward_mean)),
211            ("next_q_mean", RecordValue::Scalar(next_q_mean)),
212        ]);
213
214        Ok(record)
215    }
216}
217
218impl<E, Q, P, R> Policy<E> for Awac<E, Q, P, R>
219where
220    E: Env,
221    Q: SubModel2<Output = ActionValue>,
222    P: SubModel1<Output = (ActMean, ActStd)>,
223    E::Obs: Into<Q::Input1> + Into<P::Input>,
224    E::Act: Into<Q::Input2> + From<Tensor>,
225    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
226    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
227{
228    fn sample(&mut self, obs: &E::Obs) -> E::Act {
229        self.actor
230            .sample(&obs.clone().into(), self.train)
231            .unwrap()
232            .into()
233    }
234}
235
236impl<E, Q, P, R> Configurable for Awac<E, Q, P, R>
237where
238    E: Env,
239    Q: SubModel2<Output = ActionValue>,
240    P: SubModel1<Output = (ActMean, ActStd)>,
241    E::Obs: Into<Q::Input1> + Into<P::Input>,
242    E::Act: Into<Q::Input2> + From<Tensor>,
243    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
244    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
245{
246    type Config = AwacConfig<Q, P>;
247
248    /// Constructs [`Awac`] agent.
249    fn build(config: Self::Config) -> Self {
250        let device: Device = config
251            .device
252            .expect("No device is given for AWAC agent")
253            .into();
254        let actor = GaussianActor::build(config.actor_config, device.clone().into()).unwrap();
255        let critics = MultiCritic::build(config.critic_config, device.clone().into()).unwrap();
256
257        Awac {
258            critic: critics,
259            actor,
260            gamma: config.gamma,
261            // action_min: config.action_min,
262            // action_max: config.action_max,
263            // min_lstd: config.min_lstd,
264            // max_lstd: config.max_lstd,
265            n_updates_per_opt: config.n_updates_per_opt,
266            batch_size: config.batch_size,
267            // reward_scale: config.reward_scale,
268            critic_loss: config.critic_loss,
269            inv_lambda: config.inv_lambda,
270            exp_adv_max: config.exp_adv_max,
271            n_opts: 0,
272            train: false,
273            device: device.into(),
274            adv_softmax: config.adv_softmax,
275            phantom: PhantomData,
276        }
277    }
278}
279
280impl<E, Q, P, R> Agent<E, R> for Awac<E, Q, P, R>
281where
282    E: Env + 'static,
283    Q: SubModel2<Output = ActionValue> + 'static,
284    P: SubModel1<Output = (ActMean, ActStd)> + 'static,
285    R: ReplayBufferBase + 'static,
286    E::Obs: Into<Q::Input1> + Into<P::Input>,
287    E::Act: Into<Q::Input2> + Into<Tensor> + From<Tensor>,
288    Q::Input2: From<ActMean> + Into<Tensor>,
289    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
290    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
291    R::Batch: TransitionBatch,
292    <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input1> + Into<P::Input> + Clone,
293    <R::Batch as TransitionBatch>::ActBatch: Into<Q::Input2> + Into<Tensor> + Clone,
294{
295    fn train(&mut self) {
296        self.train = true;
297    }
298
299    fn eval(&mut self) {
300        self.train = false;
301    }
302
303    fn is_train(&self) -> bool {
304        self.train
305    }
306
307    fn opt_with_record(&mut self, buffer: &mut R) -> Record {
308        self.opt_(buffer).expect("Failed in Awac::opt_()")
309    }
310
311    fn save_params(&self, path: &Path) -> Result<Vec<PathBuf>> {
312        // TODO: consider to rename the path if it already exists
313        fs::create_dir_all(&path)?;
314
315        let actor_path = self.actor.save(path.join("actor"))?;
316        let (critic_path, critic_tgt_path) = self.critic.save(path.join("critic"))?;
317
318        Ok(vec![actor_path, critic_path, critic_tgt_path])
319    }
320
321    fn load_params(&mut self, path: &Path) -> Result<()> {
322        self.actor.load(path.join("actor").as_path())?;
323        self.critic.load(path.join("critic").as_path())?;
324
325        Ok(())
326    }
327
328    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
329        self
330    }
331
332    fn as_any_ref(&self) -> &dyn std::any::Any {
333        self
334    }
335}