lace/interface/engine/
builder.rs

1use rand::SeedableRng;
2use rand_xoshiro::Xoshiro256Plus;
3use thiserror::Error;
4
5use super::error::NewEngineError;
6use super::Engine;
7use crate::codebook::Codebook;
8use crate::data::DataSource;
9use crate::data::DefaultCodebookError;
10
11const DEFAULT_NSTATES: usize = 8;
12const DEFAULT_ID_OFFSET: usize = 0;
13
14/// Builds `Engine`s
15pub struct EngineBuilder {
16    n_states: Option<usize>,
17    codebook: Option<Codebook>,
18    data_source: DataSource,
19    id_offset: Option<usize>,
20    seed: Option<u64>,
21    flat_cols: bool,
22}
23
24#[derive(Debug, Error)]
25pub enum BuildEngineError {
26    #[error("error constructing Engine: {0}")]
27    NewEngineError(#[from] NewEngineError),
28    #[error("error generating default codebook: {0}")]
29    DefaultCodebookError(#[from] DefaultCodebookError),
30}
31
32impl EngineBuilder {
33    #[must_use]
34    pub fn new(data_source: DataSource) -> Self {
35        Self {
36            n_states: None,
37            codebook: None,
38            data_source,
39            id_offset: None,
40            seed: None,
41            flat_cols: false,
42        }
43    }
44
45    /// With a certain number of states
46    #[must_use]
47    pub fn with_nstates(mut self, n_states: usize) -> Self {
48        self.n_states = Some(n_states);
49        self
50    }
51
52    /// With a specific codebook
53    #[must_use]
54    pub fn codebook(mut self, codebook: Codebook) -> Self {
55        self.codebook = Some(codebook);
56        self
57    }
58
59    /// With state IDs starting at an offset
60    #[must_use]
61    pub fn id_offset(mut self, id_offset: usize) -> Self {
62        self.id_offset = Some(id_offset);
63        self
64    }
65
66    /// With a given random number generator
67    #[must_use]
68    pub fn seed_from_u64(mut self, seed: u64) -> Self {
69        self.seed = Some(seed);
70        self
71    }
72
73    /// With a flat column structure -- one view in each state
74    #[must_use]
75    pub fn flat_cols(mut self) -> Self {
76        self.flat_cols = true;
77        self
78    }
79
80    // Build the `Engine`; consume the `Builder`.
81    pub fn build(self) -> Result<Engine, BuildEngineError> {
82        let nstates = self.n_states.unwrap_or(DEFAULT_NSTATES);
83
84        let id_offset = self.id_offset.unwrap_or(DEFAULT_ID_OFFSET);
85        let rng = match self.seed {
86            Some(s) => Xoshiro256Plus::seed_from_u64(s),
87            None => Xoshiro256Plus::from_os_rng(),
88        };
89
90        let codebook = match self.codebook {
91            Some(codebook) => Ok(codebook),
92            None => self
93                .data_source
94                .default_codebook()
95                .map_err(BuildEngineError::DefaultCodebookError),
96        }?;
97
98        let mut engine =
99            Engine::new(nstates, codebook, self.data_source, id_offset, rng)
100                .map_err(BuildEngineError::NewEngineError)?;
101
102        if self.flat_cols {
103            engine.flatten_cols();
104        }
105
106        Ok(engine)
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use std::collections::BTreeSet;
113    use std::path::Path;
114    use std::path::PathBuf;
115
116    use maplit::btreeset;
117    use polars::prelude::CsvReadOptions;
118    use polars::prelude::DataFrame;
119
120    use super::*;
121    use crate::codebook::ReadError;
122
123    fn read_csv<P: AsRef<Path>>(path: P) -> Result<DataFrame, ReadError> {
124        use polars::prelude::SerReader;
125        let df = CsvReadOptions::default()
126            .with_infer_schema_length(Some(1000))
127            .with_has_header(true)
128            .try_into_reader_with_file_path(Some(path.as_ref().into()))?
129            .finish()?;
130        Ok(df)
131    }
132
133    fn animals_csv() -> DataSource {
134        let df = read_csv(PathBuf::from("resources/datasets/animals/data.csv"))
135            .unwrap();
136        DataSource::Polars(df)
137    }
138
139    #[test]
140    fn default_build_settings() {
141        let engine = EngineBuilder::new(animals_csv()).build().unwrap();
142        let state_ids: BTreeSet<usize> =
143            engine.state_ids.iter().copied().collect();
144        let target_ids: BTreeSet<usize> = btreeset! {0, 1, 2, 3, 4, 5, 6, 7};
145        assert_eq!(engine.n_states(), 8);
146        assert_eq!(state_ids, target_ids);
147    }
148
149    #[test]
150    fn with_id_offet_3() {
151        let engine = EngineBuilder::new(animals_csv())
152            .id_offset(3)
153            .build()
154            .unwrap();
155        let state_ids: BTreeSet<usize> =
156            engine.state_ids.iter().copied().collect();
157        let target_ids: BTreeSet<usize> = btreeset! {3, 4, 5, 6, 7, 8, 9, 10};
158        assert_eq!(engine.n_states(), 8);
159        assert_eq!(state_ids, target_ids);
160    }
161
162    #[test]
163    fn with_nstates_3() {
164        let engine = EngineBuilder::new(animals_csv())
165            .with_nstates(3)
166            .build()
167            .unwrap();
168        let state_ids: BTreeSet<usize> =
169            engine.state_ids.iter().copied().collect();
170        let target_ids: BTreeSet<usize> = btreeset! {0, 1, 2};
171        assert_eq!(engine.n_states(), 3);
172        assert_eq!(state_ids, target_ids);
173    }
174
175    #[test]
176    fn with_nstates_0_causes_error() {
177        let result = EngineBuilder::new(animals_csv()).with_nstates(0).build();
178
179        assert!(result.is_err());
180    }
181
182    #[test]
183    fn seeding_engine_works() {
184        let seed: u64 = 8_675_309;
185        let nstates = 4;
186        let mut engine_1 = EngineBuilder::new(animals_csv())
187            .with_nstates(nstates)
188            .seed_from_u64(seed)
189            .build()
190            .unwrap();
191
192        let mut engine_2 = EngineBuilder::new(animals_csv())
193            .with_nstates(nstates)
194            .seed_from_u64(seed)
195            .build()
196            .unwrap();
197
198        // initial state should be the same
199        for (state_1, state_2) in
200            engine_1.states.iter().zip(engine_2.states.iter())
201        {
202            assert_eq!(state_1.asgn(), state_2.asgn());
203            for (view_1, view_2) in
204                state_1.views.iter().zip(state_2.views.iter())
205            {
206                assert_eq!(view_1.asgn(), view_2.asgn());
207            }
208        }
209
210        engine_1.run(10).unwrap();
211        engine_2.run(10).unwrap();
212
213        // And should stay the same after the run
214        for (state_1, state_2) in
215            engine_1.states.iter().zip(engine_2.states.iter())
216        {
217            assert_eq!(state_1.asgn(), state_2.asgn());
218            for (view_1, view_2) in
219                state_1.views.iter().zip(state_2.views.iter())
220            {
221                assert_eq!(view_1.asgn(), view_2.asgn());
222            }
223        }
224    }
225}