use anyhow::Result;
use border_core::{record::Record, Obs};
use serde::{Deserialize, Serialize};
use std::{default::Default, marker::PhantomData};
#[derive(Debug, Clone)]
pub struct BorderAtariObs {
pub frames: Vec<u8>,
}
impl From<Vec<u8>> for BorderAtariObs {
fn from(frames: Vec<u8>) -> Self {
Self { frames }
}
}
impl Obs for BorderAtariObs {
fn len(&self) -> usize {
1
}
}
#[cfg(feature = "tch")]
pub mod tch_ {
use super::*;
use border_tch_agent::TensorBatch;
use tch::Tensor;
impl From<BorderAtariObs> for Tensor {
fn from(obs: BorderAtariObs) -> Tensor {
Tensor::from_slice(&obs.frames)
.reshape(&[1, 4, 1, 84, 84])
.to_kind(tch::Kind::Float)
}
}
impl From<BorderAtariObs> for TensorBatch {
fn from(obs: BorderAtariObs) -> Self {
let tensor = obs.into();
TensorBatch::from_tensor(tensor)
}
}
}
#[cfg(feature = "candle")]
pub mod canadle {
use super::*;
use border_candle_agent::TensorBatch;
use candle_core::{Device::Cpu, Tensor};
impl From<BorderAtariObs> for Tensor {
fn from(obs: BorderAtariObs) -> Tensor {
let tmp = obs.frames;
Tensor::from_vec(tmp, &[1 * 4 * 1 * 84 * 84], &Cpu)
.unwrap()
.reshape(&[1, 4, 1, 84, 84])
.unwrap()
}
}
impl From<BorderAtariObs> for TensorBatch {
fn from(obs: BorderAtariObs) -> Self {
let tensor = obs.into();
TensorBatch::from_tensor(tensor)
}
}
}
pub trait BorderAtariObsFilter<O: Obs> {
type Config: Clone + Default;
fn build(config: &Self::Config) -> Result<Self>
where
Self: Sized;
fn filt(&mut self, obs: BorderAtariObs) -> (O, Record);
fn reset(&mut self, obs: BorderAtariObs) -> O {
let (obs, _) = self.filt(obs);
obs
}
}
#[derive(Serialize, Deserialize, Debug)]
#[derive(Clone)]
pub struct BorderAtariObsRawFilterConfig;
impl Default for BorderAtariObsRawFilterConfig {
fn default() -> Self {
Self
}
}
pub struct BorderAtariObsRawFilter<O> {
phantom: PhantomData<O>,
}
impl<O> BorderAtariObsFilter<O> for BorderAtariObsRawFilter<O>
where
O: Obs + From<BorderAtariObs>,
{
type Config = BorderAtariObsRawFilterConfig;
fn build(_config: &Self::Config) -> Result<Self> {
Ok(Self {
phantom: PhantomData,
})
}
fn filt(&mut self, obs: BorderAtariObs) -> (O, Record) {
(obs.into(), Record::empty())
}
}