border_atari_env/env/
config.rs1use super::{BorderAtariActFilter, BorderAtariObsFilter};
6use border_core::{Act, Obs};
7use serde::{Deserialize, Serialize};
8use std::{default::Default, env};
9
10#[derive(Serialize, Deserialize, Debug)]
11pub struct BorderAtariEnvConfig<O, A, OF, AF>
13where
14 O: Obs,
15 A: Act,
16 OF: BorderAtariObsFilter<O>,
17 AF: BorderAtariActFilter<A>,
18{
19 pub rom_dir: String,
20 pub name: String,
21 pub obs_filter_config: OF::Config,
22 pub act_filter_config: AF::Config,
23 pub train: bool,
24 pub render: bool,
25}
26
27impl<O, A, OF, AF> Clone for BorderAtariEnvConfig<O, A, OF, AF>
28where
29 O: Obs,
30 A: Act,
31 OF: BorderAtariObsFilter<O>,
32 AF: BorderAtariActFilter<A>,
33{
34 fn clone(&self) -> Self {
35 Self {
36 rom_dir: self.rom_dir.clone(),
37 name: self.name.clone(),
38 obs_filter_config: self.obs_filter_config.clone(),
39 act_filter_config: self.act_filter_config.clone(),
40 train: self.train,
41 render: self.render,
42 }
43 }
44}
45
46impl<O, A, OF, AF> Default for BorderAtariEnvConfig<O, A, OF, AF>
47where
48 O: Obs,
49 A: Act,
50 OF: BorderAtariObsFilter<O>,
51 AF: BorderAtariActFilter<A>,
52{
53 fn default() -> Self {
54 let rom_dir = if let Ok(var) = env::var("ATARI_ROM_DIR") {
55 var
56 } else {
57 "".to_string()
58 };
59
60 Self {
61 rom_dir,
62 name: "".to_string(),
63 obs_filter_config: Default::default(),
64 act_filter_config: Default::default(),
65 train: true,
66 render: false,
67 }
68 }
69}
70
71impl<O, A, OF, AF> BorderAtariEnvConfig<O, A, OF, AF>
72where
73 O: Obs,
74 A: Act,
75 OF: BorderAtariObsFilter<O>,
76 AF: BorderAtariActFilter<A>,
77{
78 pub fn name(mut self, name: impl Into<String>) -> Self {
80 self.name = name.into();
81 self
82 }
83
84 pub fn eval(mut self) -> Self {
86 self.train = false;
87 self
88 }
89
90 pub fn render(mut self, render: bool) -> Self {
91 self.render = render;
92 self
93 }
94}