Skip to main content

nxs/convert/
infer.rs

1//! Two-pass streaming sigil inference.
2//!
3//! Pass 1: iterate source records, keep a per-key lattice state. Pass 2 is
4//! driven by the caller (json_in/csv_in/xml_in) using the frozen schema.
5//!
6//! Priority lattice (from spec inference_rules.priority):
7//!   = (int)    > ~ (float) > ? (bool) > @ (time) > < (hex) > ^ (null) > " (string)
8//!
9//! Tie-breaks (spec § tie_breaks):
10//!   - 0/1 columns → int (not bool); int rule fires first
11//!   - hex requires length ≥ 16, even length, all hex chars
12//!   - fallback is always string
13
14use super::{ConflictPolicy, InferredKey, InferredSchema};
15use crate::consts::{SIGIL_BINARY, SIGIL_STR};
16pub use crate::consts::{SIGIL_BOOL, SIGIL_FLOAT, SIGIL_INT, SIGIL_NULL, SIGIL_TIME};
17use crate::error::{NxsError, Result};
18
19pub const SIGIL_HEX: u8 = SIGIL_BINARY;
20pub const SIGIL_STRING: u8 = SIGIL_STR;
21
22/// Per-key state maintained during pass 1.
23#[derive(Debug, Default, Clone)]
24pub struct KeyState {
25    pub seen_int: bool,
26    pub seen_float: bool,
27    pub seen_bool: bool,
28    pub seen_time: bool,
29    pub seen_binary_hex: bool,
30    pub seen_string: bool,
31    pub seen_null: bool,
32    pub total_records_seen_in: usize,
33    /// Records in which this key was present (non-null).
34    pub present_count: usize,
35    /// Sigil from the very first non-null observation (for `FirstWins` policy).
36    pub first_sigil: Option<u8>,
37}
38
39impl KeyState {
40    /// Classify a raw string observation and merge into `self`.
41    pub fn observe(&mut self, raw: &str) {
42        self.total_records_seen_in += 1;
43        if raw.is_empty() {
44            self.seen_null = true;
45            return;
46        }
47        self.present_count += 1;
48
49        // int: parses as i64. Fires first so 0/1 stay int, not bool.
50        if raw.parse::<i64>().is_ok() {
51            self.seen_int = true;
52            self.first_sigil.get_or_insert(SIGIL_INT);
53            return;
54        }
55        // float: parses as f64 (and is not a pure int)
56        if raw.parse::<f64>().is_ok() {
57            self.seen_float = true;
58            self.first_sigil.get_or_insert(SIGIL_FLOAT);
59            return;
60        }
61        // bool: exactly true/false
62        if raw == "true" || raw == "false" {
63            self.seen_bool = true;
64            self.first_sigil.get_or_insert(SIGIL_BOOL);
65            return;
66        }
67        // time: contains '-' or 'T' and passes basic date/datetime heuristic
68        if is_time_like(raw) {
69            self.seen_time = true;
70            self.first_sigil.get_or_insert(SIGIL_TIME);
71            return;
72        }
73        // hex: length ≥ 16, even, all hex chars
74        if is_hex_like(raw) {
75            self.seen_binary_hex = true;
76            self.first_sigil.get_or_insert(SIGIL_HEX);
77            return;
78        }
79        self.seen_string = true;
80        self.first_sigil.get_or_insert(SIGIL_STRING);
81    }
82
83    /// Collapse accumulated flags to a single sigil byte per plan priority.
84    pub fn resolve_sigil(&self, policy: ConflictPolicy) -> Result<u8> {
85        // Count how many distinct (non-null) types were observed.
86        // String is itself a type for conflict-detection purposes.
87        let type_count = [
88            self.seen_int,
89            self.seen_float,
90            self.seen_bool,
91            self.seen_time,
92            self.seen_binary_hex,
93            self.seen_string,
94        ]
95        .iter()
96        .filter(|&&b| b)
97        .count();
98
99        if type_count > 1 {
100            return match policy {
101                ConflictPolicy::Error => Err(NxsError::ConvertSchemaConflict(
102                    "mixed types observed for key".into(),
103                )),
104                ConflictPolicy::CoerceString => Ok(SIGIL_STRING),
105                ConflictPolicy::FirstWins => {
106                    // Use the sigil from the very first non-null observation.
107                    Ok(self.first_sigil.unwrap_or(SIGIL_STRING))
108                }
109            };
110        }
111
112        // Single type — no conflict.
113        if self.seen_string {
114            return Ok(SIGIL_STRING);
115        }
116
117        if self.seen_int {
118            return Ok(SIGIL_INT);
119        }
120        if self.seen_float {
121            return Ok(SIGIL_FLOAT);
122        }
123        if self.seen_bool {
124            return Ok(SIGIL_BOOL);
125        }
126        if self.seen_time {
127            return Ok(SIGIL_TIME);
128        }
129        if self.seen_binary_hex {
130            return Ok(SIGIL_HEX);
131        }
132        // All null/missing
133        Ok(SIGIL_NULL)
134    }
135}
136
137fn is_time_like(s: &str) -> bool {
138    if s.len() < 8 {
139        return false;
140    }
141    let has_sep = s.contains('-') || s.contains('T');
142    if !has_sep {
143        return false;
144    }
145    s.chars()
146        .all(|c| c.is_ascii_digit() || matches!(c, '-' | ':' | 'T' | 'Z' | '+' | '.'))
147}
148
149fn is_hex_like(s: &str) -> bool {
150    s.len() >= 16 && s.len() % 2 == 0 && s.chars().all(|c| c.is_ascii_hexdigit())
151}
152
153/// Merge a per-record set of observations into the accumulator.
154pub fn merge(acc: &mut InferredSchema, record: &[(String, String)]) {
155    for (key, value) in record {
156        let entry = acc.keys.iter().position(|k| &k.name == key);
157        if let Some(i) = entry {
158            if let Some(ks) = acc.key_states.get_mut(i) {
159                ks.observe(value);
160            }
161        } else {
162            let mut ks = KeyState::default();
163            ks.observe(value);
164            acc.keys.push(InferredKey {
165                name: key.clone(),
166                sigil: 0,
167                optional: false,
168                list_of: None,
169            });
170            acc.key_states.push(ks);
171        }
172    }
173    acc.total_records += 1;
174}
175
176/// Freeze the accumulator into a schema ready to drive `NxsWriter`.
177pub fn finalize(mut acc: InferredSchema, policy: ConflictPolicy) -> Result<InferredSchema> {
178    for (key, state) in acc.keys.iter_mut().zip(acc.key_states.iter()) {
179        key.sigil = state.resolve_sigil(policy)?;
180        key.optional = state.present_count < acc.total_records;
181    }
182    Ok(acc)
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use crate::convert::ConflictPolicy;
189
190    fn observe_all(values: &[&str]) -> KeyState {
191        let mut ks = KeyState::default();
192        for v in values {
193            ks.observe(v);
194        }
195        ks
196    }
197
198    #[test]
199    fn test_infer_priority_order() {
200        // 1. int only
201        let ks = observe_all(&["1", "2", "3"]);
202        assert_eq!(ks.resolve_sigil(ConflictPolicy::Error).unwrap(), SIGIL_INT);
203
204        // 2. int + float → two distinct types → conflict → CoerceString gives string
205        let ks = observe_all(&["1", "2.5"]);
206        assert_eq!(
207            ks.resolve_sigil(ConflictPolicy::CoerceString).unwrap(),
208            SIGIL_STRING
209        );
210
211        // 3. bool only
212        let ks = observe_all(&["true", "false", "true"]);
213        assert_eq!(ks.resolve_sigil(ConflictPolicy::Error).unwrap(), SIGIL_BOOL);
214
215        // 4. bool + int → conflict; Error policy → Err
216        let ks = observe_all(&["true", "0"]);
217        // "0" parses as int, "true" parses as bool → two types
218        assert!(ks.resolve_sigil(ConflictPolicy::Error).is_err());
219
220        // 5. time only (ISO date)
221        let ks = observe_all(&["2026-04-30", "2025-01-01"]);
222        assert_eq!(ks.resolve_sigil(ConflictPolicy::Error).unwrap(), SIGIL_TIME);
223
224        // 6. hex only (length ≥ 16, even, all hex)
225        let ks = observe_all(&["deadbeefcafe0001", "0123456789abcdef"]);
226        assert_eq!(ks.resolve_sigil(ConflictPolicy::Error).unwrap(), SIGIL_HEX);
227
228        // 7. mixed int + string → string (with CoerceString)
229        let ks = observe_all(&["1", "hello"]);
230        assert_eq!(
231            ks.resolve_sigil(ConflictPolicy::CoerceString).unwrap(),
232            SIGIL_STRING
233        );
234
235        // 8. all null/missing → null sigil
236        let ks = observe_all(&["", ""]);
237        assert_eq!(ks.resolve_sigil(ConflictPolicy::Error).unwrap(), SIGIL_NULL);
238    }
239
240    #[test]
241    fn test_infer_missing_keys_marked_optional() {
242        let mut acc = InferredSchema::default();
243        // Record 1: has "email"
244        merge(&mut acc, &[("email".into(), "a@b.com".into())]);
245        // Record 2: does NOT have "email" — advance total without adding key
246        acc.total_records += 1;
247
248        let schema = finalize(acc, ConflictPolicy::Error).unwrap();
249        let email = schema.keys.iter().find(|k| k.name == "email").unwrap();
250        assert!(email.optional, "key absent in one record must be optional");
251    }
252
253    #[test]
254    fn test_infer_on_conflict_coerce_string() {
255        let mut ks = KeyState::default();
256        ks.observe("1"); // int
257        ks.observe("hello"); // string
258        let sigil = ks.resolve_sigil(ConflictPolicy::CoerceString).unwrap();
259        assert_eq!(sigil, SIGIL_STRING);
260    }
261
262    #[test]
263    fn test_infer_on_conflict_error() {
264        let mut ks = KeyState::default();
265        ks.observe("1"); // int
266        ks.observe("hello"); // string
267        let result = ks.resolve_sigil(ConflictPolicy::Error);
268        assert!(result.is_err());
269        assert!(matches!(
270            result.unwrap_err(),
271            NxsError::ConvertSchemaConflict(_)
272        ));
273    }
274
275    #[test]
276    fn test_infer_first_wins_returns_first_observed_sigil() {
277        // int first, then string → first_sigil = int
278        let mut ks = KeyState::default();
279        ks.observe("1"); // int → first_sigil = =
280        ks.observe("hello"); // string → conflict
281        assert_eq!(
282            ks.resolve_sigil(ConflictPolicy::FirstWins).unwrap(),
283            SIGIL_INT,
284            "FirstWins: first-seen type (int) must win"
285        );
286
287        // string first, then int → first_sigil = string
288        let mut ks2 = KeyState::default();
289        ks2.observe("hello"); // string → first_sigil = "
290        ks2.observe("1"); // int → conflict
291        assert_eq!(
292            ks2.resolve_sigil(ConflictPolicy::FirstWins).unwrap(),
293            SIGIL_STRING,
294            "FirstWins: first-seen type (string) must win"
295        );
296
297        // null then non-null: first_sigil must not be set by the null observation
298        let mut ks3 = KeyState::default();
299        ks3.observe(""); // null → first_sigil stays None
300        ks3.observe("42"); // int → first_sigil = =
301        ks3.observe("abc"); // string → conflict
302        assert_eq!(
303            ks3.resolve_sigil(ConflictPolicy::FirstWins).unwrap(),
304            SIGIL_INT,
305            "FirstWins: null observations must not pollute first_sigil"
306        );
307    }
308}