border_tch_agent/iqn/
base.rs

1//! IQN agent implemented with tch-rs.
2use super::{average, IqnConfig, IqnExplorer, IqnModel, IqnSample};
3use crate::{
4    model::{ModelBase, SubModel},
5    util::{quantile_huber_loss, track, OutDim},
6};
7use anyhow::Result;
8use border_core::{
9    record::{Record, RecordValue},
10    Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch,
11};
12use log::trace;
13use serde::{de::DeserializeOwned, Serialize};
14use std::{
15    convert::TryFrom,
16    fs,
17    marker::PhantomData,
18    path::{Path, PathBuf},
19};
20use tch::{no_grad, Device, Tensor};
21
22/// IQN agent implemented with tch-rs.
23///
24/// The type parameter `M` is a feature extractor, which takes
25/// `M::Input` and returns feature vectors.
26pub struct Iqn<E, F, M, R>
27where
28    F: SubModel<Output = Tensor>,
29    M: SubModel<Input = Tensor, Output = Tensor>,
30    F::Config: DeserializeOwned + Serialize,
31    M::Config: DeserializeOwned + Serialize,
32{
33    pub(in crate::iqn) soft_update_interval: usize,
34    pub(in crate::iqn) soft_update_counter: usize,
35    pub(in crate::iqn) n_updates_per_opt: usize,
36    pub(in crate::iqn) batch_size: usize,
37    pub(in crate::iqn) iqn: IqnModel<F, M>,
38    pub(in crate::iqn) iqn_tgt: IqnModel<F, M>,
39    pub(in crate::iqn) train: bool,
40    pub(in crate::iqn) phantom: PhantomData<(E, R)>,
41    pub(in crate::iqn) discount_factor: f64,
42    pub(in crate::iqn) tau: f64,
43    pub(in crate::iqn) sample_percents_pred: IqnSample,
44    pub(in crate::iqn) sample_percents_tgt: IqnSample,
45    pub(in crate::iqn) sample_percents_act: IqnSample,
46    pub(in crate::iqn) explorer: IqnExplorer,
47    pub(in crate::iqn) device: Device,
48    pub(in crate::iqn) n_opts: usize,
49}
50
51impl<E, F, M, R> Iqn<E, F, M, R>
52where
53    E: Env,
54    F: SubModel<Output = Tensor>,
55    M: SubModel<Input = Tensor, Output = Tensor>,
56    R: ReplayBufferBase,
57    F::Config: DeserializeOwned + Serialize,
58    M::Config: DeserializeOwned + Serialize + OutDim,
59    R::Batch: TransitionBatch,
60    <R::Batch as TransitionBatch>::ObsBatch: Into<F::Input>,
61    <R::Batch as TransitionBatch>::ActBatch: Into<Tensor>,
62{
63    fn update_critic(&mut self, buffer: &mut R) -> f32 {
64        trace!("IQN::update_critic()");
65        let batch = buffer.batch(self.batch_size).unwrap();
66        let (obs, act, next_obs, reward, is_terminated, _is_truncated, _ixs, _weight) =
67            batch.unpack();
68        let obs = obs.into();
69        let act = act.into().to(self.device);
70        let next_obs = next_obs.into();
71        let reward = Tensor::from_slice(&reward[..])
72            .to(self.device)
73            .unsqueeze(-1);
74        let is_terminated = Tensor::from_slice(&is_terminated[..])
75            .to(self.device)
76            .unsqueeze(-1);
77
78        let batch_size = self.batch_size as _;
79        let n_percent_points_pred = self.sample_percents_pred.n_percent_points();
80        let n_percent_points_tgt = self.sample_percents_tgt.n_percent_points();
81
82        debug_assert_eq!(reward.size().as_slice(), &[batch_size, 1]);
83        debug_assert_eq!(is_terminated.size().as_slice(), &[batch_size, 1]);
84        debug_assert_eq!(act.size().as_slice(), &[batch_size, 1]);
85
86        let loss = {
87            // predictions of z(s, a), where a is from minibatch
88            // pred.size() == [batch_size, 1, n_percent_points]
89            let (pred, tau) = {
90                let n_percent_points = n_percent_points_pred;
91
92                // percent points
93                let tau = self.sample_percents_pred.sample(batch_size).to(self.device);
94                debug_assert_eq!(tau.size().as_slice(), &[batch_size, n_percent_points]);
95
96                // predictions for all actions
97                let z = self.iqn.forward(&obs, &tau);
98                let n_actions = z.size()[z.size().len() - 1];
99                debug_assert_eq!(
100                    z.size().as_slice(),
101                    &[batch_size, n_percent_points, n_actions]
102                );
103
104                // Reshape action for applying torch.gather
105                let a = act.unsqueeze(1).repeat(&[1, n_percent_points, 1]);
106                debug_assert_eq!(a.size().as_slice(), &[batch_size, n_percent_points, 1]);
107
108                // takes z(s, a) with a from minibatch
109                let pred = z.gather(-1, &a, false).squeeze_dim(-1).unsqueeze(1);
110                debug_assert_eq!(pred.size().as_slice(), &[batch_size, 1, n_percent_points]);
111                (pred, tau)
112            };
113
114            // target values with max_a q(s, a)
115            // tgt.size() == [batch_size, n_percent_points, 1]
116            // in theory, n_percent_points can be different with that for predictions
117            let tgt = no_grad(|| {
118                let n_percent_points = n_percent_points_tgt;
119
120                // percent points
121                let tau = self.sample_percents_tgt.sample(batch_size).to(self.device);
122                debug_assert_eq!(tau.size().as_slice(), &[batch_size, n_percent_points]);
123
124                // target values for all actions
125                let z = self.iqn_tgt.forward(&next_obs, &tau);
126                let n_actions = z.size()[z.size().len() - 1];
127                debug_assert_eq!(
128                    z.size().as_slice(),
129                    &[batch_size, n_percent_points, n_actions]
130                );
131
132                // argmax_a z(s,a), where z are averaged over tau
133                let y = z
134                    .copy()
135                    .mean_dim(Some([1].as_slice()), false, tch::Kind::Float);
136                let a = y.argmax(-1, false).unsqueeze(-1).unsqueeze(-1).repeat(&[
137                    1,
138                    n_percent_points,
139                    1,
140                ]);
141                debug_assert_eq!(a.size(), &[batch_size, n_percent_points, 1]);
142
143                // takes z(s, a)
144                let z = z.gather(2, &a, false).squeeze_dim(-1);
145                debug_assert_eq!(z.size().as_slice(), &[batch_size, n_percent_points]);
146
147                // target value
148                let tgt: Tensor = reward + (1 - is_terminated) * self.discount_factor * z;
149                debug_assert_eq!(tgt.size().as_slice(), &[batch_size, n_percent_points]);
150
151                tgt.unsqueeze(-1)
152            });
153
154            let diff = tgt - pred;
155            debug_assert_eq!(
156                diff.size().as_slice(),
157                &[batch_size, n_percent_points_tgt, n_percent_points_pred]
158            );
159            // need to convert diff to vec<f32>
160            // buffer.update_priority(&ixs, &Some(diff));
161
162            let tau = tau.unsqueeze(1).repeat(&[1, n_percent_points_tgt, 1]);
163
164            quantile_huber_loss(&diff, &tau).mean(tch::Kind::Float)
165        };
166
167        self.iqn.backward_step(&loss);
168
169        f32::try_from(loss).expect("Failed to convert Tensor to f32")
170    }
171
172    fn opt_(&mut self, buffer: &mut R) -> Record {
173        let mut loss_critic = 0f32;
174
175        for _ in 0..self.n_updates_per_opt {
176            let loss = self.update_critic(buffer);
177            loss_critic += loss;
178        }
179
180        self.soft_update_counter += 1;
181        if self.soft_update_counter == self.soft_update_interval {
182            self.soft_update_counter = 0;
183            track(&mut self.iqn_tgt, &mut self.iqn, self.tau);
184        }
185
186        loss_critic /= self.n_updates_per_opt as f32;
187
188        self.n_opts += 1;
189
190        Record::from_slice(&[("loss_critic", RecordValue::Scalar(loss_critic))])
191    }
192}
193
194impl<E, F, M, R> Policy<E> for Iqn<E, F, M, R>
195where
196    E: Env,
197    F: SubModel<Output = Tensor>,
198    M: SubModel<Input = Tensor, Output = Tensor>,
199    E::Obs: Into<F::Input>,
200    E::Act: From<Tensor>,
201    F::Config: DeserializeOwned + Serialize + Clone,
202    M::Config: DeserializeOwned + Serialize + Clone + OutDim,
203{
204    fn sample(&mut self, obs: &E::Obs) -> E::Act {
205        // Do not support vectorized env
206        let batch_size = 1;
207
208        let a = no_grad(|| {
209            let action_value = average(
210                batch_size,
211                &obs.clone().into(),
212                &self.iqn,
213                &self.sample_percents_act,
214                self.device,
215            );
216
217            if self.train {
218                match &mut self.explorer {
219                    IqnExplorer::Softmax(softmax) => softmax.action(&action_value),
220                    IqnExplorer::EpsilonGreedy(egreedy) => egreedy.action(action_value),
221                }
222            } else {
223                action_value.argmax(-1, true)
224            }
225        });
226
227        a.into()
228    }
229}
230
231impl<E, F, M, R> Configurable for Iqn<E, F, M, R>
232where
233    E: Env,
234    F: SubModel<Output = Tensor>,
235    M: SubModel<Input = Tensor, Output = Tensor>,
236    E::Obs: Into<F::Input>,
237    E::Act: From<Tensor>,
238    F::Config: DeserializeOwned + Serialize + Clone,
239    M::Config: DeserializeOwned + Serialize + Clone + OutDim,
240{
241    type Config = IqnConfig<F, M>;
242
243    /// Constructs [`Iqn`] agent.
244    fn build(config: Self::Config) -> Self {
245        let device = config
246            .device
247            .expect("No device is given for IQN agent")
248            .into();
249        let iqn = IqnModel::build(config.model_config, device).unwrap();
250        let iqn_tgt = iqn.clone();
251
252        Iqn {
253            iqn,
254            iqn_tgt,
255            soft_update_interval: config.soft_update_interval,
256            soft_update_counter: 0,
257            n_updates_per_opt: config.n_updates_per_opt,
258            batch_size: config.batch_size,
259            discount_factor: config.discount_factor,
260            tau: config.tau,
261            sample_percents_pred: config.sample_percents_pred,
262            sample_percents_tgt: config.sample_percents_tgt,
263            sample_percents_act: config.sample_percents_act,
264            train: config.train,
265            explorer: config.explorer,
266            device,
267            n_opts: 0,
268            phantom: PhantomData,
269        }
270    }
271}
272
273impl<E, F, M, R> Agent<E, R> for Iqn<E, F, M, R>
274where
275    E: Env + 'static,
276    F: SubModel<Output = Tensor> + 'static,
277    M: SubModel<Input = Tensor, Output = Tensor> + 'static,
278    R: ReplayBufferBase + 'static,
279    E::Obs: Into<F::Input>,
280    E::Act: From<Tensor>,
281    F::Config: DeserializeOwned + Serialize + Clone,
282    M::Config: DeserializeOwned + Serialize + Clone + OutDim,
283    R::Batch: TransitionBatch,
284    <R::Batch as TransitionBatch>::ObsBatch: Into<F::Input>,
285    <R::Batch as TransitionBatch>::ActBatch: Into<Tensor>,
286{
287    fn train(&mut self) {
288        self.train = true;
289    }
290
291    fn eval(&mut self) {
292        self.train = false;
293    }
294
295    fn is_train(&self) -> bool {
296        self.train
297    }
298
299    fn opt_with_record(&mut self, buffer: &mut R) -> Record {
300        self.opt_(buffer)
301    }
302
303    fn save_params(&self, path: &Path) -> Result<Vec<PathBuf>> {
304        // TODO: consider to rename the path if it already exists
305        fs::create_dir_all(&path)?;
306        let path1 = path.join("iqn.pt.tch").to_path_buf();
307        let path2 = path.join("iqn_tgt.pt.tch").to_path_buf();
308        self.iqn.save(&path1)?;
309        self.iqn_tgt.save(&path2)?;
310        Ok(vec![path1, path2])
311    }
312
313    fn load_params(&mut self, path: &Path) -> Result<()> {
314        self.iqn.load(path.join("iqn.pt.tch").as_path())?;
315        self.iqn_tgt.load(path.join("iqn_tgt.pt.tch").as_path())?;
316        Ok(())
317    }
318
319    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
320        self
321    }
322
323    fn as_any_ref(&self) -> &dyn std::any::Any {
324        self
325    }
326}