1use super::{average, IqnConfig, IqnExplorer, IqnModel, IqnSample};
3use crate::{
4 model::{ModelBase, SubModel},
5 util::{quantile_huber_loss, track, OutDim},
6};
7use anyhow::Result;
8use border_core::{
9 record::{Record, RecordValue},
10 Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch,
11};
12use log::trace;
13use serde::{de::DeserializeOwned, Serialize};
14use std::{
15 convert::TryFrom,
16 fs,
17 marker::PhantomData,
18 path::{Path, PathBuf},
19};
20use tch::{no_grad, Device, Tensor};
21
22pub struct Iqn<E, F, M, R>
27where
28 F: SubModel<Output = Tensor>,
29 M: SubModel<Input = Tensor, Output = Tensor>,
30 F::Config: DeserializeOwned + Serialize,
31 M::Config: DeserializeOwned + Serialize,
32{
33 pub(in crate::iqn) soft_update_interval: usize,
34 pub(in crate::iqn) soft_update_counter: usize,
35 pub(in crate::iqn) n_updates_per_opt: usize,
36 pub(in crate::iqn) batch_size: usize,
37 pub(in crate::iqn) iqn: IqnModel<F, M>,
38 pub(in crate::iqn) iqn_tgt: IqnModel<F, M>,
39 pub(in crate::iqn) train: bool,
40 pub(in crate::iqn) phantom: PhantomData<(E, R)>,
41 pub(in crate::iqn) discount_factor: f64,
42 pub(in crate::iqn) tau: f64,
43 pub(in crate::iqn) sample_percents_pred: IqnSample,
44 pub(in crate::iqn) sample_percents_tgt: IqnSample,
45 pub(in crate::iqn) sample_percents_act: IqnSample,
46 pub(in crate::iqn) explorer: IqnExplorer,
47 pub(in crate::iqn) device: Device,
48 pub(in crate::iqn) n_opts: usize,
49}
50
51impl<E, F, M, R> Iqn<E, F, M, R>
52where
53 E: Env,
54 F: SubModel<Output = Tensor>,
55 M: SubModel<Input = Tensor, Output = Tensor>,
56 R: ReplayBufferBase,
57 F::Config: DeserializeOwned + Serialize,
58 M::Config: DeserializeOwned + Serialize + OutDim,
59 R::Batch: TransitionBatch,
60 <R::Batch as TransitionBatch>::ObsBatch: Into<F::Input>,
61 <R::Batch as TransitionBatch>::ActBatch: Into<Tensor>,
62{
63 fn update_critic(&mut self, buffer: &mut R) -> f32 {
64 trace!("IQN::update_critic()");
65 let batch = buffer.batch(self.batch_size).unwrap();
66 let (obs, act, next_obs, reward, is_terminated, _is_truncated, _ixs, _weight) =
67 batch.unpack();
68 let obs = obs.into();
69 let act = act.into().to(self.device);
70 let next_obs = next_obs.into();
71 let reward = Tensor::from_slice(&reward[..])
72 .to(self.device)
73 .unsqueeze(-1);
74 let is_terminated = Tensor::from_slice(&is_terminated[..])
75 .to(self.device)
76 .unsqueeze(-1);
77
78 let batch_size = self.batch_size as _;
79 let n_percent_points_pred = self.sample_percents_pred.n_percent_points();
80 let n_percent_points_tgt = self.sample_percents_tgt.n_percent_points();
81
82 debug_assert_eq!(reward.size().as_slice(), &[batch_size, 1]);
83 debug_assert_eq!(is_terminated.size().as_slice(), &[batch_size, 1]);
84 debug_assert_eq!(act.size().as_slice(), &[batch_size, 1]);
85
86 let loss = {
87 let (pred, tau) = {
90 let n_percent_points = n_percent_points_pred;
91
92 let tau = self.sample_percents_pred.sample(batch_size).to(self.device);
94 debug_assert_eq!(tau.size().as_slice(), &[batch_size, n_percent_points]);
95
96 let z = self.iqn.forward(&obs, &tau);
98 let n_actions = z.size()[z.size().len() - 1];
99 debug_assert_eq!(
100 z.size().as_slice(),
101 &[batch_size, n_percent_points, n_actions]
102 );
103
104 let a = act.unsqueeze(1).repeat(&[1, n_percent_points, 1]);
106 debug_assert_eq!(a.size().as_slice(), &[batch_size, n_percent_points, 1]);
107
108 let pred = z.gather(-1, &a, false).squeeze_dim(-1).unsqueeze(1);
110 debug_assert_eq!(pred.size().as_slice(), &[batch_size, 1, n_percent_points]);
111 (pred, tau)
112 };
113
114 let tgt = no_grad(|| {
118 let n_percent_points = n_percent_points_tgt;
119
120 let tau = self.sample_percents_tgt.sample(batch_size).to(self.device);
122 debug_assert_eq!(tau.size().as_slice(), &[batch_size, n_percent_points]);
123
124 let z = self.iqn_tgt.forward(&next_obs, &tau);
126 let n_actions = z.size()[z.size().len() - 1];
127 debug_assert_eq!(
128 z.size().as_slice(),
129 &[batch_size, n_percent_points, n_actions]
130 );
131
132 let y = z
134 .copy()
135 .mean_dim(Some([1].as_slice()), false, tch::Kind::Float);
136 let a = y.argmax(-1, false).unsqueeze(-1).unsqueeze(-1).repeat(&[
137 1,
138 n_percent_points,
139 1,
140 ]);
141 debug_assert_eq!(a.size(), &[batch_size, n_percent_points, 1]);
142
143 let z = z.gather(2, &a, false).squeeze_dim(-1);
145 debug_assert_eq!(z.size().as_slice(), &[batch_size, n_percent_points]);
146
147 let tgt: Tensor = reward + (1 - is_terminated) * self.discount_factor * z;
149 debug_assert_eq!(tgt.size().as_slice(), &[batch_size, n_percent_points]);
150
151 tgt.unsqueeze(-1)
152 });
153
154 let diff = tgt - pred;
155 debug_assert_eq!(
156 diff.size().as_slice(),
157 &[batch_size, n_percent_points_tgt, n_percent_points_pred]
158 );
159 let tau = tau.unsqueeze(1).repeat(&[1, n_percent_points_tgt, 1]);
163
164 quantile_huber_loss(&diff, &tau).mean(tch::Kind::Float)
165 };
166
167 self.iqn.backward_step(&loss);
168
169 f32::try_from(loss).expect("Failed to convert Tensor to f32")
170 }
171
172 fn opt_(&mut self, buffer: &mut R) -> Record {
173 let mut loss_critic = 0f32;
174
175 for _ in 0..self.n_updates_per_opt {
176 let loss = self.update_critic(buffer);
177 loss_critic += loss;
178 }
179
180 self.soft_update_counter += 1;
181 if self.soft_update_counter == self.soft_update_interval {
182 self.soft_update_counter = 0;
183 track(&mut self.iqn_tgt, &mut self.iqn, self.tau);
184 }
185
186 loss_critic /= self.n_updates_per_opt as f32;
187
188 self.n_opts += 1;
189
190 Record::from_slice(&[("loss_critic", RecordValue::Scalar(loss_critic))])
191 }
192}
193
194impl<E, F, M, R> Policy<E> for Iqn<E, F, M, R>
195where
196 E: Env,
197 F: SubModel<Output = Tensor>,
198 M: SubModel<Input = Tensor, Output = Tensor>,
199 E::Obs: Into<F::Input>,
200 E::Act: From<Tensor>,
201 F::Config: DeserializeOwned + Serialize + Clone,
202 M::Config: DeserializeOwned + Serialize + Clone + OutDim,
203{
204 fn sample(&mut self, obs: &E::Obs) -> E::Act {
205 let batch_size = 1;
207
208 let a = no_grad(|| {
209 let action_value = average(
210 batch_size,
211 &obs.clone().into(),
212 &self.iqn,
213 &self.sample_percents_act,
214 self.device,
215 );
216
217 if self.train {
218 match &mut self.explorer {
219 IqnExplorer::Softmax(softmax) => softmax.action(&action_value),
220 IqnExplorer::EpsilonGreedy(egreedy) => egreedy.action(action_value),
221 }
222 } else {
223 action_value.argmax(-1, true)
224 }
225 });
226
227 a.into()
228 }
229}
230
231impl<E, F, M, R> Configurable for Iqn<E, F, M, R>
232where
233 E: Env,
234 F: SubModel<Output = Tensor>,
235 M: SubModel<Input = Tensor, Output = Tensor>,
236 E::Obs: Into<F::Input>,
237 E::Act: From<Tensor>,
238 F::Config: DeserializeOwned + Serialize + Clone,
239 M::Config: DeserializeOwned + Serialize + Clone + OutDim,
240{
241 type Config = IqnConfig<F, M>;
242
243 fn build(config: Self::Config) -> Self {
245 let device = config
246 .device
247 .expect("No device is given for IQN agent")
248 .into();
249 let iqn = IqnModel::build(config.model_config, device).unwrap();
250 let iqn_tgt = iqn.clone();
251
252 Iqn {
253 iqn,
254 iqn_tgt,
255 soft_update_interval: config.soft_update_interval,
256 soft_update_counter: 0,
257 n_updates_per_opt: config.n_updates_per_opt,
258 batch_size: config.batch_size,
259 discount_factor: config.discount_factor,
260 tau: config.tau,
261 sample_percents_pred: config.sample_percents_pred,
262 sample_percents_tgt: config.sample_percents_tgt,
263 sample_percents_act: config.sample_percents_act,
264 train: config.train,
265 explorer: config.explorer,
266 device,
267 n_opts: 0,
268 phantom: PhantomData,
269 }
270 }
271}
272
273impl<E, F, M, R> Agent<E, R> for Iqn<E, F, M, R>
274where
275 E: Env + 'static,
276 F: SubModel<Output = Tensor> + 'static,
277 M: SubModel<Input = Tensor, Output = Tensor> + 'static,
278 R: ReplayBufferBase + 'static,
279 E::Obs: Into<F::Input>,
280 E::Act: From<Tensor>,
281 F::Config: DeserializeOwned + Serialize + Clone,
282 M::Config: DeserializeOwned + Serialize + Clone + OutDim,
283 R::Batch: TransitionBatch,
284 <R::Batch as TransitionBatch>::ObsBatch: Into<F::Input>,
285 <R::Batch as TransitionBatch>::ActBatch: Into<Tensor>,
286{
287 fn train(&mut self) {
288 self.train = true;
289 }
290
291 fn eval(&mut self) {
292 self.train = false;
293 }
294
295 fn is_train(&self) -> bool {
296 self.train
297 }
298
299 fn opt_with_record(&mut self, buffer: &mut R) -> Record {
300 self.opt_(buffer)
301 }
302
303 fn save_params(&self, path: &Path) -> Result<Vec<PathBuf>> {
304 fs::create_dir_all(&path)?;
306 let path1 = path.join("iqn.pt.tch").to_path_buf();
307 let path2 = path.join("iqn_tgt.pt.tch").to_path_buf();
308 self.iqn.save(&path1)?;
309 self.iqn_tgt.save(&path2)?;
310 Ok(vec![path1, path2])
311 }
312
313 fn load_params(&mut self, path: &Path) -> Result<()> {
314 self.iqn.load(path.join("iqn.pt.tch").as_path())?;
315 self.iqn_tgt.load(path.join("iqn_tgt.pt.tch").as_path())?;
316 Ok(())
317 }
318
319 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
320 self
321 }
322
323 fn as_any_ref(&self) -> &dyn std::any::Any {
324 self
325 }
326}