Skip to main content

cubek_test_utils/
config.rs

1//! `cubek.toml` loader.
2//!
3//! At process start we walk up from the current working directory looking
4//! for a file named `cubek.toml`. The first match wins, the parsed result
5//! is cached in a `OnceLock`, and every other module asks for the parts it
6//! needs through this module.
7//!
8//! The parser is intentionally tiny — `cubek.toml` is a flat schema (one
9//! `[section]` key per group, scalar values only). Anything more elaborate
10//! belongs in a real TOML library; for now we accept exactly what the
11//! example file documents and panic loudly on anything else.
12
13use std::collections::HashMap;
14use std::path::PathBuf;
15use std::sync::OnceLock;
16
17use crate::correctness::{TensorFilter, parse_tensor_filter};
18
19/// Top-level configuration loaded from `cubek.toml`.
20#[derive(Clone, Debug)]
21pub struct CubekConfig {
22    pub test: TestSection,
23    pub print: PrintSection,
24}
25
26#[derive(Clone, Debug)]
27pub struct TestSection {
28    pub policy: TestPolicy,
29}
30
31#[derive(Copy, Clone, Debug, PartialEq, Eq)]
32pub enum TestPolicy {
33    Correct,
34    Strict,
35    FailIfRun,
36}
37
38#[derive(Clone, Debug)]
39pub struct PrintSection {
40    pub enabled: bool,
41    pub view: PrintView,
42    pub force_fail: bool,
43    pub fail_only: bool,
44    pub show_expected: bool,
45    pub filter: TensorFilter,
46}
47
48#[derive(Copy, Clone, Debug, PartialEq, Eq)]
49pub enum PrintView {
50    Table,
51    Lines,
52}
53
54impl Default for CubekConfig {
55    fn default() -> Self {
56        Self {
57            test: TestSection {
58                policy: TestPolicy::Correct,
59            },
60            print: PrintSection {
61                enabled: false,
62                view: PrintView::Table,
63                force_fail: true,
64                fail_only: false,
65                show_expected: false,
66                filter: Vec::new(),
67            },
68        }
69    }
70}
71
72/// Returns the active config. Cached on first call.
73pub fn config() -> &'static CubekConfig {
74    static CACHE: OnceLock<CubekConfig> = OnceLock::new();
75    CACHE.get_or_init(load_config)
76}
77
78fn load_config() -> CubekConfig {
79    let Some(path) = find_cubek_toml() else {
80        return CubekConfig::default();
81    };
82    let text = std::fs::read_to_string(&path)
83        .unwrap_or_else(|e| panic!("cannot read {}: {e}", path.display()));
84    parse_cubek_toml(&text)
85        .unwrap_or_else(|e| panic!("invalid cubek.toml ({}): {e}", path.display()))
86}
87
88fn find_cubek_toml() -> Option<PathBuf> {
89    let mut cur = std::env::current_dir().ok()?;
90    loop {
91        let candidate = cur.join("cubek.toml");
92        if candidate.is_file() {
93            return Some(candidate);
94        }
95        if !cur.pop() {
96            return None;
97        }
98    }
99}
100
101// ---------- minimal flat-section TOML parser ----------
102
103type Sections = HashMap<String, HashMap<String, String>>;
104
105fn parse_cubek_toml(text: &str) -> Result<CubekConfig, String> {
106    let sections = parse_sections(text)?;
107
108    let mut cfg = CubekConfig::default();
109
110    if let Some(map) = sections.get("test") {
111        cfg.test.policy = match get_string(map, "policy")?.as_deref() {
112            None | Some("correct") => TestPolicy::Correct,
113            Some("strict") => TestPolicy::Strict,
114            Some("fail-if-run") => TestPolicy::FailIfRun,
115            Some(other) => {
116                return Err(format!(
117                    "[test] policy='{}': expected one of \"correct\", \"strict\", \"fail-if-run\"",
118                    other
119                ));
120            }
121        };
122        reject_unknown_keys("test", map, &["policy"])?;
123    }
124
125    if let Some(map) = sections.get("print") {
126        let enabled = get_bool(map, "enabled")?.unwrap_or(false);
127        let view = match get_string(map, "view")?.as_deref() {
128            None | Some("table") => PrintView::Table,
129            Some("lines") => PrintView::Lines,
130            Some(other) => {
131                return Err(format!(
132                    "[print] view='{}': expected \"table\" or \"lines\"",
133                    other
134                ));
135            }
136        };
137        let force_fail = get_bool(map, "force-fail")?.unwrap_or(true);
138        let fail_only = get_bool(map, "fail-only")?.unwrap_or(false);
139        let show_expected = get_bool(map, "show-expected")?.unwrap_or(false);
140        let filter_str = get_string(map, "filter")?.unwrap_or_default();
141        let filter = if filter_str.is_empty() {
142            Vec::new()
143        } else {
144            parse_tensor_filter(&filter_str)
145                .map_err(|e| format!("[print] filter='{}': {}", filter_str, e))?
146        };
147
148        cfg.print = PrintSection {
149            enabled,
150            view,
151            force_fail,
152            fail_only,
153            show_expected,
154            filter,
155        };
156
157        reject_unknown_keys(
158            "print",
159            map,
160            &[
161                "enabled",
162                "view",
163                "force-fail",
164                "fail-only",
165                "show-expected",
166                "filter",
167            ],
168        )?;
169    }
170
171    for sec in sections.keys() {
172        if sec != "test" && sec != "print" {
173            return Err(format!("unknown section [{}]", sec));
174        }
175    }
176
177    Ok(cfg)
178}
179
180fn parse_sections(text: &str) -> Result<Sections, String> {
181    let mut sections: Sections = HashMap::new();
182    let mut current: Option<String> = None;
183
184    for (line_no, raw) in text.lines().enumerate() {
185        let line = strip_comment(raw).trim();
186        if line.is_empty() {
187            continue;
188        }
189
190        if let Some(rest) = line.strip_prefix('[')
191            && let Some(name) = rest.strip_suffix(']')
192        {
193            let name = name.trim();
194            if name.is_empty() || name.contains('.') {
195                return Err(format!(
196                    "line {}: section name '[{}]' must be a single identifier",
197                    line_no + 1,
198                    name
199                ));
200            }
201            sections.entry(name.to_string()).or_default();
202            current = Some(name.to_string());
203            continue;
204        }
205
206        let Some(section) = current.as_ref() else {
207            return Err(format!(
208                "line {}: key '{}' before any [section]",
209                line_no + 1,
210                line
211            ));
212        };
213
214        let Some((k, v)) = line.split_once('=') else {
215            return Err(format!(
216                "line {}: expected `key = value`, got '{}'",
217                line_no + 1,
218                line
219            ));
220        };
221        let key = k.trim().to_string();
222        let val = v.trim().to_string();
223        sections.get_mut(section).unwrap().insert(key, val);
224    }
225
226    Ok(sections)
227}
228
229fn strip_comment(line: &str) -> &str {
230    // TOML allows '#' anywhere outside a string. We only support unquoted
231    // scalars and strings without '#' inside; that's enough for cubek.toml.
232    let mut in_string = false;
233    let bytes = line.as_bytes();
234    for (i, &b) in bytes.iter().enumerate() {
235        match b {
236            b'"' => in_string = !in_string,
237            b'#' if !in_string => return &line[..i],
238            _ => {}
239        }
240    }
241    line
242}
243
244fn get_string(map: &HashMap<String, String>, key: &str) -> Result<Option<String>, String> {
245    let Some(raw) = map.get(key) else {
246        return Ok(None);
247    };
248    let s = unquote(raw)
249        .ok_or_else(|| format!("key '{}' must be a quoted string, got `{}`", key, raw))?;
250    Ok(Some(s))
251}
252
253fn get_bool(map: &HashMap<String, String>, key: &str) -> Result<Option<bool>, String> {
254    let Some(raw) = map.get(key) else {
255        return Ok(None);
256    };
257    match raw.as_str() {
258        "true" => Ok(Some(true)),
259        "false" => Ok(Some(false)),
260        other => Err(format!(
261            "key '{}' must be `true` or `false`, got `{}`",
262            key, other
263        )),
264    }
265}
266
267fn unquote(s: &str) -> Option<String> {
268    if s.len() >= 2 && s.starts_with('"') && s.ends_with('"') {
269        Some(s[1..s.len() - 1].to_string())
270    } else {
271        None
272    }
273}
274
275fn reject_unknown_keys(
276    section: &str,
277    map: &HashMap<String, String>,
278    known: &[&str],
279) -> Result<(), String> {
280    for k in map.keys() {
281        if !known.contains(&k.as_str()) {
282            return Err(format!(
283                "[{}] unknown key '{}'. Known: {}",
284                section,
285                k,
286                known.join(", ")
287            ));
288        }
289    }
290    Ok(())
291}
292
293// ---------- unit tests for the parser ----------
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn parses_full_example() {
301        let text = r#"
302[test]
303policy = "strict"
304
305[print]
306enabled = true
307view = "lines"
308force-fail = false
309fail-only = true
310show-expected = true
311filter = "0,1-2"
312"#;
313        let cfg = parse_cubek_toml(text).unwrap();
314        assert_eq!(cfg.test.policy, TestPolicy::Strict);
315        assert!(cfg.print.enabled);
316        assert_eq!(cfg.print.view, PrintView::Lines);
317        assert!(!cfg.print.force_fail);
318        assert!(cfg.print.fail_only);
319        assert!(cfg.print.show_expected);
320        assert_eq!(cfg.print.filter.len(), 2);
321    }
322
323    #[test]
324    fn empty_file_gives_defaults() {
325        let cfg = parse_cubek_toml("").unwrap();
326        assert_eq!(cfg.test.policy, TestPolicy::Correct);
327        assert!(!cfg.print.enabled);
328        assert_eq!(cfg.print.view, PrintView::Table);
329    }
330
331    #[test]
332    fn rejects_unknown_section() {
333        let err = parse_cubek_toml("[bogus]\nx=1\n").unwrap_err();
334        assert!(err.contains("unknown section"), "{}", err);
335    }
336
337    #[test]
338    fn rejects_unknown_key() {
339        let err = parse_cubek_toml("[print]\nbogus = true\n").unwrap_err();
340        assert!(err.contains("unknown key"), "{}", err);
341    }
342
343    #[test]
344    fn rejects_bad_policy() {
345        let err = parse_cubek_toml("[test]\npolicy = \"loose\"\n").unwrap_err();
346        assert!(err.contains("policy"), "{}", err);
347    }
348
349    #[test]
350    fn rejects_unquoted_string() {
351        let err = parse_cubek_toml("[print]\nview = table\n").unwrap_err();
352        assert!(err.contains("quoted string"), "{}", err);
353    }
354
355    #[test]
356    fn rejects_show_delta_key() {
357        // We removed `show-delta`/`show-epsilon` — they should now be
358        // unknown keys, since lines view always shows them and table view
359        // never does.
360        let err = parse_cubek_toml("[print]\nshow-delta = true\n").unwrap_err();
361        assert!(err.contains("unknown key"), "{}", err);
362    }
363}