1use anyhow::Result;
16use border_core::{record::Record, Obs};
17use serde::{Deserialize, Serialize};
18use std::{default::Default, marker::PhantomData};
19
20#[derive(Debug, Clone)]
22pub struct BorderAtariObs {
23 pub frames: Vec<u8>,
25}
26
27impl From<Vec<u8>> for BorderAtariObs {
28 fn from(frames: Vec<u8>) -> Self {
29 Self { frames }
30 }
31}
32
33impl Obs for BorderAtariObs {
34 fn len(&self) -> usize {
35 1
36 }
37}
38
39#[cfg(feature = "tch")]
40pub mod tch_ {
41 use super::*;
42 use border_tch_agent::TensorBatch;
43 use tch::Tensor;
44
45 impl From<BorderAtariObs> for Tensor {
46 fn from(obs: BorderAtariObs) -> Tensor {
47 Tensor::from_slice(&obs.frames)
49 .reshape(&[1, 4, 1, 84, 84])
50 .to_kind(tch::Kind::Float)
51 }
52 }
53
54 impl From<BorderAtariObs> for TensorBatch {
55 fn from(obs: BorderAtariObs) -> Self {
56 let tensor = obs.into();
57 TensorBatch::from_tensor(tensor)
58 }
59 }
60}
61
62#[cfg(feature = "candle")]
63pub mod canadle {
64 use super::*;
65 use border_candle_agent::TensorBatch;
66 use candle_core::{Device::Cpu, Tensor};
67
68 impl From<BorderAtariObs> for Tensor {
69 fn from(obs: BorderAtariObs) -> Tensor {
70 let tmp = obs.frames;
71 Tensor::from_vec(tmp, &[1 * 4 * 1 * 84 * 84], &Cpu)
73 .unwrap()
74 .reshape(&[1, 4, 1, 84, 84])
75 .unwrap()
76 }
77 }
78
79 impl From<BorderAtariObs> for TensorBatch {
80 fn from(obs: BorderAtariObs) -> Self {
81 let tensor = obs.into();
82 TensorBatch::from_tensor(tensor)
83 }
84 }
85}
86
87pub trait BorderAtariObsFilter<O: Obs> {
89 type Config: Clone + Default;
91
92 fn build(config: &Self::Config) -> Result<Self>
94 where
95 Self: Sized;
96
97 fn filt(&mut self, obs: BorderAtariObs) -> (O, Record);
99
100 fn reset(&mut self, obs: BorderAtariObs) -> O {
102 let (obs, _) = self.filt(obs);
103 obs
104 }
105}
106
107#[derive(Serialize, Deserialize, Debug)]
108#[derive(Clone)]
110pub struct BorderAtariObsRawFilterConfig;
111
112impl Default for BorderAtariObsRawFilterConfig {
113 fn default() -> Self {
114 Self
115 }
116}
117
118pub struct BorderAtariObsRawFilter<O> {
120 phantom: PhantomData<O>,
121}
122
123impl<O> BorderAtariObsFilter<O> for BorderAtariObsRawFilter<O>
124where
125 O: Obs + From<BorderAtariObs>,
126{
127 type Config = BorderAtariObsRawFilterConfig;
128
129 fn build(_config: &Self::Config) -> Result<Self> {
130 Ok(Self {
131 phantom: PhantomData,
132 })
133 }
134
135 fn filt(&mut self, obs: BorderAtariObs) -> (O, Record) {
136 (obs.into(), Record::empty())
137 }
138}