1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
use std::{env, default::Default};
use border_core::{Obs, Act};
use super::{BorderAtariObsFilter, BorderAtariActFilter};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug)]
pub struct BorderAtariEnvConfig<O, A, OF, AF>
where
O: Obs,
A: Act,
OF: BorderAtariObsFilter<O>,
AF: BorderAtariActFilter<A>,
{
pub(super) rom_dir: String,
pub(super) name: String,
pub(super) obs_filter_config: OF::Config,
pub(super) act_filter_config: AF::Config,
pub(super) train: bool,
}
impl<O, A, OF, AF> Clone for BorderAtariEnvConfig<O, A, OF, AF>
where
O: Obs,
A: Act,
OF: BorderAtariObsFilter<O>,
AF: BorderAtariActFilter<A>,
{
fn clone(&self) -> Self {
Self {
rom_dir: self.rom_dir.clone(),
name: self.name.clone(),
obs_filter_config: self.obs_filter_config.clone(),
act_filter_config: self.act_filter_config.clone(),
train: self.train,
}
}
}
impl<O, A, OF, AF> Default for BorderAtariEnvConfig<O, A, OF, AF>
where
O: Obs,
A: Act,
OF: BorderAtariObsFilter<O>,
AF: BorderAtariActFilter<A>,
{
fn default() -> Self {
let rom_dir = if let Ok(var) = env::var("ATARI_ROM_DIR") {
var
} else {
"".to_string()
};
Self {
rom_dir,
name: "".to_string(),
obs_filter_config: Default::default(),
act_filter_config: Default::default(),
train: true,
}
}
}
impl<O, A, OF, AF> BorderAtariEnvConfig<O, A, OF, AF>
where
O: Obs,
A: Act,
OF: BorderAtariObsFilter<O>,
AF: BorderAtariActFilter<A>,
{
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
pub fn eval(mut self) -> Self {
self.train = false;
self
}
}