1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
//! Observation of [`BorderAtariEnv`](super::BorderAtariEnv).
//!
//! It applies the following preprocessing
//! (explanations are adapted from [Stable Baselines](https://stable-baselines3.readthedocs.io/en/master/common/atari_wrappers.html#stable_baselines3.common.atari_wrappers.AtariWrapper)
//! API document):
//! * (WIP: NoopReset: obtain initial state by taking random number of no-ops on reset.)
//! * Four frames skipping
//! * Max pooling: most recent two observations
//! * Resize to 84 x 84
//! * Grayscale
//! * Clip reward to {-1, 0, 1} in training
//! * Stacking four frames
//! It does not apply pixel scaling from 255 to 1.0 for saving memory of the replay buffer.
//! Instead, the scaling is applied in CNN model.
use anyhow::Result;
use border_core::{record::Record, Obs};
use serde::{Deserialize, Serialize};
use std::{default::Default, marker::PhantomData};
#[cfg(feature = "tch")]
use {std::convert::TryFrom, tch::Tensor};
/// Observation of [`BorderAtariEnv`](super::BorderAtariEnv).
#[derive(Debug, Clone)]
pub struct BorderAtariObs {
/// Four frames of 84 * 84 pixels.
pub frames: Vec<u8>,
}
impl From<Vec<u8>> for BorderAtariObs {
fn from(frames: Vec<u8>) -> Self {
Self { frames }
}
}
impl Obs for BorderAtariObs {
fn dummy(_n: usize) -> Self {
Self {
frames: vec![0; 4 * 84 * 84],
}
}
fn len(&self) -> usize {
1
}
}
#[cfg(feature = "tch")]
impl From<BorderAtariObs> for Tensor {
fn from(obs: BorderAtariObs) -> Tensor {
let tmp = &obs.frames;
// Assumes the batch size is 1, implying non-vectorized environment
Tensor::try_from(tmp).unwrap().reshape(&[1, 4, 1, 84, 84])
}
}
/// Converts [`BorderAtariObs`] to `O` with an arbitrary processing.
pub trait BorderAtariObsFilter<O: Obs> {
/// Configuration of the filter.
type Config: Clone + Default;
/// Constructs the filter given a configuration.
fn build(config: &Self::Config) -> Result<Self>
where
Self: Sized;
/// Converts the original observation into `O`.
fn filt(&mut self, obs: BorderAtariObs) -> (O, Record);
/// Resets the filter.
fn reset(&mut self, obs: BorderAtariObs) -> O {
let (obs, _) = self.filt(obs);
obs
}
}
#[derive(Serialize, Deserialize, Debug)]
/// Configuration of [`BorderAtariObsRawFilter`].
#[derive(Clone)]
pub struct BorderAtariObsRawFilterConfig;
impl Default for BorderAtariObsRawFilterConfig {
fn default() -> Self {
Self
}
}
/// A filter without any processing.
pub struct BorderAtariObsRawFilter<O> {
phantom: PhantomData<O>,
}
impl<O> BorderAtariObsFilter<O> for BorderAtariObsRawFilter<O>
where
O: Obs + From<BorderAtariObs>,
{
type Config = BorderAtariObsRawFilterConfig;
fn build(_config: &Self::Config) -> Result<Self> {
Ok(Self {
phantom: PhantomData,
})
}
fn filt(&mut self, obs: BorderAtariObs) -> (O, Record) {
(obs.into(), Record::empty())
}
}