lace/interface/engine/
builder.rs1use rand::SeedableRng;
2use rand_xoshiro::Xoshiro256Plus;
3use thiserror::Error;
4
5use super::error::NewEngineError;
6use super::Engine;
7use crate::codebook::Codebook;
8use crate::data::DataSource;
9use crate::data::DefaultCodebookError;
10
11const DEFAULT_NSTATES: usize = 8;
12const DEFAULT_ID_OFFSET: usize = 0;
13
14pub struct EngineBuilder {
16 n_states: Option<usize>,
17 codebook: Option<Codebook>,
18 data_source: DataSource,
19 id_offset: Option<usize>,
20 seed: Option<u64>,
21 flat_cols: bool,
22}
23
24#[derive(Debug, Error)]
25pub enum BuildEngineError {
26 #[error("error constructing Engine: {0}")]
27 NewEngineError(#[from] NewEngineError),
28 #[error("error generating default codebook: {0}")]
29 DefaultCodebookError(#[from] DefaultCodebookError),
30}
31
32impl EngineBuilder {
33 #[must_use]
34 pub fn new(data_source: DataSource) -> Self {
35 Self {
36 n_states: None,
37 codebook: None,
38 data_source,
39 id_offset: None,
40 seed: None,
41 flat_cols: false,
42 }
43 }
44
45 #[must_use]
47 pub fn with_nstates(mut self, n_states: usize) -> Self {
48 self.n_states = Some(n_states);
49 self
50 }
51
52 #[must_use]
54 pub fn codebook(mut self, codebook: Codebook) -> Self {
55 self.codebook = Some(codebook);
56 self
57 }
58
59 #[must_use]
61 pub fn id_offset(mut self, id_offset: usize) -> Self {
62 self.id_offset = Some(id_offset);
63 self
64 }
65
66 #[must_use]
68 pub fn seed_from_u64(mut self, seed: u64) -> Self {
69 self.seed = Some(seed);
70 self
71 }
72
73 #[must_use]
75 pub fn flat_cols(mut self) -> Self {
76 self.flat_cols = true;
77 self
78 }
79
80 pub fn build(self) -> Result<Engine, BuildEngineError> {
82 let nstates = self.n_states.unwrap_or(DEFAULT_NSTATES);
83
84 let id_offset = self.id_offset.unwrap_or(DEFAULT_ID_OFFSET);
85 let rng = match self.seed {
86 Some(s) => Xoshiro256Plus::seed_from_u64(s),
87 None => Xoshiro256Plus::from_os_rng(),
88 };
89
90 let codebook = match self.codebook {
91 Some(codebook) => Ok(codebook),
92 None => self
93 .data_source
94 .default_codebook()
95 .map_err(BuildEngineError::DefaultCodebookError),
96 }?;
97
98 let mut engine =
99 Engine::new(nstates, codebook, self.data_source, id_offset, rng)
100 .map_err(BuildEngineError::NewEngineError)?;
101
102 if self.flat_cols {
103 engine.flatten_cols();
104 }
105
106 Ok(engine)
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use std::collections::BTreeSet;
113 use std::path::Path;
114 use std::path::PathBuf;
115
116 use maplit::btreeset;
117 use polars::prelude::CsvReadOptions;
118 use polars::prelude::DataFrame;
119
120 use super::*;
121 use crate::codebook::ReadError;
122
123 fn read_csv<P: AsRef<Path>>(path: P) -> Result<DataFrame, ReadError> {
124 use polars::prelude::SerReader;
125 let df = CsvReadOptions::default()
126 .with_infer_schema_length(Some(1000))
127 .with_has_header(true)
128 .try_into_reader_with_file_path(Some(path.as_ref().into()))?
129 .finish()?;
130 Ok(df)
131 }
132
133 fn animals_csv() -> DataSource {
134 let df = read_csv(PathBuf::from("resources/datasets/animals/data.csv"))
135 .unwrap();
136 DataSource::Polars(df)
137 }
138
139 #[test]
140 fn default_build_settings() {
141 let engine = EngineBuilder::new(animals_csv()).build().unwrap();
142 let state_ids: BTreeSet<usize> =
143 engine.state_ids.iter().copied().collect();
144 let target_ids: BTreeSet<usize> = btreeset! {0, 1, 2, 3, 4, 5, 6, 7};
145 assert_eq!(engine.n_states(), 8);
146 assert_eq!(state_ids, target_ids);
147 }
148
149 #[test]
150 fn with_id_offet_3() {
151 let engine = EngineBuilder::new(animals_csv())
152 .id_offset(3)
153 .build()
154 .unwrap();
155 let state_ids: BTreeSet<usize> =
156 engine.state_ids.iter().copied().collect();
157 let target_ids: BTreeSet<usize> = btreeset! {3, 4, 5, 6, 7, 8, 9, 10};
158 assert_eq!(engine.n_states(), 8);
159 assert_eq!(state_ids, target_ids);
160 }
161
162 #[test]
163 fn with_nstates_3() {
164 let engine = EngineBuilder::new(animals_csv())
165 .with_nstates(3)
166 .build()
167 .unwrap();
168 let state_ids: BTreeSet<usize> =
169 engine.state_ids.iter().copied().collect();
170 let target_ids: BTreeSet<usize> = btreeset! {0, 1, 2};
171 assert_eq!(engine.n_states(), 3);
172 assert_eq!(state_ids, target_ids);
173 }
174
175 #[test]
176 fn with_nstates_0_causes_error() {
177 let result = EngineBuilder::new(animals_csv()).with_nstates(0).build();
178
179 assert!(result.is_err());
180 }
181
182 #[test]
183 fn seeding_engine_works() {
184 let seed: u64 = 8_675_309;
185 let nstates = 4;
186 let mut engine_1 = EngineBuilder::new(animals_csv())
187 .with_nstates(nstates)
188 .seed_from_u64(seed)
189 .build()
190 .unwrap();
191
192 let mut engine_2 = EngineBuilder::new(animals_csv())
193 .with_nstates(nstates)
194 .seed_from_u64(seed)
195 .build()
196 .unwrap();
197
198 for (state_1, state_2) in
200 engine_1.states.iter().zip(engine_2.states.iter())
201 {
202 assert_eq!(state_1.asgn(), state_2.asgn());
203 for (view_1, view_2) in
204 state_1.views.iter().zip(state_2.views.iter())
205 {
206 assert_eq!(view_1.asgn(), view_2.asgn());
207 }
208 }
209
210 engine_1.run(10).unwrap();
211 engine_2.run(10).unwrap();
212
213 for (state_1, state_2) in
215 engine_1.states.iter().zip(engine_2.states.iter())
216 {
217 assert_eq!(state_1.asgn(), state_2.asgn());
218 for (view_1, view_2) in
219 state_1.views.iter().zip(state_2.views.iter())
220 {
221 assert_eq!(view_1.asgn(), view_2.asgn());
222 }
223 }
224 }
225}