1use super::AwacConfig;
2use crate::{
3 model::{SubModel1, SubModel2},
4 util::{
5 actor::GaussianActor, critic::MultiCritic, gamma_not_done, smooth_l1_loss, CriticLoss,
6 OutDim,
7 },
8};
9use anyhow::Result;
10use border_core::{
11 record::{Record, RecordValue},
12 Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch,
13};
14use candle_core::{Device, Tensor, D};
15use candle_nn::{loss::mse, ops::softmax};
16use serde::{de::DeserializeOwned, Serialize};
17use std::{
18 fs,
19 marker::PhantomData,
20 path::{Path, PathBuf},
21};
22
23type ActionValue = Tensor;
24type ActMean = Tensor;
25type ActStd = Tensor;
26
27pub struct Awac<E, Q, P, R>
29where
30 Q: SubModel2<Output = ActionValue>,
31 P: SubModel1<Output = (ActMean, ActStd)>,
32 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
33 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
34{
35 critic: MultiCritic<Q>,
36 actor: GaussianActor<P>,
37 gamma: f64,
38 inv_lambda: f64,
39 n_updates_per_opt: usize,
40 batch_size: usize,
41 train: bool,
42 n_opts: usize,
44 exp_adv_max: f64,
45 critic_loss: CriticLoss,
46 phantom: PhantomData<(E, R)>,
47 device: Device,
48 adv_softmax: bool,
49}
50
51impl<E, Q, P, R> Awac<E, Q, P, R>
52where
53 E: Env,
54 Q: SubModel2<Output = ActionValue>,
55 P: SubModel1<Output = (ActMean, ActStd)>,
56 R: ReplayBufferBase,
57 E::Obs: Into<Q::Input1> + Into<P::Input>,
58 E::Act: Into<Q::Input2> + Into<Tensor>,
59 Q::Input2: From<ActMean> + Into<Tensor>,
60 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
61 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
62 R::Batch: TransitionBatch,
63 <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input1> + Into<P::Input> + Clone,
64 <R::Batch as TransitionBatch>::ActBatch: Into<Q::Input2> + Into<Tensor> + Clone,
65{
66 fn update_critic(&mut self, batch: R::Batch) -> Result<(f32, f32, f32, f32)> {
67 let (loss, q_tgt_abs_mean, reward_mean, next_q_mean) = {
68 let (obs, act, next_obs, reward, is_terminated, is_truncated, _, _) = batch.unpack();
70 let batch_size = reward.len();
71 let reward = Tensor::from_slice(&reward[..], (batch_size,), &self.device)?;
72
73 let qs = self.critic.qvals(&obs.into(), &act.into());
75
76 let (tgt, reward, next_q) = {
78 let gamma_not_done = gamma_not_done(
79 self.gamma as f32,
80 is_terminated,
81 Some(is_truncated),
82 &self.device,
83 )?;
84 let next_act = self.actor.sample(&next_obs.clone().into(), self.train)?;
85 let next_q = self
86 .critic
87 .qvals_min_tgt(&next_obs.into(), &next_act.into())?;
88 let tgt = (&reward + (&gamma_not_done * &next_q)?)?.squeeze(D::Minus1)?;
89
90 (tgt.detach(), reward, next_q)
91 };
92 debug_assert_eq!(tgt.dims(), [self.batch_size]);
93
94 let losses: Vec<_> = match self.critic_loss {
96 CriticLoss::Mse => qs.iter().map(|pred| mse(&pred, &tgt).unwrap()).collect(),
97 CriticLoss::SmoothL1 => qs
98 .iter()
99 .map(|pred| smooth_l1_loss(&pred, &tgt).unwrap())
100 .collect(),
101 };
102
103 let q_tgt_abs_mean = tgt.abs()?.mean_all()?.to_scalar::<f32>()?;
105 let reward_mean = reward.mean_all()?.to_scalar::<f32>()?;
106 let next_q_mean = next_q.mean_all()?.to_scalar::<f32>()?;
107
108 (
109 Tensor::stack(&losses, 0)?.sum_all()?,
110 q_tgt_abs_mean,
111 reward_mean,
112 next_q_mean,
113 )
114 };
115
116 self.critic.backward_step(&loss)?;
117 self.critic.soft_update()?;
118
119 Ok((
120 loss.to_scalar::<f32>()?,
121 q_tgt_abs_mean,
122 reward_mean,
123 next_q_mean,
124 ))
125 }
126
127 fn update_actor(&mut self, batch: &R::Batch) -> Result<(f32, f32, f32, f32)> {
128 log::trace!("Extract items in the batch");
130 let obs = batch.obs().clone();
131 let act = batch.act().clone();
132
133 let (w, adv) = {
134 let act_ = self.actor.sample(&obs.clone().into(), self.train)?;
135 let q = self
136 .critic
137 .qvals_min(&obs.clone().into(), &act.clone().into())?;
138 let v = self.critic.qvals_min(&obs.clone().into(), &act_.into())?;
139 let adv = (&q - &v)?;
140 debug_assert_eq!(adv.dims(), &[self.batch_size]);
141
142 let w = match self.adv_softmax {
143 false => (&adv * self.inv_lambda)?
144 .exp()?
145 .clamp(0f64, self.exp_adv_max)?,
146 true => softmax(&(&adv * self.inv_lambda)?, 0)?,
147 }
148 .detach();
149 (w, adv)
150 };
151 debug_assert_eq!(w.dims(), &[self.batch_size]);
152
153 let (loss, logp) = {
154 let logp = self.actor.logp(&obs.into(), &act.into())?;
155 debug_assert_eq!(logp.dims(), &[self.batch_size]);
156
157 ((-1f64 * &logp * w)?.mean_all()?, logp)
158 };
159
160 self.actor.backward_step(&loss)?;
161
162 let loss = loss.to_scalar::<f32>()?;
163 let adv_mean = adv.mean_all()?.to_scalar::<f32>()?;
164 let adv_abs_mean = adv.abs()?.mean_all()?.to_scalar::<f32>()?;
165 let logp_mean = logp.mean_all()?.to_scalar::<f32>()?;
166
167 Ok((loss, adv_mean, adv_abs_mean, logp_mean))
168 }
169
170 fn opt_(&mut self, buffer: &mut R) -> Result<Record> {
171 let mut loss_critic = 0f32;
172 let mut loss_actor = 0f32;
173 let mut q_tgt_abs_mean = 0f32;
174 let mut adv_mean = 0f32;
175 let mut adv_abs_mean = 0f32;
176 let mut logp_mean = 0f32;
177 let mut reward_mean = 0f32;
178 let mut next_q_mean = 0f32;
179
180 for _ in 0..self.n_updates_per_opt {
181 let batch = buffer.batch(self.batch_size).unwrap();
182 let (loss_actor_, adv_mean_, adv_abs_mean_, logp_mean_) = self.update_actor(&batch)?;
183 loss_actor += loss_actor_;
184 adv_mean += adv_mean_;
185 adv_abs_mean += adv_abs_mean_;
186 logp_mean += logp_mean_;
187
188 let (loss_critic_, q_tgt_abs_mean_, reward_mean_, next_q_mean_) =
189 self.update_critic(batch)?;
190 loss_critic += loss_critic_;
191 q_tgt_abs_mean += q_tgt_abs_mean_;
192 reward_mean += reward_mean_;
193 next_q_mean += next_q_mean_;
194 self.n_opts += 1;
195 }
196
197 loss_critic /= self.n_updates_per_opt as f32;
198 loss_actor /= self.n_updates_per_opt as f32;
199 q_tgt_abs_mean /= self.n_updates_per_opt as f32;
200 adv_mean /= self.n_updates_per_opt as f32;
201 adv_abs_mean /= self.n_updates_per_opt as f32;
202
203 let record = Record::from_slice(&[
204 ("loss_critic", RecordValue::Scalar(loss_critic)),
205 ("loss_actor", RecordValue::Scalar(loss_actor)),
206 ("q_tgt_abs_mean", RecordValue::Scalar(q_tgt_abs_mean)),
207 ("adv_mean", RecordValue::Scalar(adv_mean)),
208 ("adv_abs_mean", RecordValue::Scalar(adv_abs_mean)),
209 ("logp_mean", RecordValue::Scalar(logp_mean)),
210 ("reward_mean", RecordValue::Scalar(reward_mean)),
211 ("next_q_mean", RecordValue::Scalar(next_q_mean)),
212 ]);
213
214 Ok(record)
215 }
216}
217
218impl<E, Q, P, R> Policy<E> for Awac<E, Q, P, R>
219where
220 E: Env,
221 Q: SubModel2<Output = ActionValue>,
222 P: SubModel1<Output = (ActMean, ActStd)>,
223 E::Obs: Into<Q::Input1> + Into<P::Input>,
224 E::Act: Into<Q::Input2> + From<Tensor>,
225 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
226 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
227{
228 fn sample(&mut self, obs: &E::Obs) -> E::Act {
229 self.actor
230 .sample(&obs.clone().into(), self.train)
231 .unwrap()
232 .into()
233 }
234}
235
236impl<E, Q, P, R> Configurable for Awac<E, Q, P, R>
237where
238 E: Env,
239 Q: SubModel2<Output = ActionValue>,
240 P: SubModel1<Output = (ActMean, ActStd)>,
241 E::Obs: Into<Q::Input1> + Into<P::Input>,
242 E::Act: Into<Q::Input2> + From<Tensor>,
243 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
244 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
245{
246 type Config = AwacConfig<Q, P>;
247
248 fn build(config: Self::Config) -> Self {
250 let device: Device = config
251 .device
252 .expect("No device is given for AWAC agent")
253 .into();
254 let actor = GaussianActor::build(config.actor_config, device.clone().into()).unwrap();
255 let critics = MultiCritic::build(config.critic_config, device.clone().into()).unwrap();
256
257 Awac {
258 critic: critics,
259 actor,
260 gamma: config.gamma,
261 n_updates_per_opt: config.n_updates_per_opt,
266 batch_size: config.batch_size,
267 critic_loss: config.critic_loss,
269 inv_lambda: config.inv_lambda,
270 exp_adv_max: config.exp_adv_max,
271 n_opts: 0,
272 train: false,
273 device: device.into(),
274 adv_softmax: config.adv_softmax,
275 phantom: PhantomData,
276 }
277 }
278}
279
280impl<E, Q, P, R> Agent<E, R> for Awac<E, Q, P, R>
281where
282 E: Env + 'static,
283 Q: SubModel2<Output = ActionValue> + 'static,
284 P: SubModel1<Output = (ActMean, ActStd)> + 'static,
285 R: ReplayBufferBase + 'static,
286 E::Obs: Into<Q::Input1> + Into<P::Input>,
287 E::Act: Into<Q::Input2> + Into<Tensor> + From<Tensor>,
288 Q::Input2: From<ActMean> + Into<Tensor>,
289 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
290 P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
291 R::Batch: TransitionBatch,
292 <R::Batch as TransitionBatch>::ObsBatch: Into<Q::Input1> + Into<P::Input> + Clone,
293 <R::Batch as TransitionBatch>::ActBatch: Into<Q::Input2> + Into<Tensor> + Clone,
294{
295 fn train(&mut self) {
296 self.train = true;
297 }
298
299 fn eval(&mut self) {
300 self.train = false;
301 }
302
303 fn is_train(&self) -> bool {
304 self.train
305 }
306
307 fn opt_with_record(&mut self, buffer: &mut R) -> Record {
308 self.opt_(buffer).expect("Failed in Awac::opt_()")
309 }
310
311 fn save_params(&self, path: &Path) -> Result<Vec<PathBuf>> {
312 fs::create_dir_all(&path)?;
314
315 let actor_path = self.actor.save(path.join("actor"))?;
316 let (critic_path, critic_tgt_path) = self.critic.save(path.join("critic"))?;
317
318 Ok(vec![actor_path, critic_path, critic_tgt_path])
319 }
320
321 fn load_params(&mut self, path: &Path) -> Result<()> {
322 self.actor.load(path.join("actor").as_path())?;
323 self.critic.load(path.join("critic").as_path())?;
324
325 Ok(())
326 }
327
328 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
329 self
330 }
331
332 fn as_any_ref(&self) -> &dyn std::any::Any {
333 self
334 }
335}