border_atari_env/
act.rs

1//! Action for [BorderAtariEnv](crate::BorderAtariEnv)
2use anyhow::Result;
3use border_core::{record::Record, Act};
4use serde::{Deserialize, Serialize};
5use std::{default::Default, marker::PhantomData};
6
7#[derive(Debug, Clone)]
8/// Action for [`BorderAtariEnv`](crate::BorderAtariEnv).
9///
10/// This action is a discrete action and denotes pushing a button.
11pub struct BorderAtariAct {
12    pub act: u8,
13}
14
15impl BorderAtariAct {
16    pub fn new(act: u8) -> Self {
17        Self { act }
18    }
19}
20
21impl Act for BorderAtariAct {
22    fn len(&self) -> usize {
23        1
24    }
25}
26
27impl From<u8> for BorderAtariAct {
28    fn from(act: u8) -> Self {
29        Self { act }
30    }
31}
32
33#[cfg(feature = "candle")]
34pub mod candle {
35    use super::*;
36    use border_candle_agent::TensorBatch;
37    use candle_core::{Device::Cpu, Tensor};
38
39    impl From<BorderAtariAct> for Tensor {
40        fn from(act: BorderAtariAct) -> Tensor {
41            Tensor::from_vec(vec![act.act as u8], &[1, 1], &Cpu).unwrap()
42        }
43    }
44
45    impl From<BorderAtariAct> for TensorBatch {
46        fn from(act: BorderAtariAct) -> Self {
47            let tensor = act.into();
48            TensorBatch::from_tensor(tensor)
49        }
50    }
51
52    impl From<Tensor> for BorderAtariAct {
53        /// `t` must have single item.
54        fn from(t: Tensor) -> Self {
55            (t.to_vec1::<i64>().unwrap()[0] as u8).into()
56        }
57    }
58}
59
60#[cfg(feature = "tch")]
61pub mod tch_ {
62    use super::*;
63    use border_tch_agent::TensorBatch;
64    use std::convert::TryInto;
65    use tch::Tensor;
66
67    impl From<BorderAtariAct> for Tensor {
68        fn from(act: BorderAtariAct) -> Tensor {
69            Tensor::from_slice(&[act.act as i64]).unsqueeze(-1)
70        }
71    }
72
73    impl From<BorderAtariAct> for TensorBatch {
74        fn from(act: BorderAtariAct) -> Self {
75            let tensor = act.into();
76            TensorBatch::from_tensor(tensor)
77        }
78    }
79
80    impl From<Tensor> for BorderAtariAct {
81        /// `t` must have single item.
82        fn from(t: Tensor) -> Self {
83            (TryInto::<i64>::try_into(t).unwrap() as u8).into()
84        }
85    }
86}
87
88/// Converts action of type `A` to [`BorderAtariAct`].
89pub trait BorderAtariActFilter<A: Act> {
90    /// Configuration of the filter.
91    type Config: Clone + Default;
92
93    /// Constructs the filter given a configuration.
94    fn build(config: &Self::Config) -> Result<Self>
95    where
96        Self: Sized;
97
98    /// Converts `A` into an action of [BorderAtariAct].
99    fn filt(&mut self, act: A) -> (BorderAtariAct, Record);
100
101    /// Resets the filter. Does nothing in the default implementation.
102    fn reset(&mut self, _is_done: &Option<&Vec<i8>>) {}
103}
104
105#[derive(Debug, Deserialize, Serialize)]
106/// Configuration of [`BorderAtariActRawFilter`].
107#[derive(Clone)]
108pub struct BorderAtariActRawFilterConfig;
109
110impl Default for BorderAtariActRawFilterConfig {
111    fn default() -> Self {
112        Self
113    }
114}
115
116/// A filter that performs no processing.
117pub struct BorderAtariActRawFilter<A> {
118    phantom: PhantomData<A>,
119}
120
121impl<A> BorderAtariActFilter<A> for BorderAtariActRawFilter<A>
122where
123    A: Act + Into<BorderAtariAct>,
124{
125    type Config = BorderAtariActRawFilterConfig;
126
127    fn build(_config: &Self::Config) -> Result<Self> {
128        Ok(Self {
129            phantom: PhantomData,
130        })
131    }
132
133    fn filt(&mut self, act: A) -> (BorderAtariAct, Record) {
134        (act.into(), Record::empty())
135    }
136}