1use super::{Actor, Critic, EntCoef, SacConfig};
2use crate::{
3 model::{ModelBase, SubModel, SubModel2},
4 util::{track, CriticLoss, OutDim},
5};
6use anyhow::Result;
7use border_core::{
8 record::{Record, RecordValue},
9 Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch,
10};
11use serde::{de::DeserializeOwned, Serialize};
12use std::{
14 convert::TryFrom,
15 fs,
16 marker::PhantomData,
17 path::{Path, PathBuf},
18};
19use tch::{no_grad, Tensor};
20
21type ActionValue = Tensor;
22type ActMean = Tensor;
23type ActStd = Tensor;
24
25fn normal_logp(x: &Tensor) -> Tensor {
26 let tmp: Tensor = Tensor::from(-0.5 * (2.0 * std::f32::consts::PI).ln() as f32)
27 - 0.5 * x.pow_tensor_scalar(2);
28 tmp.sum_dim_intlist(Some([-1].as_slice()), false, tch::Kind::Float)
29}
30
31pub struct Sac<E, Q, P, R>
33where
34 Q: SubModel2<Output = ActionValue>,
35 P: SubModel<Output = (ActMean, ActStd)>,
36 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
37 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
38{
39 pub(super) qnets: Vec<Critic<Q>>,
40 pub(super) qnets_tgt: Vec<Critic<Q>>,
41 pub(super) pi: Actor<P>,
42 pub(super) gamma: f64,
43 pub(super) tau: f64,
44 pub(super) ent_coef: EntCoef,
45 pub(super) epsilon: f64,
46 pub(super) min_lstd: f64,
47 pub(super) max_lstd: f64,
48 pub(super) n_updates_per_opt: usize,
49 pub(super) batch_size: usize,
50 pub(super) train: bool,
51 pub(super) reward_scale: f32,
52 pub(super) n_opts: usize,
53 pub(super) critic_loss: CriticLoss,
54 pub(super) phantom: PhantomData<(E, R)>,
55 pub(super) device: tch::Device,
56}
57
58impl<E, Q, P, R> Sac<E, Q, P, R>
59where
60 E: Env,
61 Q: SubModel2<Output = ActionValue>,
62 P: SubModel<Output = (ActMean, ActStd)>,
63 R: ReplayBufferBase,
64 E::Obs: Into<Q::Input1> + Into<P::Input>,
65 E::Act: Into<Q::Input2>,
66 Q::Input2: From<ActMean>,
67 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
68 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
69 R::Batch: TransitionBatch,
70 <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input1> + Into<P::Input> + Clone,
71 <R::Batch as TransitionBatch>::ActBatch: Into<Q::Input2> + Into<Tensor>,
72{
73 fn action_logp(&self, o: &P::Input) -> (Tensor, Tensor) {
74 let (mean, lstd) = self.pi.forward(o);
75 let std = lstd.clip(self.min_lstd, self.max_lstd).exp();
76 let z = Tensor::randn(mean.size().as_slice(), tch::kind::FLOAT_CPU).to(self.device);
77 let a = (&std * &z + &mean).tanh();
78 let log_p = normal_logp(&z)
79 - (Tensor::from(1f32) - a.pow_tensor_scalar(2.0) + Tensor::from(self.epsilon))
80 .log()
81 .sum_dim_intlist(Some([-1].as_slice()), false, tch::Kind::Float);
82
83 debug_assert_eq!(a.size().as_slice()[0], self.batch_size as i64);
84 debug_assert_eq!(log_p.size().as_slice(), [self.batch_size as i64]);
85
86 (a, log_p)
87 }
88
89 fn qvals(&self, qnets: &[Critic<Q>], obs: &Q::Input1, act: &Q::Input2) -> Vec<Tensor> {
90 qnets
91 .iter()
92 .map(|qnet| qnet.forward(obs, act).squeeze())
93 .collect()
94 }
95
96 fn qvals_min(&self, qnets: &[Critic<Q>], obs: &Q::Input1, act: &Q::Input2) -> Tensor {
98 let qvals = self.qvals(qnets, obs, act);
99 let qvals = Tensor::vstack(&qvals);
100 let qvals_min = qvals.min_dim(0, false).0;
101
102 debug_assert_eq!(qvals_min.size().as_slice(), [self.batch_size as i64]);
103
104 qvals_min
105 }
106
107 fn update_critic(&mut self, batch: R::Batch) -> f32 {
108 let losses = {
109 let (obs, act, next_obs, reward, is_terminated, _is_truncated, _, _) = batch.unpack();
110 let reward = Tensor::from_slice(&reward[..]).to(self.device);
111 let is_terminated = Tensor::from_slice(&is_terminated[..]).to(self.device);
112
113 let preds = self.qvals(&self.qnets, &obs.into(), &act.into());
114 let tgt = {
115 let next_q = no_grad(|| {
116 let (next_a, next_log_p) = self.action_logp(&next_obs.clone().into());
117 let next_q = self.qvals_min(&self.qnets_tgt, &next_obs.into(), &next_a.into());
118 next_q - self.ent_coef.alpha() * next_log_p
119 });
120 self.reward_scale * reward
121 + (1f32 - &is_terminated) * Tensor::from(self.gamma) * next_q
122 };
123
124 debug_assert_eq!(tgt.size().as_slice(), [self.batch_size as i64]);
125
126 let losses: Vec<_> = match self.critic_loss {
127 CriticLoss::Mse => preds
128 .iter()
129 .map(|pred| pred.mse_loss(&tgt, tch::Reduction::Mean))
130 .collect(),
131 CriticLoss::SmoothL1 => preds
132 .iter()
133 .map(|pred| pred.smooth_l1_loss(&tgt, tch::Reduction::Mean, 1.0))
134 .collect(),
135 };
136 losses
137 };
138
139 for (qnet, loss) in self.qnets.iter_mut().zip(&losses) {
140 qnet.backward_step(&loss);
141 }
142
143 losses
144 .iter()
145 .map(f32::try_from)
146 .map(|a| a.expect("Failed to convert Tensor to f32"))
147 .sum::<f32>()
148 / (self.qnets.len() as f32)
149 }
150
151 fn update_actor(&mut self, batch: &R::Batch) -> f32 {
152 let loss = {
153 let o = batch.obs().clone();
154 let (a, log_p) = self.action_logp(&o.into());
155
156 self.ent_coef.update(&log_p.detach());
158
159 let o = batch.obs().clone();
160 let qval = self.qvals_min(&self.qnets, &o.into(), &a.into());
161 (self.ent_coef.alpha().detach() * &log_p - &qval).mean(tch::Kind::Float)
162 };
163
164 self.pi.backward_step(&loss);
165
166 f32::try_from(loss).expect("Failed to convert Tensor to f32")
167 }
168
169 fn soft_update(&mut self) {
170 for (qnet_tgt, qnet) in self.qnets_tgt.iter_mut().zip(&mut self.qnets) {
171 track(qnet_tgt, qnet, self.tau);
172 }
173 }
174
175 fn opt_(&mut self, buffer: &mut R) -> Record {
176 let mut loss_critic = 0f32;
177 let mut loss_actor = 0f32;
178
179 for _ in 0..self.n_updates_per_opt {
180 let batch = buffer.batch(self.batch_size).unwrap();
181 loss_actor += self.update_actor(&batch);
182 loss_critic += self.update_critic(batch);
183 self.soft_update();
184 self.n_opts += 1;
185 }
186
187 loss_critic /= self.n_updates_per_opt as f32;
188 loss_actor /= self.n_updates_per_opt as f32;
189
190 Record::from_slice(&[
191 ("loss_critic", RecordValue::Scalar(loss_critic)),
192 ("loss_actor", RecordValue::Scalar(loss_actor)),
193 (
194 "ent_coef",
195 RecordValue::Scalar(self.ent_coef.alpha().double_value(&[0]) as f32),
196 ),
197 ])
198 }
199
200 pub fn get_policy_net(&self) -> &Actor<P> {
201 &self.pi
202 }
203}
204
205impl<E, Q, P, R> Policy<E> for Sac<E, Q, P, R>
206where
207 E: Env,
208 Q: SubModel2<Output = ActionValue>,
209 P: SubModel<Output = (ActMean, ActStd)>,
210 E::Obs: Into<Q::Input1> + Into<P::Input>,
211 E::Act: Into<Q::Input2> + From<Tensor>,
212 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
213 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
214{
215 fn sample(&mut self, obs: &E::Obs) -> E::Act {
216 let obs = obs.clone().into();
217 let (mean, lstd) = self.pi.forward(&obs);
218 let std = lstd.clip(self.min_lstd, self.max_lstd).exp();
219 let act = if self.train {
220 std * Tensor::randn(&mean.size(), tch::kind::FLOAT_CPU).to(self.device) + mean
221 } else {
222 mean
223 };
224 act.tanh().into()
225 }
226}
227
228impl<E, Q, P, R> Configurable for Sac<E, Q, P, R>
229where
230 E: Env,
231 Q: SubModel2<Output = ActionValue>,
232 P: SubModel<Output = (ActMean, ActStd)>,
233 E::Obs: Into<Q::Input1> + Into<P::Input>,
234 E::Act: Into<Q::Input2> + From<Tensor>,
235 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
236 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
237{
238 type Config = SacConfig<Q, P>;
239
240 fn build(config: Self::Config) -> Self {
242 let device = config
243 .device
244 .expect("No device is given for SAC agent")
245 .into();
246 let n_critics = config.n_critics;
247 let pi = Actor::build(config.actor_config, device).unwrap();
248 let mut qnets = vec![];
249 let mut qnets_tgt = vec![];
250 for _ in 0..n_critics {
251 let critic = Critic::build(config.critic_config.clone(), device).unwrap();
252 qnets.push(critic.clone());
253 qnets_tgt.push(critic);
254 }
255
256 if let Some(seed) = config.seed.as_ref() {
257 tch::manual_seed(*seed);
258 }
259
260 Sac {
261 qnets,
262 qnets_tgt,
263 pi,
264 gamma: config.gamma,
265 tau: config.tau,
266 ent_coef: EntCoef::new(config.ent_coef_mode, device),
267 epsilon: config.epsilon,
268 min_lstd: config.min_lstd,
269 max_lstd: config.max_lstd,
270 n_updates_per_opt: config.n_updates_per_opt,
271 batch_size: config.batch_size,
272 train: config.train,
273 reward_scale: config.reward_scale,
274 critic_loss: config.critic_loss,
275 n_opts: 0,
276 device,
277 phantom: PhantomData,
278 }
279 }
280}
281
282impl<E, Q, P, R> Agent<E, R> for Sac<E, Q, P, R>
283where
284 E: Env + 'static,
285 Q: SubModel2<Output = ActionValue> + 'static,
286 P: SubModel<Output = (ActMean, ActStd)> + 'static,
287 R: ReplayBufferBase + 'static,
288 E::Obs: Into<Q::Input1> + Into<P::Input>,
289 E::Act: Into<Q::Input2> + From<Tensor>,
290 Q::Input2: From<ActMean>,
291 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
292 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
293 R::Batch: TransitionBatch,
294 <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input1> + Into<P::Input> + Clone,
295 <R::Batch as TransitionBatch>::ActBatch: Into<Q::Input2> + Into<Tensor>,
296{
297 fn train(&mut self) {
298 self.train = true;
299 }
300
301 fn eval(&mut self) {
302 self.train = false;
303 }
304
305 fn is_train(&self) -> bool {
306 self.train
307 }
308
309 fn opt_with_record(&mut self, buffer: &mut R) -> Record {
310 self.opt_(buffer)
311 }
312
313 fn save_params(&self, path: &Path) -> Result<Vec<PathBuf>> {
314 fs::create_dir_all(&path)?;
316 let mut paths = vec![];
317
318 for (i, (qnet, qnet_tgt)) in self.qnets.iter().zip(&self.qnets_tgt).enumerate() {
319 let path1 = path.join(format!("qnet_{}.pt.tch", i)).to_path_buf();
320 let path2 = path.join(format!("qnet_tgt_{}.pt.tch", i)).to_path_buf();
321 qnet.save(&path1)?;
322 qnet_tgt.save(&path2)?;
323 paths.push(path1);
324 paths.push(path2);
325 }
326
327 let path_actor = path.join("pi.pt.tch").to_path_buf();
328 let path_ent_coef = path.join("ent_coef.pt.tch").to_path_buf();
329 self.pi.save(&path_actor)?;
330 self.ent_coef.save(&path_ent_coef)?;
331 paths.push(path_actor);
332 paths.push(path_ent_coef);
333
334 Ok(paths)
335 }
336
337 fn load_params(&mut self, path: &Path) -> Result<()> {
338 for (i, (qnet, qnet_tgt)) in self.qnets.iter_mut().zip(&mut self.qnets_tgt).enumerate() {
339 qnet.load(path.join(format!("qnet_{}.pt.tch", i)).as_path())?;
340 qnet_tgt.load(path.join(format!("qnet_tgt_{}.pt.tch", i)).as_path())?;
341 }
342 self.pi.load(path.join("pi.pt.tch").as_path())?;
343 self.ent_coef.load(path.join("ent_coef.pt.tch").as_path())?;
344 Ok(())
345 }
346
347 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
348 self
349 }
350
351 fn as_any_ref(&self) -> &dyn std::any::Any {
352 self
353 }
354}
355
356#[cfg(feature = "border-async-trainer")]
357use {crate::util::NamedTensors, border_async_trainer::SyncModel};
358
359#[cfg(feature = "border-async-trainer")]
360impl<E, Q, P, R> SyncModel for Sac<E, Q, P, R>
361where
362 E: Env,
363 Q: SubModel2<Output = ActionValue>,
364 P: SubModel<Output = (ActMean, ActStd)>,
365 R: ReplayBufferBase,
366 E::Obs: Into<Q::Input1> + Into<P::Input>,
367 E::Act: Into<Q::Input2> + From<Tensor>,
368 Q::Input2: From<ActMean>,
369 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
370 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
371 R::Batch: TransitionBatch,
372 <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input1> + Into<P::Input> + Clone,
373 <R::Batch as TransitionBatch>::ActBatch: Into<Q::Input2> + Into<Tensor>,
374{
375 type ModelInfo = NamedTensors;
376
377 fn model_info(&self) -> (usize, Self::ModelInfo) {
378 (
379 self.n_opts,
380 NamedTensors::copy_from(self.pi.get_var_store()),
381 )
382 }
383
384 fn sync_model(&mut self, model_info: &Self::ModelInfo) {
385 model_info.copy_to(self.pi.get_var_store_mut());
386 }
387}