border_atari_env/env/
config.rs

1//! Configuration of [BorderAtariEnv](super::BorderAtariEnv).
2//!
3//! If environment variable `ATARI_ROM_DIR` exists, it is used as the directory
4//! from which ROM images of the Atari games is loaded.
5use super::{BorderAtariActFilter, BorderAtariObsFilter};
6use border_core::{Act, Obs};
7use serde::{Deserialize, Serialize};
8use std::{default::Default, env};
9
10#[derive(Serialize, Deserialize, Debug)]
11/// Configuration of [`BorderAtariEnv`](super::BorderAtariEnv).
12pub struct BorderAtariEnvConfig<O, A, OF, AF>
13where
14    O: Obs,
15    A: Act,
16    OF: BorderAtariObsFilter<O>,
17    AF: BorderAtariActFilter<A>,
18{
19    pub rom_dir: String,
20    pub name: String,
21    pub obs_filter_config: OF::Config,
22    pub act_filter_config: AF::Config,
23    pub train: bool,
24    pub render: bool,
25}
26
27impl<O, A, OF, AF> Clone for BorderAtariEnvConfig<O, A, OF, AF>
28where
29    O: Obs,
30    A: Act,
31    OF: BorderAtariObsFilter<O>,
32    AF: BorderAtariActFilter<A>,
33{
34    fn clone(&self) -> Self {
35        Self {
36            rom_dir: self.rom_dir.clone(),
37            name: self.name.clone(),
38            obs_filter_config: self.obs_filter_config.clone(),
39            act_filter_config: self.act_filter_config.clone(),
40            train: self.train,
41            render: self.render,
42        }
43    }
44}
45
46impl<O, A, OF, AF> Default for BorderAtariEnvConfig<O, A, OF, AF>
47where
48    O: Obs,
49    A: Act,
50    OF: BorderAtariObsFilter<O>,
51    AF: BorderAtariActFilter<A>,
52{
53    fn default() -> Self {
54        let rom_dir = if let Ok(var) = env::var("ATARI_ROM_DIR") {
55            var
56        } else {
57            "".to_string()
58        };
59
60        Self {
61            rom_dir,
62            name: "".to_string(),
63            obs_filter_config: Default::default(),
64            act_filter_config: Default::default(),
65            train: true,
66            render: false,
67        }
68    }
69}
70
71impl<O, A, OF, AF> BorderAtariEnvConfig<O, A, OF, AF>
72where
73    O: Obs,
74    A: Act,
75    OF: BorderAtariObsFilter<O>,
76    AF: BorderAtariActFilter<A>,
77{
78    /// Sets the name of the game.
79    pub fn name(mut self, name: impl Into<String>) -> Self {
80        self.name = name.into();
81        self
82    }
83
84    /// Sets the evaluation flag.
85    pub fn eval(mut self) -> Self {
86        self.train = false;
87        self
88    }
89
90    pub fn render(mut self, render: bool) -> Self {
91        self.render = render;
92        self
93    }
94}