border_atari_env/util/
test.rs

1//! Utilities for test.
2use crate::{
3    BorderAtariAct, BorderAtariActRawFilter, BorderAtariEnv, BorderAtariEnvConfig, BorderAtariObs,
4    BorderAtariObsRawFilter,
5};
6use anyhow::Result;
7use border_core::{
8    generic_replay_buffer::{BatchBase, SimpleReplayBuffer},
9    record::Record,
10    Agent as Agent_, Configurable, Policy, ReplayBufferBase,
11};
12use serde::Deserialize;
13use std::ptr::copy;
14
15pub type Obs = BorderAtariObs;
16pub type Act = BorderAtariAct;
17pub type ObsFilter = BorderAtariObsRawFilter<Obs>;
18pub type ActFilter = BorderAtariActRawFilter<Act>;
19pub type EnvConfig = BorderAtariEnvConfig<Obs, Act, ObsFilter, ActFilter>;
20pub type ReplayBuffer = SimpleReplayBuffer<ObsBatch, ActBatch>;
21pub type Env = BorderAtariEnv<Obs, Act, ObsFilter, ActFilter>;
22pub type Agent = RandomAgent;
23
24const FRAME_IN_BYTES: usize = 84 * 84;
25
26/// Consists the observation part of a batch in [SimpleReplayBuffer].
27pub struct ObsBatch {
28    /// The number of samples in the batch.
29    pub n: usize,
30
31    /// The length of a sample in bytes.
32    pub m: usize,
33
34    /// The buffer.
35    pub buf: Vec<u8>,
36}
37
38impl BatchBase for ObsBatch {
39    fn new(capacity: usize) -> Self {
40        let m = 4 * FRAME_IN_BYTES;
41        Self {
42            n: 0,
43            m,
44            buf: vec![0; capacity * m],
45        }
46    }
47
48    #[inline]
49    fn push(&mut self, i: usize, data: Self) {
50        unsafe {
51            let src: *const u8 = &data.buf[0];
52            let dst: *mut u8 = &mut self.buf[i * self.m];
53            copy(src, dst, self.m);
54        }
55    }
56
57    fn sample(&self, ixs: &Vec<usize>) -> Self {
58        let n = ixs.len();
59        let m = self.m;
60        let mut buf = vec![0; n];
61        (0..n).enumerate().for_each(|(i, ix)| unsafe {
62            let src: *const u8 = &self.buf[ix];
63            let dst: *mut u8 = &mut buf[i * self.m];
64            copy(src, dst, self.m);
65        });
66
67        Self { m, n, buf }
68    }
69}
70
71impl From<Obs> for ObsBatch {
72    fn from(obs: Obs) -> Self {
73        Self {
74            n: 1,
75            m: 4 * FRAME_IN_BYTES,
76            buf: obs.frames,
77        }
78    }
79}
80
81/// Consists the action part of a batch in [SimpleReplayBuffer].
82pub struct ActBatch {
83    /// The number of samples in the batch.
84    pub n: usize,
85
86    /// The length of a sample in bytes.
87    pub m: usize,
88
89    /// The buffer.
90    pub buf: Vec<u8>,
91}
92
93impl BatchBase for ActBatch {
94    fn new(capacity: usize) -> Self {
95        let m = 1;
96        Self {
97            n: 0,
98            m,
99            buf: vec![0; capacity * m],
100        }
101    }
102
103    #[inline]
104    fn push(&mut self, i: usize, data: Self) {
105        unsafe {
106            let src: *const u8 = &data.buf[0];
107            let dst: *mut u8 = &mut self.buf[i * self.m];
108            copy(src, dst, self.m);
109        }
110    }
111
112    fn sample(&self, ixs: &Vec<usize>) -> Self {
113        let n = ixs.len();
114        let m = self.m;
115        let mut buf = vec![0; n];
116        (0..n).enumerate().for_each(|(i, ix)| unsafe {
117            let src: *const u8 = &self.buf[ix];
118            let dst: *mut u8 = &mut buf[i * self.m];
119            copy(src, dst, self.m);
120        });
121
122        Self { m, n, buf }
123    }
124}
125
126impl From<Act> for ActBatch {
127    fn from(act: Act) -> Self {
128        Self {
129            n: 1,
130            m: 1,
131            buf: vec![act.act],
132        }
133    }
134}
135
136#[derive(Clone, Deserialize)]
137/// Configuration of [`RandomAgent``].
138pub struct RandomAgentConfig {
139    pub n_acts: usize,
140}
141
142/// A random policy.
143pub struct RandomAgent {
144    n_acts: usize,
145    n_opts_steps: usize,
146    train: bool,
147}
148
149impl Policy<Env> for RandomAgent {
150    fn sample(&mut self, _: &Obs) -> Act {
151        fastrand::u8(..self.n_acts as u8).into()
152    }
153}
154
155impl Configurable for RandomAgent {
156    type Config = RandomAgentConfig;
157
158    fn build(config: Self::Config) -> Self {
159        Self {
160            n_acts: config.n_acts,
161            n_opts_steps: 0,
162            train: true,
163        }
164    }
165}
166
167impl<R: ReplayBufferBase> Agent_<Env, R> for RandomAgent {
168    fn train(&mut self) {
169        self.train = true;
170    }
171
172    fn eval(&mut self) {
173        self.train = false;
174    }
175
176    fn is_train(&self) -> bool {
177        self.train
178    }
179
180    fn opt_with_record(&mut self, _buffer: &mut R) -> border_core::record::Record {
181        // Do nothing
182        self.n_opts_steps += 1;
183        Record::empty()
184    }
185
186    fn save_params(&self, _path: &std::path::Path) -> Result<Vec<std::path::PathBuf>> {
187        println!("save() was invoked");
188        Ok(vec![])
189    }
190
191    fn load_params(&mut self, _path: &std::path::Path) -> Result<()> {
192        println!("load() was invoked");
193        Ok(())
194    }
195
196    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
197        self
198    }
199
200    fn as_any_ref(&self) -> &dyn std::any::Any {
201        self
202    }
203}
204
205impl RandomAgent {
206    /// Returns the number of optimization steps;
207    pub fn n_opts_steps(&self) -> usize {
208        self.n_opts_steps
209    }
210}
211
212/// Returns the default configuration of [BorderAtariEnv].
213pub fn env_config(name: String) -> EnvConfig {
214    EnvConfig::default().name(name)
215}
216
217// fn main() -> Result<()> {
218//     env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
219//     fastrand::seed(42);
220
221//     let env_config = env_config("pong".to_string());
222//     let mut env = Env::build(&env_config, 42)?;
223//     let mut recorder = BufferedRecorder::new();
224//     let n_acts = env.get_num_actions_atari();
225//     let policy_config = RandomPolicyConfig {
226//         n_acts: n_acts as _,
227//     };
228//     let mut policy = RandomPolicy::build(policy_config);
229
230//     env.open()?;
231//     let _ = util::eval_with_recorder(&mut env, &mut policy, 5, &mut recorder)?;
232
233//     Ok(())
234// }