use anyhow::Result;
use border_core::{record::Record, Act};
use serde::{Deserialize, Serialize};
use std::{default::Default, marker::PhantomData};
#[derive(Debug, Clone)]
pub struct BorderAtariAct {
pub act: u8,
}
impl BorderAtariAct {
pub fn new(act: u8) -> Self {
Self { act }
}
}
impl Act for BorderAtariAct {
fn len(&self) -> usize {
1
}
}
impl From<u8> for BorderAtariAct {
fn from(act: u8) -> Self {
Self { act }
}
}
#[cfg(feature = "candle")]
pub mod candle {
use super::*;
use border_candle_agent::TensorBatch;
use candle_core::{Device::Cpu, Tensor};
impl From<BorderAtariAct> for Tensor {
fn from(act: BorderAtariAct) -> Tensor {
Tensor::from_vec(vec![act.act as u8], &[1, 1], &Cpu).unwrap()
}
}
impl From<BorderAtariAct> for TensorBatch {
fn from(act: BorderAtariAct) -> Self {
let tensor = act.into();
TensorBatch::from_tensor(tensor)
}
}
impl From<Tensor> for BorderAtariAct {
fn from(t: Tensor) -> Self {
(t.to_vec1::<i64>().unwrap()[0] as u8).into()
}
}
}
#[cfg(feature = "tch")]
pub mod tch_ {
use super::*;
use border_tch_agent::TensorBatch;
use std::convert::TryInto;
use tch::Tensor;
impl From<BorderAtariAct> for Tensor {
fn from(act: BorderAtariAct) -> Tensor {
Tensor::from_slice(&[act.act as i64]).unsqueeze(-1)
}
}
impl From<BorderAtariAct> for TensorBatch {
fn from(act: BorderAtariAct) -> Self {
let tensor = act.into();
TensorBatch::from_tensor(tensor)
}
}
impl From<Tensor> for BorderAtariAct {
fn from(t: Tensor) -> Self {
(TryInto::<i64>::try_into(t).unwrap() as u8).into()
}
}
}
pub trait BorderAtariActFilter<A: Act> {
type Config: Clone + Default;
fn build(config: &Self::Config) -> Result<Self>
where
Self: Sized;
fn filt(&mut self, act: A) -> (BorderAtariAct, Record);
fn reset(&mut self, _is_done: &Option<&Vec<i8>>) {}
}
#[derive(Debug, Deserialize, Serialize)]
#[derive(Clone)]
pub struct BorderAtariActRawFilterConfig;
impl Default for BorderAtariActRawFilterConfig {
fn default() -> Self {
Self
}
}
pub struct BorderAtariActRawFilter<A> {
phantom: PhantomData<A>,
}
impl<A> BorderAtariActFilter<A> for BorderAtariActRawFilter<A>
where
A: Act + Into<BorderAtariAct>,
{
type Config = BorderAtariActRawFilterConfig;
fn build(_config: &Self::Config) -> Result<Self> {
Ok(Self {
phantom: PhantomData,
})
}
fn filt(&mut self, act: A) -> (BorderAtariAct, Record) {
(act.into(), Record::empty())
}
}