1use anyhow::Result;
3use border_core::{record::Record, Act};
4use serde::{Deserialize, Serialize};
5use std::{default::Default, marker::PhantomData};
6
7#[derive(Debug, Clone)]
8pub struct BorderAtariAct {
12 pub act: u8,
13}
14
15impl BorderAtariAct {
16 pub fn new(act: u8) -> Self {
17 Self { act }
18 }
19}
20
21impl Act for BorderAtariAct {
22 fn len(&self) -> usize {
23 1
24 }
25}
26
27impl From<u8> for BorderAtariAct {
28 fn from(act: u8) -> Self {
29 Self { act }
30 }
31}
32
33#[cfg(feature = "candle")]
34pub mod candle {
35 use super::*;
36 use border_candle_agent::TensorBatch;
37 use candle_core::{Device::Cpu, Tensor};
38
39 impl From<BorderAtariAct> for Tensor {
40 fn from(act: BorderAtariAct) -> Tensor {
41 Tensor::from_vec(vec![act.act as u8], &[1, 1], &Cpu).unwrap()
42 }
43 }
44
45 impl From<BorderAtariAct> for TensorBatch {
46 fn from(act: BorderAtariAct) -> Self {
47 let tensor = act.into();
48 TensorBatch::from_tensor(tensor)
49 }
50 }
51
52 impl From<Tensor> for BorderAtariAct {
53 fn from(t: Tensor) -> Self {
55 (t.to_vec1::<i64>().unwrap()[0] as u8).into()
56 }
57 }
58}
59
60#[cfg(feature = "tch")]
61pub mod tch_ {
62 use super::*;
63 use border_tch_agent::TensorBatch;
64 use std::convert::TryInto;
65 use tch::Tensor;
66
67 impl From<BorderAtariAct> for Tensor {
68 fn from(act: BorderAtariAct) -> Tensor {
69 Tensor::from_slice(&[act.act as i64]).unsqueeze(-1)
70 }
71 }
72
73 impl From<BorderAtariAct> for TensorBatch {
74 fn from(act: BorderAtariAct) -> Self {
75 let tensor = act.into();
76 TensorBatch::from_tensor(tensor)
77 }
78 }
79
80 impl From<Tensor> for BorderAtariAct {
81 fn from(t: Tensor) -> Self {
83 (TryInto::<i64>::try_into(t).unwrap() as u8).into()
84 }
85 }
86}
87
88pub trait BorderAtariActFilter<A: Act> {
90 type Config: Clone + Default;
92
93 fn build(config: &Self::Config) -> Result<Self>
95 where
96 Self: Sized;
97
98 fn filt(&mut self, act: A) -> (BorderAtariAct, Record);
100
101 fn reset(&mut self, _is_done: &Option<&Vec<i8>>) {}
103}
104
105#[derive(Debug, Deserialize, Serialize)]
106#[derive(Clone)]
108pub struct BorderAtariActRawFilterConfig;
109
110impl Default for BorderAtariActRawFilterConfig {
111 fn default() -> Self {
112 Self
113 }
114}
115
116pub struct BorderAtariActRawFilter<A> {
118 phantom: PhantomData<A>,
119}
120
121impl<A> BorderAtariActFilter<A> for BorderAtariActRawFilter<A>
122where
123 A: Act + Into<BorderAtariAct>,
124{
125 type Config = BorderAtariActRawFilterConfig;
126
127 fn build(_config: &Self::Config) -> Result<Self> {
128 Ok(Self {
129 phantom: PhantomData,
130 })
131 }
132
133 fn filt(&mut self, act: A) -> (BorderAtariAct, Record) {
134 (act.into(), Record::empty())
135 }
136}