1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
//! Observation of [BorderAtariEnv](super::BorderAtariEnv).
//!
//! It applies the following preprocessing
//! (explanations are adapted from [Stable Baselines](https://stable-baselines3.readthedocs.io/en/master/common/atari_wrappers.html#stable_baselines3.common.atari_wrappers.AtariWrapper)
//! API document):
//! * (WIP: NoopReset: obtain initial state by taking random number of no-ops on reset.)
//! * Four frames skipping
//! * Max pooling: most recent two observations
//! * Resize to 84 x 84
//! * Grayscale
//! * Clip reward to {-1, 0, 1} in training
//! * Stacking four frames
//! It does not apply pixel scaling from 255 to 1.0 for saving memory of the replay buffer.
//! Instead, the scaling is applied in CNN model.
use anyhow::Result;
use border_core::{record::Record, Obs};
use serde::{Deserialize, Serialize};
use std::{default::Default, marker::PhantomData};
#[cfg(feature = "tch")]
use {std::convert::TryFrom, tch::Tensor};

/// Observation of [BorderAtariEnv](super::BorderAtariEnv).
#[derive(Debug, Clone)]
pub struct BorderAtariObs {
    /// Four frames of 84 * 84 pixels.
    pub frames: Vec<u8>,
}

impl From<Vec<u8>> for BorderAtariObs {
    fn from(frames: Vec<u8>) -> Self {
        Self { frames }
    }
}

impl Obs for BorderAtariObs {
    fn dummy(_n: usize) -> Self {
        Self {
            frames: vec![0; 4 * 84 * 84],
        }
    }

    fn merge(self, _obs_reset: Self, _is_done: &[i8]) -> Self {
        unimplemented!();
    }

    fn len(&self) -> usize {
        1
    }
}

#[cfg(feature = "tch")]
impl From<BorderAtariObs> for Tensor {
    fn from(obs: BorderAtariObs) -> Tensor {
        let tmp = &obs.frames;
        // Assumes the batch size is 1, implying non-vectorized environment
        Tensor::try_from(tmp).unwrap().reshape(&[1, 4, 1, 84, 84])
    }
}

/// Converts [BorderAtariObs] to `O` with an arbitrary processing.
pub trait BorderAtariObsFilter<O: Obs> {
    /// Configuration of the filter.
    type Config: Clone + Default;

    /// Constructs the filter given a configuration.
    fn build(config: &Self::Config) -> Result<Self>
    where
        Self: Sized;

    /// Converts the original observation into `O`.
    fn filt(&mut self, obs: BorderAtariObs) -> (O, Record);

    /// Resets the filter.
    fn reset(&mut self, obs: BorderAtariObs) -> O {
        let (obs, _) = self.filt(obs);
        obs
    }
}

#[derive(Serialize, Deserialize, Debug)]
/// Configuration of [BorderAtariObsRawFilter].
#[derive(Clone)]
pub struct BorderAtariObsRawFilterConfig;

impl Default for BorderAtariObsRawFilterConfig {
    fn default() -> Self {
        Self
    }
}

/// A filter without any processing.
pub struct BorderAtariObsRawFilter<O> {
    phantom: PhantomData<O>,
}

impl<O> BorderAtariObsFilter<O> for BorderAtariObsRawFilter<O>
where
    O: Obs + From<BorderAtariObs>,
{
    type Config = BorderAtariObsRawFilterConfig;

    fn build(_config: &Self::Config) -> Result<Self> {
        Ok(Self {
            phantom: PhantomData,
        })
    }

    fn filt(&mut self, obs: BorderAtariObs) -> (O, Record) {
        (obs.into(), Record::empty())
    }
}