1use super::{Actor, Critic, EntCoef, SacConfig};
2use crate::{
3 model::{ModelBase, SubModel, SubModel2},
4 util::{track, CriticLoss, OutDim},
5};
6use anyhow::Result;
7use border_core::{
8 record::{Record, RecordValue},
9 Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch,
10};
11use serde::{de::DeserializeOwned, Serialize};
12use std::{convert::TryFrom, fs, marker::PhantomData, path::Path};
14use tch::{no_grad, Tensor};
15
16type ActionValue = Tensor;
17type ActMean = Tensor;
18type ActStd = Tensor;
19
20fn normal_logp(x: &Tensor) -> Tensor {
21 let tmp: Tensor = Tensor::from(-0.5 * (2.0 * std::f32::consts::PI).ln() as f32)
22 - 0.5 * x.pow_tensor_scalar(2);
23 tmp.sum_dim_intlist(Some([-1].as_slice()), false, tch::Kind::Float)
24}
25
26pub struct Sac<E, Q, P, R>
28where
29 Q: SubModel2<Output = ActionValue>,
30 P: SubModel<Output = (ActMean, ActStd)>,
31 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
32 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
33{
34 pub(super) qnets: Vec<Critic<Q>>,
35 pub(super) qnets_tgt: Vec<Critic<Q>>,
36 pub(super) pi: Actor<P>,
37 pub(super) gamma: f64,
38 pub(super) tau: f64,
39 pub(super) ent_coef: EntCoef,
40 pub(super) epsilon: f64,
41 pub(super) min_lstd: f64,
42 pub(super) max_lstd: f64,
43 pub(super) n_updates_per_opt: usize,
44 pub(super) batch_size: usize,
45 pub(super) train: bool,
46 pub(super) reward_scale: f32,
47 pub(super) n_opts: usize,
48 pub(super) critic_loss: CriticLoss,
49 pub(super) phantom: PhantomData<(E, R)>,
50 pub(super) device: tch::Device,
51}
52
53impl<E, Q, P, R> Sac<E, Q, P, R>
54where
55 E: Env,
56 Q: SubModel2<Output = ActionValue>,
57 P: SubModel<Output = (ActMean, ActStd)>,
58 R: ReplayBufferBase,
59 E::Obs: Into<Q::Input1> + Into<P::Input>,
60 E::Act: Into<Q::Input2>,
61 Q::Input2: From<ActMean>,
62 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
63 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
64 R::Batch: TransitionBatch,
65 <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input1> + Into<P::Input> + Clone,
66 <R::Batch as TransitionBatch>::ActBatch: Into<Q::Input2> + Into<Tensor>,
67{
68 fn action_logp(&self, o: &P::Input) -> (Tensor, Tensor) {
69 let (mean, lstd) = self.pi.forward(o);
70 let std = lstd.clip(self.min_lstd, self.max_lstd).exp();
71 let z = Tensor::randn(mean.size().as_slice(), tch::kind::FLOAT_CPU).to(self.device);
72 let a = (&std * &z + &mean).tanh();
73 let log_p = normal_logp(&z)
74 - (Tensor::from(1f32) - a.pow_tensor_scalar(2.0) + Tensor::from(self.epsilon))
75 .log()
76 .sum_dim_intlist(Some([-1].as_slice()), false, tch::Kind::Float);
77
78 debug_assert_eq!(a.size().as_slice()[0], self.batch_size as i64);
79 debug_assert_eq!(log_p.size().as_slice(), [self.batch_size as i64]);
80
81 (a, log_p)
82 }
83
84 fn qvals(&self, qnets: &[Critic<Q>], obs: &Q::Input1, act: &Q::Input2) -> Vec<Tensor> {
85 qnets
86 .iter()
87 .map(|qnet| qnet.forward(obs, act).squeeze())
88 .collect()
89 }
90
91 fn qvals_min(&self, qnets: &[Critic<Q>], obs: &Q::Input1, act: &Q::Input2) -> Tensor {
93 let qvals = self.qvals(qnets, obs, act);
94 let qvals = Tensor::vstack(&qvals);
95 let qvals_min = qvals.min_dim(0, false).0;
96
97 debug_assert_eq!(qvals_min.size().as_slice(), [self.batch_size as i64]);
98
99 qvals_min
100 }
101
102 fn update_critic(&mut self, batch: R::Batch) -> f32 {
103 let losses = {
104 let (obs, act, next_obs, reward, is_terminated, _is_truncated, _, _) = batch.unpack();
105 let reward = Tensor::from_slice(&reward[..]).to(self.device);
106 let is_terminated = Tensor::from_slice(&is_terminated[..]).to(self.device);
107
108 let preds = self.qvals(&self.qnets, &obs.into(), &act.into());
109 let tgt = {
110 let next_q = no_grad(|| {
111 let (next_a, next_log_p) = self.action_logp(&next_obs.clone().into());
112 let next_q = self.qvals_min(&self.qnets_tgt, &next_obs.into(), &next_a.into());
113 next_q - self.ent_coef.alpha() * next_log_p
114 });
115 self.reward_scale * reward
116 + (1f32 - &is_terminated) * Tensor::from(self.gamma) * next_q
117 };
118
119 debug_assert_eq!(tgt.size().as_slice(), [self.batch_size as i64]);
120
121 let losses: Vec<_> = match self.critic_loss {
122 CriticLoss::Mse => preds
123 .iter()
124 .map(|pred| pred.mse_loss(&tgt, tch::Reduction::Mean))
125 .collect(),
126 CriticLoss::SmoothL1 => preds
127 .iter()
128 .map(|pred| pred.smooth_l1_loss(&tgt, tch::Reduction::Mean, 1.0))
129 .collect(),
130 };
131 losses
132 };
133
134 for (qnet, loss) in self.qnets.iter_mut().zip(&losses) {
135 qnet.backward_step(&loss);
136 }
137
138 losses
139 .iter()
140 .map(f32::try_from)
141 .map(|a| a.expect("Failed to convert Tensor to f32"))
142 .sum::<f32>()
143 / (self.qnets.len() as f32)
144 }
145
146 fn update_actor(&mut self, batch: &R::Batch) -> f32 {
147 let loss = {
148 let o = batch.obs().clone();
149 let (a, log_p) = self.action_logp(&o.into());
150
151 self.ent_coef.update(&log_p.detach());
153
154 let o = batch.obs().clone();
155 let qval = self.qvals_min(&self.qnets, &o.into(), &a.into());
156 (self.ent_coef.alpha().detach() * &log_p - &qval).mean(tch::Kind::Float)
157 };
158
159 self.pi.backward_step(&loss);
160
161 f32::try_from(loss).expect("Failed to convert Tensor to f32")
162 }
163
164 fn soft_update(&mut self) {
165 for (qnet_tgt, qnet) in self.qnets_tgt.iter_mut().zip(&mut self.qnets) {
166 track(qnet_tgt, qnet, self.tau);
167 }
168 }
169
170 fn opt_(&mut self, buffer: &mut R) -> Record {
171 let mut loss_critic = 0f32;
172 let mut loss_actor = 0f32;
173
174 for _ in 0..self.n_updates_per_opt {
175 let batch = buffer.batch(self.batch_size).unwrap();
176 loss_actor += self.update_actor(&batch);
177 loss_critic += self.update_critic(batch);
178 self.soft_update();
179 self.n_opts += 1;
180 }
181
182 loss_critic /= self.n_updates_per_opt as f32;
183 loss_actor /= self.n_updates_per_opt as f32;
184
185 Record::from_slice(&[
186 ("loss_critic", RecordValue::Scalar(loss_critic)),
187 ("loss_actor", RecordValue::Scalar(loss_actor)),
188 (
189 "ent_coef",
190 RecordValue::Scalar(self.ent_coef.alpha().double_value(&[0]) as f32),
191 ),
192 ])
193 }
194
195 pub fn get_policy_net(&self) -> &Actor<P> {
196 &self.pi
197 }
198}
199
200impl<E, Q, P, R> Policy<E> for Sac<E, Q, P, R>
201where
202 E: Env,
203 Q: SubModel2<Output = ActionValue>,
204 P: SubModel<Output = (ActMean, ActStd)>,
205 E::Obs: Into<Q::Input1> + Into<P::Input>,
206 E::Act: Into<Q::Input2> + From<Tensor>,
207 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
208 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
209{
210 fn sample(&mut self, obs: &E::Obs) -> E::Act {
211 let obs = obs.clone().into();
212 let (mean, lstd) = self.pi.forward(&obs);
213 let std = lstd.clip(self.min_lstd, self.max_lstd).exp();
214 let act = if self.train {
215 std * Tensor::randn(&mean.size(), tch::kind::FLOAT_CPU).to(self.device) + mean
216 } else {
217 mean
218 };
219 act.tanh().into()
220 }
221}
222
223impl<E, Q, P, R> Configurable<E> for Sac<E, Q, P, R>
224where
225 E: Env,
226 Q: SubModel2<Output = ActionValue>,
227 P: SubModel<Output = (ActMean, ActStd)>,
228 E::Obs: Into<Q::Input1> + Into<P::Input>,
229 E::Act: Into<Q::Input2> + From<Tensor>,
230 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
231 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
232{
233 type Config = SacConfig<Q, P>;
234
235 fn build(config: Self::Config) -> Self {
237 let device = config
238 .device
239 .expect("No device is given for SAC agent")
240 .into();
241 let n_critics = config.n_critics;
242 let pi = Actor::build(config.actor_config, device).unwrap();
243 let mut qnets = vec![];
244 let mut qnets_tgt = vec![];
245 for _ in 0..n_critics {
246 let critic = Critic::build(config.critic_config.clone(), device).unwrap();
247 qnets.push(critic.clone());
248 qnets_tgt.push(critic);
249 }
250
251 if let Some(seed) = config.seed.as_ref() {
252 tch::manual_seed(*seed);
253 }
254
255 Sac {
256 qnets,
257 qnets_tgt,
258 pi,
259 gamma: config.gamma,
260 tau: config.tau,
261 ent_coef: EntCoef::new(config.ent_coef_mode, device),
262 epsilon: config.epsilon,
263 min_lstd: config.min_lstd,
264 max_lstd: config.max_lstd,
265 n_updates_per_opt: config.n_updates_per_opt,
266 batch_size: config.batch_size,
267 train: config.train,
268 reward_scale: config.reward_scale,
269 critic_loss: config.critic_loss,
270 n_opts: 0,
271 device,
272 phantom: PhantomData,
273 }
274 }
275}
276
277impl<E, Q, P, R> Agent<E, R> for Sac<E, Q, P, R>
278where
279 E: Env,
280 Q: SubModel2<Output = ActionValue>,
281 P: SubModel<Output = (ActMean, ActStd)>,
282 R: ReplayBufferBase,
283 E::Obs: Into<Q::Input1> + Into<P::Input>,
284 E::Act: Into<Q::Input2> + From<Tensor>,
285 Q::Input2: From<ActMean>,
286 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
287 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
288 R::Batch: TransitionBatch,
289 <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input1> + Into<P::Input> + Clone,
290 <R::Batch as TransitionBatch>::ActBatch: Into<Q::Input2> + Into<Tensor>,
291{
292 fn train(&mut self) {
293 self.train = true;
294 }
295
296 fn eval(&mut self) {
297 self.train = false;
298 }
299
300 fn is_train(&self) -> bool {
301 self.train
302 }
303
304 fn opt_with_record(&mut self, buffer: &mut R) -> Record {
305 self.opt_(buffer)
306 }
307
308 fn save_params<T: AsRef<Path>>(&self, path: T) -> Result<()> {
309 fs::create_dir_all(&path)?;
311 for (i, (qnet, qnet_tgt)) in self.qnets.iter().zip(&self.qnets_tgt).enumerate() {
312 qnet.save(&path.as_ref().join(format!("qnet_{}.pt.tch", i)).as_path())?;
313 qnet_tgt.save(
314 &path
315 .as_ref()
316 .join(format!("qnet_tgt_{}.pt.tch", i))
317 .as_path(),
318 )?;
319 }
320 self.pi.save(&path.as_ref().join("pi.pt.tch").as_path())?;
321 self.ent_coef
322 .save(&path.as_ref().join("ent_coef.pt.tch").as_path())?;
323 Ok(())
324 }
325
326 fn load_params<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
327 for (i, (qnet, qnet_tgt)) in self.qnets.iter_mut().zip(&mut self.qnets_tgt).enumerate() {
328 qnet.load(&path.as_ref().join(format!("qnet_{}.pt.tch", i)).as_path())?;
329 qnet_tgt.load(
330 &path
331 .as_ref()
332 .join(format!("qnet_tgt_{}.pt.tch", i))
333 .as_path(),
334 )?;
335 }
336 self.pi.load(&path.as_ref().join("pi.pt.tch").as_path())?;
337 self.ent_coef
338 .load(&path.as_ref().join("ent_coef.pt.tch").as_path())?;
339 Ok(())
340 }
341}
342
343#[cfg(feature = "border-async-trainer")]
344use {crate::util::NamedTensors, border_async_trainer::SyncModel};
345
346#[cfg(feature = "border-async-trainer")]
347impl<E, Q, P, R> SyncModel for Sac<E, Q, P, R>
348where
349 E: Env,
350 Q: SubModel2<Output = ActionValue>,
351 P: SubModel<Output = (ActMean, ActStd)>,
352 R: ReplayBufferBase,
353 E::Obs: Into<Q::Input1> + Into<P::Input>,
354 E::Act: Into<Q::Input2> + From<Tensor>,
355 Q::Input2: From<ActMean>,
356 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
357 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
358 R::Batch: TransitionBatch,
359 <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input1> + Into<P::Input> + Clone,
360 <R::Batch as TransitionBatch>::ActBatch: Into<Q::Input2> + Into<Tensor>,
361{
362 type ModelInfo = NamedTensors;
363
364 fn model_info(&self) -> (usize, Self::ModelInfo) {
365 (
366 self.n_opts,
367 NamedTensors::copy_from(self.pi.get_var_store()),
368 )
369 }
370
371 fn sync_model(&mut self, model_info: &Self::ModelInfo) {
372 model_info.copy_to(self.pi.get_var_store_mut());
373 }
374}