border_tch_agent/sac/
base.rs

1use super::{Actor, Critic, EntCoef, SacConfig};
2use crate::{
3    model::{ModelBase, SubModel, SubModel2},
4    util::{track, CriticLoss, OutDim},
5};
6use anyhow::Result;
7use border_core::{
8    record::{Record, RecordValue},
9    Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch,
10};
11use serde::{de::DeserializeOwned, Serialize};
12// use log::info;
13use std::{
14    convert::TryFrom,
15    fs,
16    marker::PhantomData,
17    path::{Path, PathBuf},
18};
19use tch::{no_grad, Tensor};
20
21type ActionValue = Tensor;
22type ActMean = Tensor;
23type ActStd = Tensor;
24
25fn normal_logp(x: &Tensor) -> Tensor {
26    let tmp: Tensor = Tensor::from(-0.5 * (2.0 * std::f32::consts::PI).ln() as f32)
27        - 0.5 * x.pow_tensor_scalar(2);
28    tmp.sum_dim_intlist(Some([-1].as_slice()), false, tch::Kind::Float)
29}
30
31/// Soft actor critic (SAC) agent.
32pub struct Sac<E, Q, P, R>
33where
34    Q: SubModel2<Output = ActionValue>,
35    P: SubModel<Output = (ActMean, ActStd)>,
36    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
37    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
38{
39    pub(super) qnets: Vec<Critic<Q>>,
40    pub(super) qnets_tgt: Vec<Critic<Q>>,
41    pub(super) pi: Actor<P>,
42    pub(super) gamma: f64,
43    pub(super) tau: f64,
44    pub(super) ent_coef: EntCoef,
45    pub(super) epsilon: f64,
46    pub(super) min_lstd: f64,
47    pub(super) max_lstd: f64,
48    pub(super) n_updates_per_opt: usize,
49    pub(super) batch_size: usize,
50    pub(super) train: bool,
51    pub(super) reward_scale: f32,
52    pub(super) n_opts: usize,
53    pub(super) critic_loss: CriticLoss,
54    pub(super) phantom: PhantomData<(E, R)>,
55    pub(super) device: tch::Device,
56}
57
58impl<E, Q, P, R> Sac<E, Q, P, R>
59where
60    E: Env,
61    Q: SubModel2<Output = ActionValue>,
62    P: SubModel<Output = (ActMean, ActStd)>,
63    R: ReplayBufferBase,
64    E::Obs: Into<Q::Input1> + Into<P::Input>,
65    E::Act: Into<Q::Input2>,
66    Q::Input2: From<ActMean>,
67    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
68    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
69    R::Batch: TransitionBatch,
70    <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input1> + Into<P::Input> + Clone,
71    <R::Batch as TransitionBatch>::ActBatch: Into<Q::Input2> + Into<Tensor>,
72{
73    fn action_logp(&self, o: &P::Input) -> (Tensor, Tensor) {
74        let (mean, lstd) = self.pi.forward(o);
75        let std = lstd.clip(self.min_lstd, self.max_lstd).exp();
76        let z = Tensor::randn(mean.size().as_slice(), tch::kind::FLOAT_CPU).to(self.device);
77        let a = (&std * &z + &mean).tanh();
78        let log_p = normal_logp(&z)
79            - (Tensor::from(1f32) - a.pow_tensor_scalar(2.0) + Tensor::from(self.epsilon))
80                .log()
81                .sum_dim_intlist(Some([-1].as_slice()), false, tch::Kind::Float);
82
83        debug_assert_eq!(a.size().as_slice()[0], self.batch_size as i64);
84        debug_assert_eq!(log_p.size().as_slice(), [self.batch_size as i64]);
85
86        (a, log_p)
87    }
88
89    fn qvals(&self, qnets: &[Critic<Q>], obs: &Q::Input1, act: &Q::Input2) -> Vec<Tensor> {
90        qnets
91            .iter()
92            .map(|qnet| qnet.forward(obs, act).squeeze())
93            .collect()
94    }
95
96    /// Returns the minimum values of q values over critics
97    fn qvals_min(&self, qnets: &[Critic<Q>], obs: &Q::Input1, act: &Q::Input2) -> Tensor {
98        let qvals = self.qvals(qnets, obs, act);
99        let qvals = Tensor::vstack(&qvals);
100        let qvals_min = qvals.min_dim(0, false).0;
101
102        debug_assert_eq!(qvals_min.size().as_slice(), [self.batch_size as i64]);
103
104        qvals_min
105    }
106
107    fn update_critic(&mut self, batch: R::Batch) -> f32 {
108        let losses = {
109            let (obs, act, next_obs, reward, is_terminated, _is_truncated, _, _) = batch.unpack();
110            let reward = Tensor::from_slice(&reward[..]).to(self.device);
111            let is_terminated = Tensor::from_slice(&is_terminated[..]).to(self.device);
112
113            let preds = self.qvals(&self.qnets, &obs.into(), &act.into());
114            let tgt = {
115                let next_q = no_grad(|| {
116                    let (next_a, next_log_p) = self.action_logp(&next_obs.clone().into());
117                    let next_q = self.qvals_min(&self.qnets_tgt, &next_obs.into(), &next_a.into());
118                    next_q - self.ent_coef.alpha() * next_log_p
119                });
120                self.reward_scale * reward
121                    + (1f32 - &is_terminated) * Tensor::from(self.gamma) * next_q
122            };
123
124            debug_assert_eq!(tgt.size().as_slice(), [self.batch_size as i64]);
125
126            let losses: Vec<_> = match self.critic_loss {
127                CriticLoss::Mse => preds
128                    .iter()
129                    .map(|pred| pred.mse_loss(&tgt, tch::Reduction::Mean))
130                    .collect(),
131                CriticLoss::SmoothL1 => preds
132                    .iter()
133                    .map(|pred| pred.smooth_l1_loss(&tgt, tch::Reduction::Mean, 1.0))
134                    .collect(),
135            };
136            losses
137        };
138
139        for (qnet, loss) in self.qnets.iter_mut().zip(&losses) {
140            qnet.backward_step(&loss);
141        }
142
143        losses
144            .iter()
145            .map(f32::try_from)
146            .map(|a| a.expect("Failed to convert Tensor to f32"))
147            .sum::<f32>()
148            / (self.qnets.len() as f32)
149    }
150
151    fn update_actor(&mut self, batch: &R::Batch) -> f32 {
152        let loss = {
153            let o = batch.obs().clone();
154            let (a, log_p) = self.action_logp(&o.into());
155
156            // Update the entropy coefficient
157            self.ent_coef.update(&log_p.detach());
158
159            let o = batch.obs().clone();
160            let qval = self.qvals_min(&self.qnets, &o.into(), &a.into());
161            (self.ent_coef.alpha().detach() * &log_p - &qval).mean(tch::Kind::Float)
162        };
163
164        self.pi.backward_step(&loss);
165
166        f32::try_from(loss).expect("Failed to convert Tensor to f32")
167    }
168
169    fn soft_update(&mut self) {
170        for (qnet_tgt, qnet) in self.qnets_tgt.iter_mut().zip(&mut self.qnets) {
171            track(qnet_tgt, qnet, self.tau);
172        }
173    }
174
175    fn opt_(&mut self, buffer: &mut R) -> Record {
176        let mut loss_critic = 0f32;
177        let mut loss_actor = 0f32;
178
179        for _ in 0..self.n_updates_per_opt {
180            let batch = buffer.batch(self.batch_size).unwrap();
181            loss_actor += self.update_actor(&batch);
182            loss_critic += self.update_critic(batch);
183            self.soft_update();
184            self.n_opts += 1;
185        }
186
187        loss_critic /= self.n_updates_per_opt as f32;
188        loss_actor /= self.n_updates_per_opt as f32;
189
190        Record::from_slice(&[
191            ("loss_critic", RecordValue::Scalar(loss_critic)),
192            ("loss_actor", RecordValue::Scalar(loss_actor)),
193            (
194                "ent_coef",
195                RecordValue::Scalar(self.ent_coef.alpha().double_value(&[0]) as f32),
196            ),
197        ])
198    }
199
200    pub fn get_policy_net(&self) -> &Actor<P> {
201        &self.pi
202    }
203}
204
205impl<E, Q, P, R> Policy<E> for Sac<E, Q, P, R>
206where
207    E: Env,
208    Q: SubModel2<Output = ActionValue>,
209    P: SubModel<Output = (ActMean, ActStd)>,
210    E::Obs: Into<Q::Input1> + Into<P::Input>,
211    E::Act: Into<Q::Input2> + From<Tensor>,
212    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
213    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
214{
215    fn sample(&mut self, obs: &E::Obs) -> E::Act {
216        let obs = obs.clone().into();
217        let (mean, lstd) = self.pi.forward(&obs);
218        let std = lstd.clip(self.min_lstd, self.max_lstd).exp();
219        let act = if self.train {
220            std * Tensor::randn(&mean.size(), tch::kind::FLOAT_CPU).to(self.device) + mean
221        } else {
222            mean
223        };
224        act.tanh().into()
225    }
226}
227
228impl<E, Q, P, R> Configurable for Sac<E, Q, P, R>
229where
230    E: Env,
231    Q: SubModel2<Output = ActionValue>,
232    P: SubModel<Output = (ActMean, ActStd)>,
233    E::Obs: Into<Q::Input1> + Into<P::Input>,
234    E::Act: Into<Q::Input2> + From<Tensor>,
235    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
236    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
237{
238    type Config = SacConfig<Q, P>;
239
240    /// Constructs [`Sac`] agent.
241    fn build(config: Self::Config) -> Self {
242        let device = config
243            .device
244            .expect("No device is given for SAC agent")
245            .into();
246        let n_critics = config.n_critics;
247        let pi = Actor::build(config.actor_config, device).unwrap();
248        let mut qnets = vec![];
249        let mut qnets_tgt = vec![];
250        for _ in 0..n_critics {
251            let critic = Critic::build(config.critic_config.clone(), device).unwrap();
252            qnets.push(critic.clone());
253            qnets_tgt.push(critic);
254        }
255
256        if let Some(seed) = config.seed.as_ref() {
257            tch::manual_seed(*seed);
258        }
259
260        Sac {
261            qnets,
262            qnets_tgt,
263            pi,
264            gamma: config.gamma,
265            tau: config.tau,
266            ent_coef: EntCoef::new(config.ent_coef_mode, device),
267            epsilon: config.epsilon,
268            min_lstd: config.min_lstd,
269            max_lstd: config.max_lstd,
270            n_updates_per_opt: config.n_updates_per_opt,
271            batch_size: config.batch_size,
272            train: config.train,
273            reward_scale: config.reward_scale,
274            critic_loss: config.critic_loss,
275            n_opts: 0,
276            device,
277            phantom: PhantomData,
278        }
279    }
280}
281
282impl<E, Q, P, R> Agent<E, R> for Sac<E, Q, P, R>
283where
284    E: Env + 'static,
285    Q: SubModel2<Output = ActionValue> + 'static,
286    P: SubModel<Output = (ActMean, ActStd)> + 'static,
287    R: ReplayBufferBase + 'static,
288    E::Obs: Into<Q::Input1> + Into<P::Input>,
289    E::Act: Into<Q::Input2> + From<Tensor>,
290    Q::Input2: From<ActMean>,
291    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
292    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
293    R::Batch: TransitionBatch,
294    <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input1> + Into<P::Input> + Clone,
295    <R::Batch as TransitionBatch>::ActBatch: Into<Q::Input2> + Into<Tensor>,
296{
297    fn train(&mut self) {
298        self.train = true;
299    }
300
301    fn eval(&mut self) {
302        self.train = false;
303    }
304
305    fn is_train(&self) -> bool {
306        self.train
307    }
308
309    fn opt_with_record(&mut self, buffer: &mut R) -> Record {
310        self.opt_(buffer)
311    }
312
313    fn save_params(&self, path: &Path) -> Result<Vec<PathBuf>> {
314        // TODO: consider to rename the path if it already exists
315        fs::create_dir_all(&path)?;
316        let mut paths = vec![];
317
318        for (i, (qnet, qnet_tgt)) in self.qnets.iter().zip(&self.qnets_tgt).enumerate() {
319            let path1 = path.join(format!("qnet_{}.pt.tch", i)).to_path_buf();
320            let path2 = path.join(format!("qnet_tgt_{}.pt.tch", i)).to_path_buf();
321            qnet.save(&path1)?;
322            qnet_tgt.save(&path2)?;
323            paths.push(path1);
324            paths.push(path2);
325        }
326
327        let path_actor = path.join("pi.pt.tch").to_path_buf();
328        let path_ent_coef = path.join("ent_coef.pt.tch").to_path_buf();
329        self.pi.save(&path_actor)?;
330        self.ent_coef.save(&path_ent_coef)?;
331        paths.push(path_actor);
332        paths.push(path_ent_coef);
333
334        Ok(paths)
335    }
336
337    fn load_params(&mut self, path: &Path) -> Result<()> {
338        for (i, (qnet, qnet_tgt)) in self.qnets.iter_mut().zip(&mut self.qnets_tgt).enumerate() {
339            qnet.load(path.join(format!("qnet_{}.pt.tch", i)).as_path())?;
340            qnet_tgt.load(path.join(format!("qnet_tgt_{}.pt.tch", i)).as_path())?;
341        }
342        self.pi.load(path.join("pi.pt.tch").as_path())?;
343        self.ent_coef.load(path.join("ent_coef.pt.tch").as_path())?;
344        Ok(())
345    }
346
347    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
348        self
349    }
350
351    fn as_any_ref(&self) -> &dyn std::any::Any {
352        self
353    }
354}
355
356#[cfg(feature = "border-async-trainer")]
357use {crate::util::NamedTensors, border_async_trainer::SyncModel};
358
359#[cfg(feature = "border-async-trainer")]
360impl<E, Q, P, R> SyncModel for Sac<E, Q, P, R>
361where
362    E: Env,
363    Q: SubModel2<Output = ActionValue>,
364    P: SubModel<Output = (ActMean, ActStd)>,
365    R: ReplayBufferBase,
366    E::Obs: Into<Q::Input1> + Into<P::Input>,
367    E::Act: Into<Q::Input2> + From<Tensor>,
368    Q::Input2: From<ActMean>,
369    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
370    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
371    R::Batch: TransitionBatch,
372    <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input1> + Into<P::Input> + Clone,
373    <R::Batch as TransitionBatch>::ActBatch: Into<Q::Input2> + Into<Tensor>,
374{
375    type ModelInfo = NamedTensors;
376
377    fn model_info(&self) -> (usize, Self::ModelInfo) {
378        (
379            self.n_opts,
380            NamedTensors::copy_from(self.pi.get_var_store()),
381        )
382    }
383
384    fn sync_model(&mut self, model_info: &Self::ModelInfo) {
385        model_info.copy_to(self.pi.get_var_store_mut());
386    }
387}