border_tch_agent/sac/
base.rs

1use super::{Actor, Critic, EntCoef, SacConfig};
2use crate::{
3    model::{ModelBase, SubModel, SubModel2},
4    util::{track, CriticLoss, OutDim},
5};
6use anyhow::Result;
7use border_core::{
8    record::{Record, RecordValue},
9    Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch,
10};
11use serde::{de::DeserializeOwned, Serialize};
12// use log::info;
13use std::{convert::TryFrom, fs, marker::PhantomData, path::Path};
14use tch::{no_grad, Tensor};
15
16type ActionValue = Tensor;
17type ActMean = Tensor;
18type ActStd = Tensor;
19
20fn normal_logp(x: &Tensor) -> Tensor {
21    let tmp: Tensor = Tensor::from(-0.5 * (2.0 * std::f32::consts::PI).ln() as f32)
22        - 0.5 * x.pow_tensor_scalar(2);
23    tmp.sum_dim_intlist(Some([-1].as_slice()), false, tch::Kind::Float)
24}
25
26/// Soft actor critic (SAC) agent.
27pub struct Sac<E, Q, P, R>
28where
29    Q: SubModel2<Output = ActionValue>,
30    P: SubModel<Output = (ActMean, ActStd)>,
31    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
32    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
33{
34    pub(super) qnets: Vec<Critic<Q>>,
35    pub(super) qnets_tgt: Vec<Critic<Q>>,
36    pub(super) pi: Actor<P>,
37    pub(super) gamma: f64,
38    pub(super) tau: f64,
39    pub(super) ent_coef: EntCoef,
40    pub(super) epsilon: f64,
41    pub(super) min_lstd: f64,
42    pub(super) max_lstd: f64,
43    pub(super) n_updates_per_opt: usize,
44    pub(super) batch_size: usize,
45    pub(super) train: bool,
46    pub(super) reward_scale: f32,
47    pub(super) n_opts: usize,
48    pub(super) critic_loss: CriticLoss,
49    pub(super) phantom: PhantomData<(E, R)>,
50    pub(super) device: tch::Device,
51}
52
53impl<E, Q, P, R> Sac<E, Q, P, R>
54where
55    E: Env,
56    Q: SubModel2<Output = ActionValue>,
57    P: SubModel<Output = (ActMean, ActStd)>,
58    R: ReplayBufferBase,
59    E::Obs: Into<Q::Input1> + Into<P::Input>,
60    E::Act: Into<Q::Input2>,
61    Q::Input2: From<ActMean>,
62    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
63    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
64    R::Batch: TransitionBatch,
65    <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input1> + Into<P::Input> + Clone,
66    <R::Batch as TransitionBatch>::ActBatch: Into<Q::Input2> + Into<Tensor>,
67{
68    fn action_logp(&self, o: &P::Input) -> (Tensor, Tensor) {
69        let (mean, lstd) = self.pi.forward(o);
70        let std = lstd.clip(self.min_lstd, self.max_lstd).exp();
71        let z = Tensor::randn(mean.size().as_slice(), tch::kind::FLOAT_CPU).to(self.device);
72        let a = (&std * &z + &mean).tanh();
73        let log_p = normal_logp(&z)
74            - (Tensor::from(1f32) - a.pow_tensor_scalar(2.0) + Tensor::from(self.epsilon))
75                .log()
76                .sum_dim_intlist(Some([-1].as_slice()), false, tch::Kind::Float);
77
78        debug_assert_eq!(a.size().as_slice()[0], self.batch_size as i64);
79        debug_assert_eq!(log_p.size().as_slice(), [self.batch_size as i64]);
80
81        (a, log_p)
82    }
83
84    fn qvals(&self, qnets: &[Critic<Q>], obs: &Q::Input1, act: &Q::Input2) -> Vec<Tensor> {
85        qnets
86            .iter()
87            .map(|qnet| qnet.forward(obs, act).squeeze())
88            .collect()
89    }
90
91    /// Returns the minimum values of q values over critics
92    fn qvals_min(&self, qnets: &[Critic<Q>], obs: &Q::Input1, act: &Q::Input2) -> Tensor {
93        let qvals = self.qvals(qnets, obs, act);
94        let qvals = Tensor::vstack(&qvals);
95        let qvals_min = qvals.min_dim(0, false).0;
96
97        debug_assert_eq!(qvals_min.size().as_slice(), [self.batch_size as i64]);
98
99        qvals_min
100    }
101
102    fn update_critic(&mut self, batch: R::Batch) -> f32 {
103        let losses = {
104            let (obs, act, next_obs, reward, is_terminated, _is_truncated, _, _) = batch.unpack();
105            let reward = Tensor::from_slice(&reward[..]).to(self.device);
106            let is_terminated = Tensor::from_slice(&is_terminated[..]).to(self.device);
107
108            let preds = self.qvals(&self.qnets, &obs.into(), &act.into());
109            let tgt = {
110                let next_q = no_grad(|| {
111                    let (next_a, next_log_p) = self.action_logp(&next_obs.clone().into());
112                    let next_q = self.qvals_min(&self.qnets_tgt, &next_obs.into(), &next_a.into());
113                    next_q - self.ent_coef.alpha() * next_log_p
114                });
115                self.reward_scale * reward
116                    + (1f32 - &is_terminated) * Tensor::from(self.gamma) * next_q
117            };
118
119            debug_assert_eq!(tgt.size().as_slice(), [self.batch_size as i64]);
120
121            let losses: Vec<_> = match self.critic_loss {
122                CriticLoss::Mse => preds
123                    .iter()
124                    .map(|pred| pred.mse_loss(&tgt, tch::Reduction::Mean))
125                    .collect(),
126                CriticLoss::SmoothL1 => preds
127                    .iter()
128                    .map(|pred| pred.smooth_l1_loss(&tgt, tch::Reduction::Mean, 1.0))
129                    .collect(),
130            };
131            losses
132        };
133
134        for (qnet, loss) in self.qnets.iter_mut().zip(&losses) {
135            qnet.backward_step(&loss);
136        }
137
138        losses
139            .iter()
140            .map(f32::try_from)
141            .map(|a| a.expect("Failed to convert Tensor to f32"))
142            .sum::<f32>()
143            / (self.qnets.len() as f32)
144    }
145
146    fn update_actor(&mut self, batch: &R::Batch) -> f32 {
147        let loss = {
148            let o = batch.obs().clone();
149            let (a, log_p) = self.action_logp(&o.into());
150
151            // Update the entropy coefficient
152            self.ent_coef.update(&log_p.detach());
153
154            let o = batch.obs().clone();
155            let qval = self.qvals_min(&self.qnets, &o.into(), &a.into());
156            (self.ent_coef.alpha().detach() * &log_p - &qval).mean(tch::Kind::Float)
157        };
158
159        self.pi.backward_step(&loss);
160
161        f32::try_from(loss).expect("Failed to convert Tensor to f32")
162    }
163
164    fn soft_update(&mut self) {
165        for (qnet_tgt, qnet) in self.qnets_tgt.iter_mut().zip(&mut self.qnets) {
166            track(qnet_tgt, qnet, self.tau);
167        }
168    }
169
170    fn opt_(&mut self, buffer: &mut R) -> Record {
171        let mut loss_critic = 0f32;
172        let mut loss_actor = 0f32;
173
174        for _ in 0..self.n_updates_per_opt {
175            let batch = buffer.batch(self.batch_size).unwrap();
176            loss_actor += self.update_actor(&batch);
177            loss_critic += self.update_critic(batch);
178            self.soft_update();
179            self.n_opts += 1;
180        }
181
182        loss_critic /= self.n_updates_per_opt as f32;
183        loss_actor /= self.n_updates_per_opt as f32;
184
185        Record::from_slice(&[
186            ("loss_critic", RecordValue::Scalar(loss_critic)),
187            ("loss_actor", RecordValue::Scalar(loss_actor)),
188            (
189                "ent_coef",
190                RecordValue::Scalar(self.ent_coef.alpha().double_value(&[0]) as f32),
191            ),
192        ])
193    }
194
195    pub fn get_policy_net(&self) -> &Actor<P> {
196        &self.pi
197    }
198}
199
200impl<E, Q, P, R> Policy<E> for Sac<E, Q, P, R>
201where
202    E: Env,
203    Q: SubModel2<Output = ActionValue>,
204    P: SubModel<Output = (ActMean, ActStd)>,
205    E::Obs: Into<Q::Input1> + Into<P::Input>,
206    E::Act: Into<Q::Input2> + From<Tensor>,
207    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
208    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
209{
210    fn sample(&mut self, obs: &E::Obs) -> E::Act {
211        let obs = obs.clone().into();
212        let (mean, lstd) = self.pi.forward(&obs);
213        let std = lstd.clip(self.min_lstd, self.max_lstd).exp();
214        let act = if self.train {
215            std * Tensor::randn(&mean.size(), tch::kind::FLOAT_CPU).to(self.device) + mean
216        } else {
217            mean
218        };
219        act.tanh().into()
220    }
221}
222
223impl<E, Q, P, R> Configurable<E> for Sac<E, Q, P, R>
224where
225    E: Env,
226    Q: SubModel2<Output = ActionValue>,
227    P: SubModel<Output = (ActMean, ActStd)>,
228    E::Obs: Into<Q::Input1> + Into<P::Input>,
229    E::Act: Into<Q::Input2> + From<Tensor>,
230    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
231    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
232{
233    type Config = SacConfig<Q, P>;
234
235    /// Constructs [`Sac`] agent.
236    fn build(config: Self::Config) -> Self {
237        let device = config
238            .device
239            .expect("No device is given for SAC agent")
240            .into();
241        let n_critics = config.n_critics;
242        let pi = Actor::build(config.actor_config, device).unwrap();
243        let mut qnets = vec![];
244        let mut qnets_tgt = vec![];
245        for _ in 0..n_critics {
246            let critic = Critic::build(config.critic_config.clone(), device).unwrap();
247            qnets.push(critic.clone());
248            qnets_tgt.push(critic);
249        }
250
251        if let Some(seed) = config.seed.as_ref() {
252            tch::manual_seed(*seed);
253        }
254
255        Sac {
256            qnets,
257            qnets_tgt,
258            pi,
259            gamma: config.gamma,
260            tau: config.tau,
261            ent_coef: EntCoef::new(config.ent_coef_mode, device),
262            epsilon: config.epsilon,
263            min_lstd: config.min_lstd,
264            max_lstd: config.max_lstd,
265            n_updates_per_opt: config.n_updates_per_opt,
266            batch_size: config.batch_size,
267            train: config.train,
268            reward_scale: config.reward_scale,
269            critic_loss: config.critic_loss,
270            n_opts: 0,
271            device,
272            phantom: PhantomData,
273        }
274    }
275}
276
277impl<E, Q, P, R> Agent<E, R> for Sac<E, Q, P, R>
278where
279    E: Env,
280    Q: SubModel2<Output = ActionValue>,
281    P: SubModel<Output = (ActMean, ActStd)>,
282    R: ReplayBufferBase,
283    E::Obs: Into<Q::Input1> + Into<P::Input>,
284    E::Act: Into<Q::Input2> + From<Tensor>,
285    Q::Input2: From<ActMean>,
286    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
287    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
288    R::Batch: TransitionBatch,
289    <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input1> + Into<P::Input> + Clone,
290    <R::Batch as TransitionBatch>::ActBatch: Into<Q::Input2> + Into<Tensor>,
291{
292    fn train(&mut self) {
293        self.train = true;
294    }
295
296    fn eval(&mut self) {
297        self.train = false;
298    }
299
300    fn is_train(&self) -> bool {
301        self.train
302    }
303
304    fn opt_with_record(&mut self, buffer: &mut R) -> Record {
305        self.opt_(buffer)
306    }
307
308    fn save_params<T: AsRef<Path>>(&self, path: T) -> Result<()> {
309        // TODO: consider to rename the path if it already exists
310        fs::create_dir_all(&path)?;
311        for (i, (qnet, qnet_tgt)) in self.qnets.iter().zip(&self.qnets_tgt).enumerate() {
312            qnet.save(&path.as_ref().join(format!("qnet_{}.pt.tch", i)).as_path())?;
313            qnet_tgt.save(
314                &path
315                    .as_ref()
316                    .join(format!("qnet_tgt_{}.pt.tch", i))
317                    .as_path(),
318            )?;
319        }
320        self.pi.save(&path.as_ref().join("pi.pt.tch").as_path())?;
321        self.ent_coef
322            .save(&path.as_ref().join("ent_coef.pt.tch").as_path())?;
323        Ok(())
324    }
325
326    fn load_params<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
327        for (i, (qnet, qnet_tgt)) in self.qnets.iter_mut().zip(&mut self.qnets_tgt).enumerate() {
328            qnet.load(&path.as_ref().join(format!("qnet_{}.pt.tch", i)).as_path())?;
329            qnet_tgt.load(
330                &path
331                    .as_ref()
332                    .join(format!("qnet_tgt_{}.pt.tch", i))
333                    .as_path(),
334            )?;
335        }
336        self.pi.load(&path.as_ref().join("pi.pt.tch").as_path())?;
337        self.ent_coef
338            .load(&path.as_ref().join("ent_coef.pt.tch").as_path())?;
339        Ok(())
340    }
341}
342
343#[cfg(feature = "border-async-trainer")]
344use {crate::util::NamedTensors, border_async_trainer::SyncModel};
345
346#[cfg(feature = "border-async-trainer")]
347impl<E, Q, P, R> SyncModel for Sac<E, Q, P, R>
348where
349    E: Env,
350    Q: SubModel2<Output = ActionValue>,
351    P: SubModel<Output = (ActMean, ActStd)>,
352    R: ReplayBufferBase,
353    E::Obs: Into<Q::Input1> + Into<P::Input>,
354    E::Act: Into<Q::Input2> + From<Tensor>,
355    Q::Input2: From<ActMean>,
356    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
357    P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
358    R::Batch: TransitionBatch,
359    <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input1> + Into<P::Input> + Clone,
360    <R::Batch as TransitionBatch>::ActBatch: Into<Q::Input2> + Into<Tensor>,
361{
362    type ModelInfo = NamedTensors;
363
364    fn model_info(&self) -> (usize, Self::ModelInfo) {
365        (
366            self.n_opts,
367            NamedTensors::copy_from(self.pi.get_var_store()),
368        )
369    }
370
371    fn sync_model(&mut self, model_info: &Self::ModelInfo) {
372        model_info.copy_to(self.pi.get_var_store_mut());
373    }
374}