1use std::collections::HashMap;
14use std::path::PathBuf;
15use std::sync::OnceLock;
16
17use crate::correctness::{TensorFilter, parse_tensor_filter};
18
19#[derive(Clone, Debug)]
21pub struct CubekConfig {
22 pub test: TestSection,
23 pub print: PrintSection,
24}
25
26#[derive(Clone, Debug)]
27pub struct TestSection {
28 pub policy: TestPolicy,
29}
30
31#[derive(Copy, Clone, Debug, PartialEq, Eq)]
32pub enum TestPolicy {
33 Correct,
34 Strict,
35 FailIfRun,
36}
37
38#[derive(Clone, Debug)]
39pub struct PrintSection {
40 pub enabled: bool,
41 pub view: PrintView,
42 pub force_fail: bool,
43 pub fail_only: bool,
44 pub show_expected: bool,
45 pub filter: TensorFilter,
46}
47
48#[derive(Copy, Clone, Debug, PartialEq, Eq)]
49pub enum PrintView {
50 Table,
51 Lines,
52}
53
54impl Default for CubekConfig {
55 fn default() -> Self {
56 Self {
57 test: TestSection {
58 policy: TestPolicy::Correct,
59 },
60 print: PrintSection {
61 enabled: false,
62 view: PrintView::Table,
63 force_fail: true,
64 fail_only: false,
65 show_expected: false,
66 filter: Vec::new(),
67 },
68 }
69 }
70}
71
72pub fn config() -> &'static CubekConfig {
74 static CACHE: OnceLock<CubekConfig> = OnceLock::new();
75 CACHE.get_or_init(load_config)
76}
77
78fn load_config() -> CubekConfig {
79 let Some(path) = find_cubek_toml() else {
80 return CubekConfig::default();
81 };
82 let text = std::fs::read_to_string(&path)
83 .unwrap_or_else(|e| panic!("cannot read {}: {e}", path.display()));
84 parse_cubek_toml(&text)
85 .unwrap_or_else(|e| panic!("invalid cubek.toml ({}): {e}", path.display()))
86}
87
88fn find_cubek_toml() -> Option<PathBuf> {
89 let mut cur = std::env::current_dir().ok()?;
90 loop {
91 let candidate = cur.join("cubek.toml");
92 if candidate.is_file() {
93 return Some(candidate);
94 }
95 if !cur.pop() {
96 return None;
97 }
98 }
99}
100
101type Sections = HashMap<String, HashMap<String, String>>;
104
105fn parse_cubek_toml(text: &str) -> Result<CubekConfig, String> {
106 let sections = parse_sections(text)?;
107
108 let mut cfg = CubekConfig::default();
109
110 if let Some(map) = sections.get("test") {
111 cfg.test.policy = match get_string(map, "policy")?.as_deref() {
112 None | Some("correct") => TestPolicy::Correct,
113 Some("strict") => TestPolicy::Strict,
114 Some("fail-if-run") => TestPolicy::FailIfRun,
115 Some(other) => {
116 return Err(format!(
117 "[test] policy='{}': expected one of \"correct\", \"strict\", \"fail-if-run\"",
118 other
119 ));
120 }
121 };
122 reject_unknown_keys("test", map, &["policy"])?;
123 }
124
125 if let Some(map) = sections.get("print") {
126 let enabled = get_bool(map, "enabled")?.unwrap_or(false);
127 let view = match get_string(map, "view")?.as_deref() {
128 None | Some("table") => PrintView::Table,
129 Some("lines") => PrintView::Lines,
130 Some(other) => {
131 return Err(format!(
132 "[print] view='{}': expected \"table\" or \"lines\"",
133 other
134 ));
135 }
136 };
137 let force_fail = get_bool(map, "force-fail")?.unwrap_or(true);
138 let fail_only = get_bool(map, "fail-only")?.unwrap_or(false);
139 let show_expected = get_bool(map, "show-expected")?.unwrap_or(false);
140 let filter_str = get_string(map, "filter")?.unwrap_or_default();
141 let filter = if filter_str.is_empty() {
142 Vec::new()
143 } else {
144 parse_tensor_filter(&filter_str)
145 .map_err(|e| format!("[print] filter='{}': {}", filter_str, e))?
146 };
147
148 cfg.print = PrintSection {
149 enabled,
150 view,
151 force_fail,
152 fail_only,
153 show_expected,
154 filter,
155 };
156
157 reject_unknown_keys(
158 "print",
159 map,
160 &[
161 "enabled",
162 "view",
163 "force-fail",
164 "fail-only",
165 "show-expected",
166 "filter",
167 ],
168 )?;
169 }
170
171 for sec in sections.keys() {
172 if sec != "test" && sec != "print" {
173 return Err(format!("unknown section [{}]", sec));
174 }
175 }
176
177 Ok(cfg)
178}
179
180fn parse_sections(text: &str) -> Result<Sections, String> {
181 let mut sections: Sections = HashMap::new();
182 let mut current: Option<String> = None;
183
184 for (line_no, raw) in text.lines().enumerate() {
185 let line = strip_comment(raw).trim();
186 if line.is_empty() {
187 continue;
188 }
189
190 if let Some(rest) = line.strip_prefix('[')
191 && let Some(name) = rest.strip_suffix(']')
192 {
193 let name = name.trim();
194 if name.is_empty() || name.contains('.') {
195 return Err(format!(
196 "line {}: section name '[{}]' must be a single identifier",
197 line_no + 1,
198 name
199 ));
200 }
201 sections.entry(name.to_string()).or_default();
202 current = Some(name.to_string());
203 continue;
204 }
205
206 let Some(section) = current.as_ref() else {
207 return Err(format!(
208 "line {}: key '{}' before any [section]",
209 line_no + 1,
210 line
211 ));
212 };
213
214 let Some((k, v)) = line.split_once('=') else {
215 return Err(format!(
216 "line {}: expected `key = value`, got '{}'",
217 line_no + 1,
218 line
219 ));
220 };
221 let key = k.trim().to_string();
222 let val = v.trim().to_string();
223 sections.get_mut(section).unwrap().insert(key, val);
224 }
225
226 Ok(sections)
227}
228
229fn strip_comment(line: &str) -> &str {
230 let mut in_string = false;
233 let bytes = line.as_bytes();
234 for (i, &b) in bytes.iter().enumerate() {
235 match b {
236 b'"' => in_string = !in_string,
237 b'#' if !in_string => return &line[..i],
238 _ => {}
239 }
240 }
241 line
242}
243
244fn get_string(map: &HashMap<String, String>, key: &str) -> Result<Option<String>, String> {
245 let Some(raw) = map.get(key) else {
246 return Ok(None);
247 };
248 let s = unquote(raw)
249 .ok_or_else(|| format!("key '{}' must be a quoted string, got `{}`", key, raw))?;
250 Ok(Some(s))
251}
252
253fn get_bool(map: &HashMap<String, String>, key: &str) -> Result<Option<bool>, String> {
254 let Some(raw) = map.get(key) else {
255 return Ok(None);
256 };
257 match raw.as_str() {
258 "true" => Ok(Some(true)),
259 "false" => Ok(Some(false)),
260 other => Err(format!(
261 "key '{}' must be `true` or `false`, got `{}`",
262 key, other
263 )),
264 }
265}
266
267fn unquote(s: &str) -> Option<String> {
268 if s.len() >= 2 && s.starts_with('"') && s.ends_with('"') {
269 Some(s[1..s.len() - 1].to_string())
270 } else {
271 None
272 }
273}
274
275fn reject_unknown_keys(
276 section: &str,
277 map: &HashMap<String, String>,
278 known: &[&str],
279) -> Result<(), String> {
280 for k in map.keys() {
281 if !known.contains(&k.as_str()) {
282 return Err(format!(
283 "[{}] unknown key '{}'. Known: {}",
284 section,
285 k,
286 known.join(", ")
287 ));
288 }
289 }
290 Ok(())
291}
292
293#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn parses_full_example() {
301 let text = r#"
302[test]
303policy = "strict"
304
305[print]
306enabled = true
307view = "lines"
308force-fail = false
309fail-only = true
310show-expected = true
311filter = "0,1-2"
312"#;
313 let cfg = parse_cubek_toml(text).unwrap();
314 assert_eq!(cfg.test.policy, TestPolicy::Strict);
315 assert!(cfg.print.enabled);
316 assert_eq!(cfg.print.view, PrintView::Lines);
317 assert!(!cfg.print.force_fail);
318 assert!(cfg.print.fail_only);
319 assert!(cfg.print.show_expected);
320 assert_eq!(cfg.print.filter.len(), 2);
321 }
322
323 #[test]
324 fn empty_file_gives_defaults() {
325 let cfg = parse_cubek_toml("").unwrap();
326 assert_eq!(cfg.test.policy, TestPolicy::Correct);
327 assert!(!cfg.print.enabled);
328 assert_eq!(cfg.print.view, PrintView::Table);
329 }
330
331 #[test]
332 fn rejects_unknown_section() {
333 let err = parse_cubek_toml("[bogus]\nx=1\n").unwrap_err();
334 assert!(err.contains("unknown section"), "{}", err);
335 }
336
337 #[test]
338 fn rejects_unknown_key() {
339 let err = parse_cubek_toml("[print]\nbogus = true\n").unwrap_err();
340 assert!(err.contains("unknown key"), "{}", err);
341 }
342
343 #[test]
344 fn rejects_bad_policy() {
345 let err = parse_cubek_toml("[test]\npolicy = \"loose\"\n").unwrap_err();
346 assert!(err.contains("policy"), "{}", err);
347 }
348
349 #[test]
350 fn rejects_unquoted_string() {
351 let err = parse_cubek_toml("[print]\nview = table\n").unwrap_err();
352 assert!(err.contains("quoted string"), "{}", err);
353 }
354
355 #[test]
356 fn rejects_show_delta_key() {
357 let err = parse_cubek_toml("[print]\nshow-delta = true\n").unwrap_err();
361 assert!(err.contains("unknown key"), "{}", err);
362 }
363}