border_atari_env/
obs.rs

1//! Observation of [`BorderAtariEnv`](super::BorderAtariEnv).
2//!
3//! It applies the following preprocessing
4//! (explanations are adapted from [Stable Baselines](https://stable-baselines3.readthedocs.io/en/master/common/atari_wrappers.html#stable_baselines3.common.atari_wrappers.AtariWrapper)
5//! API document):
6//! * (WIP: NoopReset: obtain initial state by taking random number of no-ops on reset.)
7//! * Four frames skipping
8//! * Max pooling: most recent two observations
9//! * Resize to 84 x 84
10//! * Grayscale
11//! * Clip reward to {-1, 0, 1} in training
12//! * Stacking four frames
13//! It does not apply pixel scaling from 255 to 1.0 for saving memory of the replay buffer.
14//! Instead, the scaling is applied in CNN model.
15use anyhow::Result;
16use border_core::{record::Record, Obs};
17use serde::{Deserialize, Serialize};
18use std::{default::Default, marker::PhantomData};
19
20/// Observation of [`BorderAtariEnv`](super::BorderAtariEnv).
21#[derive(Debug, Clone)]
22pub struct BorderAtariObs {
23    /// Four frames of 84 * 84 pixels.
24    pub frames: Vec<u8>,
25}
26
27impl From<Vec<u8>> for BorderAtariObs {
28    fn from(frames: Vec<u8>) -> Self {
29        Self { frames }
30    }
31}
32
33impl Obs for BorderAtariObs {
34    fn len(&self) -> usize {
35        1
36    }
37}
38
39#[cfg(feature = "tch")]
40pub mod tch_ {
41    use super::*;
42    use border_tch_agent::TensorBatch;
43    use tch::Tensor;
44
45    impl From<BorderAtariObs> for Tensor {
46        fn from(obs: BorderAtariObs) -> Tensor {
47            // Assumes the batch size is 1, implying non-vectorized environment
48            Tensor::from_slice(&obs.frames)
49                .reshape(&[1, 4, 1, 84, 84])
50                .to_kind(tch::Kind::Float)
51        }
52    }
53
54    impl From<BorderAtariObs> for TensorBatch {
55        fn from(obs: BorderAtariObs) -> Self {
56            let tensor = obs.into();
57            TensorBatch::from_tensor(tensor)
58        }
59    }
60}
61
62#[cfg(feature = "candle")]
63pub mod canadle {
64    use super::*;
65    use border_candle_agent::TensorBatch;
66    use candle_core::{Device::Cpu, Tensor};
67
68    impl From<BorderAtariObs> for Tensor {
69        fn from(obs: BorderAtariObs) -> Tensor {
70            let tmp = obs.frames;
71            // Assumes the batch size is 1, implying non-vectorized environment
72            Tensor::from_vec(tmp, &[1 * 4 * 1 * 84 * 84], &Cpu)
73                .unwrap()
74                .reshape(&[1, 4, 1, 84, 84])
75                .unwrap()
76        }
77    }
78
79    impl From<BorderAtariObs> for TensorBatch {
80        fn from(obs: BorderAtariObs) -> Self {
81            let tensor = obs.into();
82            TensorBatch::from_tensor(tensor)
83        }
84    }
85}
86
87/// Converts [`BorderAtariObs`] to observation of type `O` with an arbitrary processing.
88pub trait BorderAtariObsFilter<O: Obs> {
89    /// Configuration of the filter.
90    type Config: Clone + Default;
91
92    /// Constructs the filter given a configuration.
93    fn build(config: &Self::Config) -> Result<Self>
94    where
95        Self: Sized;
96
97    /// Converts the original observation into `O`.
98    fn filt(&mut self, obs: BorderAtariObs) -> (O, Record);
99
100    /// Resets the filter.
101    fn reset(&mut self, obs: BorderAtariObs) -> O {
102        let (obs, _) = self.filt(obs);
103        obs
104    }
105}
106
107#[derive(Serialize, Deserialize, Debug)]
108/// Configuration of [`BorderAtariObsRawFilter`].
109#[derive(Clone)]
110pub struct BorderAtariObsRawFilterConfig;
111
112impl Default for BorderAtariObsRawFilterConfig {
113    fn default() -> Self {
114        Self
115    }
116}
117
118/// A filter that performs no processing.
119pub struct BorderAtariObsRawFilter<O> {
120    phantom: PhantomData<O>,
121}
122
123impl<O> BorderAtariObsFilter<O> for BorderAtariObsRawFilter<O>
124where
125    O: Obs + From<BorderAtariObs>,
126{
127    type Config = BorderAtariObsRawFilterConfig;
128
129    fn build(_config: &Self::Config) -> Result<Self> {
130        Ok(Self {
131            phantom: PhantomData,
132        })
133    }
134
135    fn filt(&mut self, obs: BorderAtariObs) -> (O, Record) {
136        (obs.into(), Record::empty())
137    }
138}