border_atari_env/util/
test.rs1use crate::{
3 BorderAtariAct, BorderAtariActRawFilter, BorderAtariEnv, BorderAtariEnvConfig, BorderAtariObs,
4 BorderAtariObsRawFilter,
5};
6use anyhow::Result;
7use border_core::{
8 generic_replay_buffer::{BatchBase, SimpleReplayBuffer},
9 record::Record,
10 Agent as Agent_, Configurable, Policy, ReplayBufferBase,
11};
12use serde::Deserialize;
13use std::ptr::copy;
14
15pub type Obs = BorderAtariObs;
16pub type Act = BorderAtariAct;
17pub type ObsFilter = BorderAtariObsRawFilter<Obs>;
18pub type ActFilter = BorderAtariActRawFilter<Act>;
19pub type EnvConfig = BorderAtariEnvConfig<Obs, Act, ObsFilter, ActFilter>;
20pub type ReplayBuffer = SimpleReplayBuffer<ObsBatch, ActBatch>;
21pub type Env = BorderAtariEnv<Obs, Act, ObsFilter, ActFilter>;
22pub type Agent = RandomAgent;
23
24const FRAME_IN_BYTES: usize = 84 * 84;
25
26pub struct ObsBatch {
28 pub n: usize,
30
31 pub m: usize,
33
34 pub buf: Vec<u8>,
36}
37
38impl BatchBase for ObsBatch {
39 fn new(capacity: usize) -> Self {
40 let m = 4 * FRAME_IN_BYTES;
41 Self {
42 n: 0,
43 m,
44 buf: vec![0; capacity * m],
45 }
46 }
47
48 #[inline]
49 fn push(&mut self, i: usize, data: Self) {
50 unsafe {
51 let src: *const u8 = &data.buf[0];
52 let dst: *mut u8 = &mut self.buf[i * self.m];
53 copy(src, dst, self.m);
54 }
55 }
56
57 fn sample(&self, ixs: &Vec<usize>) -> Self {
58 let n = ixs.len();
59 let m = self.m;
60 let mut buf = vec![0; n];
61 (0..n).enumerate().for_each(|(i, ix)| unsafe {
62 let src: *const u8 = &self.buf[ix];
63 let dst: *mut u8 = &mut buf[i * self.m];
64 copy(src, dst, self.m);
65 });
66
67 Self { m, n, buf }
68 }
69}
70
71impl From<Obs> for ObsBatch {
72 fn from(obs: Obs) -> Self {
73 Self {
74 n: 1,
75 m: 4 * FRAME_IN_BYTES,
76 buf: obs.frames,
77 }
78 }
79}
80
81pub struct ActBatch {
83 pub n: usize,
85
86 pub m: usize,
88
89 pub buf: Vec<u8>,
91}
92
93impl BatchBase for ActBatch {
94 fn new(capacity: usize) -> Self {
95 let m = 1;
96 Self {
97 n: 0,
98 m,
99 buf: vec![0; capacity * m],
100 }
101 }
102
103 #[inline]
104 fn push(&mut self, i: usize, data: Self) {
105 unsafe {
106 let src: *const u8 = &data.buf[0];
107 let dst: *mut u8 = &mut self.buf[i * self.m];
108 copy(src, dst, self.m);
109 }
110 }
111
112 fn sample(&self, ixs: &Vec<usize>) -> Self {
113 let n = ixs.len();
114 let m = self.m;
115 let mut buf = vec![0; n];
116 (0..n).enumerate().for_each(|(i, ix)| unsafe {
117 let src: *const u8 = &self.buf[ix];
118 let dst: *mut u8 = &mut buf[i * self.m];
119 copy(src, dst, self.m);
120 });
121
122 Self { m, n, buf }
123 }
124}
125
126impl From<Act> for ActBatch {
127 fn from(act: Act) -> Self {
128 Self {
129 n: 1,
130 m: 1,
131 buf: vec![act.act],
132 }
133 }
134}
135
136#[derive(Clone, Deserialize)]
137pub struct RandomAgentConfig {
139 pub n_acts: usize,
140}
141
142pub struct RandomAgent {
144 n_acts: usize,
145 n_opts_steps: usize,
146 train: bool,
147}
148
149impl Policy<Env> for RandomAgent {
150 fn sample(&mut self, _: &Obs) -> Act {
151 fastrand::u8(..self.n_acts as u8).into()
152 }
153}
154
155impl Configurable for RandomAgent {
156 type Config = RandomAgentConfig;
157
158 fn build(config: Self::Config) -> Self {
159 Self {
160 n_acts: config.n_acts,
161 n_opts_steps: 0,
162 train: true,
163 }
164 }
165}
166
167impl<R: ReplayBufferBase> Agent_<Env, R> for RandomAgent {
168 fn train(&mut self) {
169 self.train = true;
170 }
171
172 fn eval(&mut self) {
173 self.train = false;
174 }
175
176 fn is_train(&self) -> bool {
177 self.train
178 }
179
180 fn opt_with_record(&mut self, _buffer: &mut R) -> border_core::record::Record {
181 self.n_opts_steps += 1;
183 Record::empty()
184 }
185
186 fn save_params(&self, _path: &std::path::Path) -> Result<Vec<std::path::PathBuf>> {
187 println!("save() was invoked");
188 Ok(vec![])
189 }
190
191 fn load_params(&mut self, _path: &std::path::Path) -> Result<()> {
192 println!("load() was invoked");
193 Ok(())
194 }
195
196 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
197 self
198 }
199
200 fn as_any_ref(&self) -> &dyn std::any::Any {
201 self
202 }
203}
204
205impl RandomAgent {
206 pub fn n_opts_steps(&self) -> usize {
208 self.n_opts_steps
209 }
210}
211
212pub fn env_config(name: String) -> EnvConfig {
214 EnvConfig::default().name(name)
215}
216
217