Skip to main content

zer_schema/
config.rs

1use std::collections::HashSet;
2use std::path::Path;
3
4use regex::Regex;
5use zer_core::{error::ZerError, schema::FieldKind};
6
7const DEFAULT_NAME_HEURISTICS: &str = include_str!("../heuristics_name.toml");
8const DEFAULT_VALUE_PATTERNS: &str = include_str!("../heuristics_values.toml");
9
10// ── Name heuristics ───────────────────────────────────────────────────────────
11
12/// A single name-matching rule mapping one or more column-name patterns to a [`FieldKind`].
13#[derive(Debug, Clone, serde::Deserialize)]
14pub struct NameRule {
15    pub kind: FieldKind,
16    #[serde(default)]
17    pub contains: Vec<String>,
18    #[serde(default)]
19    pub exact: Vec<String>,
20    #[serde(default)]
21    pub starts_with: Vec<String>,
22    #[serde(default)]
23    pub ends_with: Vec<String>,
24}
25
26/// Ordered list of name-matching rules loaded from `heuristics_name.toml`.
27#[derive(Debug, Clone, serde::Deserialize)]
28pub struct NameHeuristics {
29    pub rules: Vec<NameRule>,
30}
31
32impl NameHeuristics {
33    /// Parse from a TOML string.
34    pub fn from_toml_str(s: &str) -> Result<Self, ZerError> {
35        toml::from_str(s).map_err(|e| ZerError::Config(e.to_string()))
36    }
37
38    /// Load from a TOML file on disk.
39    pub fn from_file(path: &Path) -> Result<Self, ZerError> {
40        let content = std::fs::read_to_string(path)?;
41        Self::from_toml_str(&content)
42    }
43
44    /// Load the default heuristics.
45    ///
46    /// Checks `ZER_NAME_HEURISTICS` env var first; if set and loadable, uses
47    /// that file. Otherwise falls back to the embedded `heuristics_name.toml`.
48    pub fn load_default() -> Self {
49        if let Ok(path) = std::env::var("ZER_NAME_HEURISTICS") {
50            match Self::from_file(Path::new(&path)) {
51                Ok(h) => return h,
52                Err(e) => tracing::warn!(
53                    "ZER_NAME_HEURISTICS={path:?}: failed to load ({e}), using embedded default"
54                ),
55            }
56        }
57        Self::from_toml_str(DEFAULT_NAME_HEURISTICS)
58            .expect("embedded heuristics_name.toml is always valid")
59    }
60
61    /// Try to match a column name against the rules. Returns `None` when no
62    /// rule matches, signalling the caller to fall back to value sampling.
63    pub fn infer_kind(&self, name: &str) -> Option<FieldKind> {
64        let n = name.to_ascii_lowercase();
65        for rule in &self.rules {
66            if rule.exact.iter().any(|p| n == p.as_str())
67                || rule.contains.iter().any(|p| n.contains(p.as_str()))
68                || rule.starts_with.iter().any(|p| n.starts_with(p.as_str()))
69                || rule.ends_with.iter().any(|p| n.ends_with(p.as_str()))
70            {
71                return Some(rule.kind);
72            }
73        }
74        None
75    }
76}
77
78// ── Value patterns ────────────────────────────────────────────────────────────
79
80#[derive(Debug, serde::Deserialize)]
81struct RawValuePattern {
82    kind: FieldKind,
83    regex: String,
84    #[serde(default)]
85    threshold: f32,
86    unique_rate_min: Option<f32>,
87    unique_rate_max: Option<f32>,
88    avg_len_min: Option<f32>,
89    avg_len_max: Option<f32>,
90}
91
92#[derive(Debug, serde::Deserialize)]
93struct RawFallback {
94    default_kind: FieldKind,
95}
96
97#[derive(Debug, serde::Deserialize)]
98struct RawValuePatterns {
99    patterns: Vec<RawValuePattern>,
100    fallback: RawFallback,
101}
102
103/// A value-sampling pattern with its regex pre-compiled.
104#[derive(Debug)]
105pub struct CompiledValuePattern {
106    pub kind: FieldKind,
107    /// `None` when the pattern has no regex (purely statistical conditions).
108    pub regex: Option<Regex>,
109    pub threshold: f32,
110    pub unique_rate_min: Option<f32>,
111    pub unique_rate_max: Option<f32>,
112    pub avg_len_min: Option<f32>,
113    pub avg_len_max: Option<f32>,
114}
115
116/// Ordered list of value-sampling patterns loaded from `heuristics_values.toml`.
117#[derive(Debug)]
118pub struct ValuePatterns {
119    pub patterns: Vec<CompiledValuePattern>,
120    pub fallback_kind: FieldKind,
121}
122
123impl ValuePatterns {
124    fn from_raw(raw: RawValuePatterns) -> Result<Self, ZerError> {
125        let mut patterns = Vec::with_capacity(raw.patterns.len());
126        for p in raw.patterns {
127            let regex = if p.regex.is_empty() {
128                None
129            } else {
130                Some(Regex::new(&p.regex).map_err(|e| {
131                    ZerError::Config(format!("invalid regex {:?}: {e}", p.regex))
132                })?)
133            };
134            patterns.push(CompiledValuePattern {
135                kind: p.kind,
136                regex,
137                threshold: p.threshold,
138                unique_rate_min: p.unique_rate_min,
139                unique_rate_max: p.unique_rate_max,
140                avg_len_min: p.avg_len_min,
141                avg_len_max: p.avg_len_max,
142            });
143        }
144        Ok(Self { patterns, fallback_kind: raw.fallback.default_kind })
145    }
146
147    /// Parse from a TOML string. Returns `Err` if any regex is invalid.
148    pub fn from_toml_str(s: &str) -> Result<Self, ZerError> {
149        let raw: RawValuePatterns =
150            toml::from_str(s).map_err(|e| ZerError::Config(e.to_string()))?;
151        Self::from_raw(raw)
152    }
153
154    /// Load from a TOML file on disk.
155    pub fn from_file(path: &Path) -> Result<Self, ZerError> {
156        let content = std::fs::read_to_string(path)?;
157        Self::from_toml_str(&content)
158    }
159
160    /// Load the default patterns.
161    ///
162    /// Checks `ZER_VALUE_PATTERNS` env var first; if set and loadable, uses
163    /// that file. Otherwise falls back to the embedded `heuristics_values.toml`.
164    pub fn load_default() -> Self {
165        if let Ok(path) = std::env::var("ZER_VALUE_PATTERNS") {
166            match Self::from_file(Path::new(&path)) {
167                Ok(p) => return p,
168                Err(e) => tracing::warn!(
169                    "ZER_VALUE_PATTERNS={path:?}: failed to load ({e}), using embedded default"
170                ),
171            }
172        }
173        Self::from_toml_str(DEFAULT_VALUE_PATTERNS)
174            .expect("embedded heuristics_values.toml is always valid")
175    }
176
177    /// Infer a [`FieldKind`] from a slice of sampled text values.
178    ///
179    /// Evaluates patterns in order; returns the first match. Falls back to
180    /// `fallback_kind` (typically `FreeText`) when nothing matches.
181    pub fn infer_kind(&self, samples: &[&str]) -> FieldKind {
182        if samples.is_empty() {
183            return self.fallback_kind;
184        }
185        let total = samples.len() as f32;
186        let unique_rate = samples.iter().collect::<HashSet<_>>().len() as f32 / total;
187        let avg_len = samples.iter().map(|s| s.len() as f32).sum::<f32>() / total;
188
189        for pat in &self.patterns {
190            let match_frac = match &pat.regex {
191                Some(re) => samples.iter().filter(|s| re.is_match(s)).count() as f32 / total,
192                None => 1.0,
193            };
194            if match_frac >= pat.threshold
195                && pat.unique_rate_min.map_or(true, |min| unique_rate >= min)
196                && pat.unique_rate_max.map_or(true, |max| unique_rate <= max)
197                && pat.avg_len_max.map_or(true, |max| avg_len <= max)
198                && pat.avg_len_min.map_or(true, |min| avg_len >= min)
199            {
200                return pat.kind;
201            }
202        }
203        self.fallback_kind
204    }
205}
206
207// ── Tests ─────────────────────────────────────────────────────────────────────
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn name_heuristics_embedded_default_loads() {
215        let h = NameHeuristics::load_default();
216        assert!(!h.rules.is_empty());
217    }
218
219    #[test]
220    fn name_heuristics_matches_known_patterns() {
221        let h = NameHeuristics::load_default();
222        assert_eq!(h.infer_kind("first_name"), Some(FieldKind::Name));
223        assert_eq!(h.infer_kind("geboortedatum"), Some(FieldKind::Date));
224        assert_eq!(h.infer_kind("msisdn"), Some(FieldKind::Phone));
225        assert_eq!(h.infer_kind("postcode"), Some(FieldKind::Address));
226        assert_eq!(h.infer_kind("bsn"), Some(FieldKind::Id));
227    }
228
229    #[test]
230    fn name_heuristics_returns_none_for_unknown() {
231        let h = NameHeuristics::load_default();
232        assert_eq!(h.infer_kind("xyzzy_col"), None);
233    }
234
235    #[test]
236    fn value_patterns_embedded_default_loads() {
237        let p = ValuePatterns::load_default();
238        assert!(!p.patterns.is_empty());
239    }
240
241    #[test]
242    fn value_patterns_date_detection() {
243        let p = ValuePatterns::load_default();
244        let samples: Vec<&str> = (0..20).map(|_| "2024-03-15").collect();
245        assert_eq!(p.infer_kind(&samples), FieldKind::Date);
246    }
247
248    #[test]
249    fn value_patterns_fallback_on_empty() {
250        let p = ValuePatterns::load_default();
251        assert_eq!(p.infer_kind(&[]), FieldKind::FreeText);
252    }
253
254    #[test]
255    fn custom_name_heuristics_from_file() {
256        let dir = tempfile::tempdir().unwrap();
257        let path = dir.path().join("custom_name.toml");
258        std::fs::write(
259            &path,
260            r#"
261[[rules]]
262kind  = "Id"
263exact = ["mijnkolom"]
264"#,
265        )
266        .unwrap();
267
268        let h = NameHeuristics::from_file(&path).unwrap();
269        assert_eq!(h.infer_kind("mijnkolom"), Some(FieldKind::Id));
270        assert_eq!(h.infer_kind("other"), None);
271    }
272
273    #[test]
274    fn custom_value_patterns_from_file() {
275        let dir = tempfile::tempdir().unwrap();
276        let path = dir.path().join("custom_values.toml");
277        std::fs::write(
278            &path,
279            r#"
280[[patterns]]
281kind      = "Phone"
282regex     = '^\+31\d{9}$'
283threshold = 0.8
284
285[fallback]
286default_kind = "FreeText"
287"#,
288        )
289        .unwrap();
290
291        let p = ValuePatterns::from_file(&path).unwrap();
292        let samples: Vec<&str> = (0..20).map(|_| "+31612345678").collect();
293        assert_eq!(p.infer_kind(&samples), FieldKind::Phone);
294    }
295
296    #[test]
297    fn invalid_toml_returns_error() {
298        let result = NameHeuristics::from_toml_str("this is not toml ][");
299        assert!(matches!(result, Err(ZerError::Config(_))));
300    }
301
302    #[test]
303    fn invalid_regex_returns_error() {
304        let result = ValuePatterns::from_toml_str(
305            r#"
306[[patterns]]
307kind      = "Date"
308regex     = '[invalid'
309threshold = 0.8
310
311[fallback]
312default_kind = "FreeText"
313"#,
314        );
315        assert!(matches!(result, Err(ZerError::Config(_))));
316    }
317}