Skip to main content

gam_data/
lib.rs

1use csv::{ReaderBuilder, StringRecord};
2use ndarray::{Array2, Axis};
3use rayon::prelude::*;
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, HashSet};
6use std::fmt;
7use std::path::Path;
8
9// ---------------------------------------------------------------------------
10// Typed error
11// ---------------------------------------------------------------------------
12
13/// Typed error variants for the data-loading module.
14///
15/// Public entry points continue to return `Result<_, String>`; this enum is
16/// materialized at leaf sites and converted at the boundary via
17/// `From<DataError> for String` so error text remains byte-identical to the
18/// previous ad-hoc `format!(...)` output.
19#[derive(Debug, Clone)]
20pub enum DataError {
21    /// Schema/column shape disagrees with the file: row width mismatch,
22    /// requested column missing from headers, schema-declared kind violated by
23    /// a row, or an unseen categorical level encountered under
24    /// `UnseenCategoryPolicy::Error`.
25    SchemaMismatch { reason: String },
26    /// Failed to open, decode, or read structural bytes of the source
27    /// (CSV/TSV row read, parquet metadata, file extension detection, parquet
28    /// arrow-cast for string columns).
29    ParseError { reason: String },
30    /// Internal encoding bookkeeping failed: a categorical map expected by the
31    /// schema path was missing, or a level expected to be present in the
32    /// per-column inference state was not found during fix-up.
33    EncodingFailure { reason: String },
34    /// The source has no headers, no rows, or contains an empty / missing
35    /// field at a row that requires a value.
36    EmptyInput { reason: String },
37    /// A cell value cannot be used as a feature: non-finite float, null in a
38    /// numeric parquet column, or an unsupported parquet data type for the
39    /// column.
40    InvalidValue { reason: String },
41    /// A formula or call site references a column name that is not present in
42    /// the input data. Structured so the FFI boundary can raise a typed
43    /// Python exception (`gamfit.ColumnNotFoundError`) carrying the missing
44    /// name, available columns, and similarity suggestions as attributes —
45    /// not as a parsed-back-out substring of the human display text.
46    ///
47    ColumnNotFound {
48        /// The missing column name, exactly as the user wrote it.
49        name: String,
50        /// Optional role label (`"response"`, `"entry"`, `"exit"`, etc.)
51        /// supplied at the resolution site to disambiguate which slot in the
52        /// formula referenced the bad name. `None` for bare term references.
53        role: Option<String>,
54        /// All headers present in the input table at resolution time, sorted.
55        available: Vec<String>,
56        /// Cheap similarity suggestions (case-insensitive substring or
57        /// shared-prefix length ≥ 3), sorted; empty when no header is close.
58        similar: Vec<String>,
59        /// True iff the available set has exactly one entry and that entry
60        /// contains a literal tab — i.e. the user almost certainly handed gam
61        /// a TSV file under a `.csv` filename. Surfaced as a structured
62        /// boolean rather than re-parsed from prose at the boundary.
63        tsv_hint: bool,
64    },
65}
66
67impl DataError {
68    /// Build a typed `ColumnNotFound` from the column map of the resolved
69    /// dataset. Centralises the similarity / TSV-hint heuristics that the
70    /// legacy `missing_column_message` helper used to perform inline so all
71    /// callers — leaf `resolve_col*` shims and the multi-column requested-
72    /// columns aggregator — produce identical payloads.
73    pub fn column_not_found(
74        col_map: &HashMap<String, usize>,
75        name: &str,
76        role: Option<&str>,
77    ) -> Self {
78        let target_lower = name.to_lowercase();
79        let mut similar: Vec<String> = col_map
80            .keys()
81            .filter(|k| {
82                let k_lower = k.to_lowercase();
83                k_lower.contains(&target_lower)
84                    || target_lower.contains(&k_lower)
85                    || shared_prefix(&k_lower, &target_lower) >= 3
86            })
87            .cloned()
88            .collect();
89        similar.sort_unstable();
90        let mut available: Vec<String> = col_map.keys().cloned().collect();
91        available.sort_unstable();
92        let tsv_hint = available.len() == 1 && available[0].contains('\t');
93        Self::ColumnNotFound {
94            name: name.to_string(),
95            role: role.map(str::to_string),
96            available,
97            similar,
98            tsv_hint,
99        }
100    }
101}
102
103impl fmt::Display for DataError {
104    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105        match self {
106            DataError::SchemaMismatch { reason }
107            | DataError::ParseError { reason }
108            | DataError::EncodingFailure { reason }
109            | DataError::EmptyInput { reason }
110            | DataError::InvalidValue { reason } => f.write_str(reason),
111            DataError::ColumnNotFound {
112                name,
113                role,
114                available,
115                similar,
116                tsv_hint,
117            } => {
118                let label = match role {
119                    Some(r) => format!("{r} column '{name}'"),
120                    None => format!("column '{name}'"),
121                };
122                let tsv_suffix = if *tsv_hint {
123                    " — your file appears to be tab-separated; gam expects comma-separated CSV. \
124         Replace tabs with commas, or pre-convert with `tr '\\t' ',' < file.tsv > file.csv`."
125                } else {
126                    ""
127                };
128                if similar.is_empty() {
129                    write!(
130                        f,
131                        "{label} not found in data. Available columns: [{}]{tsv_suffix}",
132                        available.join(", ")
133                    )
134                } else {
135                    write!(
136                        f,
137                        "{label} not found in data. Did you mean one of [{}]? Full list: [{}]{tsv_suffix}",
138                        similar.join(", "),
139                        available.join(", ")
140                    )
141                }
142            }
143        }
144    }
145}
146
147impl std::error::Error for DataError {}
148
149impl From<DataError> for String {
150    fn from(err: DataError) -> String {
151        err.to_string()
152    }
153}
154
155// ---------------------------------------------------------------------------
156// Public types
157// ---------------------------------------------------------------------------
158
159#[derive(Clone, Debug, Serialize, Deserialize)]
160pub struct DataSchema {
161    pub columns: Vec<SchemaColumn>,
162}
163
164#[derive(Clone, Debug, Serialize, Deserialize)]
165pub struct SchemaColumn {
166    pub name: String,
167    pub kind: ColumnKindTag,
168    #[serde(default)]
169    pub levels: Vec<String>,
170}
171
172#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
173#[serde(rename_all = "kebab-case")]
174pub enum ColumnKindTag {
175    Continuous,
176    Binary,
177    Categorical,
178}
179
180#[derive(Clone, Debug, Eq, PartialEq)]
181pub enum UnseenCategoryPolicy {
182    Error,
183    EncodeUnknownForColumns(HashSet<String>),
184}
185
186impl UnseenCategoryPolicy {
187    pub fn encode_unknown_for_columns(columns: HashSet<String>) -> Self {
188        if columns.is_empty() {
189            Self::Error
190        } else {
191            Self::EncodeUnknownForColumns(columns)
192        }
193    }
194
195    fn unseen_code_for(&self, column_name: &str, level_count: usize) -> Option<f64> {
196        match self {
197            Self::Error => None,
198            Self::EncodeUnknownForColumns(columns) => {
199                columns.contains(column_name).then_some(level_count as f64)
200            }
201        }
202    }
203}
204
205#[derive(Clone, Debug)]
206pub struct EncodedDataset {
207    pub headers: Vec<String>,
208    pub values: Array2<f64>,
209    pub schema: DataSchema,
210    pub column_kinds: Vec<ColumnKindTag>,
211}
212
213impl EncodedDataset {
214    pub fn column_map(&self) -> HashMap<String, usize> {
215        self.headers
216            .iter()
217            .enumerate()
218            .map(|(index, header)| (header.clone(), index))
219            .collect()
220    }
221
222    /// Per-column finite (min, max) of the training values, parallel to
223    /// `headers`. Columns with no finite values default to (0.0, 0.0) so that
224    /// downstream clipping is a no-op for them. Used to populate
225    /// `training_feature_ranges` so prediction can clip out-of-hull inputs
226    /// to the training bounding box.
227    pub fn feature_ranges(&self) -> Vec<(f64, f64)> {
228        // Iterate column-by-column (contiguous in C-order Array2 along axis 0
229        // only when the array is Fortran-order; here Array2 is row-major so
230        // each column is strided. However, scanning one column at a time keeps
231        // each column's working set hot, lets rayon parallelize across
232        // columns, and avoids the previous outer-col/inner-row pattern that
233        // re-streamed all rows per column with stride `p`.
234        self.values
235            .axis_iter(Axis(1))
236            .into_par_iter()
237            .map(|col| {
238                let (lo, hi) =
239                    col.iter()
240                        .fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
241                            if v.is_finite() {
242                                (lo.min(v), hi.max(v))
243                            } else {
244                                (lo, hi)
245                            }
246                        });
247                if !lo.is_finite() || !hi.is_finite() {
248                    (0.0, 0.0)
249                } else {
250                    (lo, hi)
251                }
252            })
253            .collect()
254    }
255}
256
257fn shared_prefix(a: &str, b: &str) -> usize {
258    a.chars()
259        .zip(b.chars())
260        .take_while(|(ca, cb)| ca == cb)
261        .count()
262}
263
264// ---------------------------------------------------------------------------
265// Format detection
266// ---------------------------------------------------------------------------
267
268#[derive(Clone, Copy, Debug, Eq, PartialEq)]
269enum DataFormat {
270    Csv,
271    Tsv,
272    Parquet,
273}
274
275fn detect_format(path: &Path) -> Result<DataFormat, DataError> {
276    let ext = path
277        .extension()
278        .and_then(|s| s.to_str())
279        .unwrap_or_default()
280        .to_ascii_lowercase();
281    match ext.as_str() {
282        "csv" => Ok(DataFormat::Csv),
283        "tsv" | "txt" | "tab" => Ok(DataFormat::Tsv),
284        "parquet" | "pq" | "pqt" => Ok(DataFormat::Parquet),
285        other => Err(DataError::ParseError {
286            reason: format!(
287                "unsupported data file extension '.{other}'; expected csv, tsv, txt, parquet, or pq: '{}'",
288                path.display()
289            ),
290        }),
291    }
292}
293
294// ---------------------------------------------------------------------------
295// Unified public API  — format auto-detected, zero extra CLI args
296// ---------------------------------------------------------------------------
297
298pub fn load_dataset_projected(
299    path: &Path,
300    requested_columns: &[String],
301) -> Result<EncodedDataset, DataError> {
302    load_dataset_projected_with_categorical_roles(path, requested_columns, &HashSet::new())
303}
304
305/// Schema-inferring projected loader that forces a set of columns to
306/// [`ColumnKindTag::Categorical`] regardless of whether their labels parse as
307/// numbers.
308///
309/// An untyped CSV/TSV/parquet-numeric frame cannot carry the dtype the typed
310/// Python frame stamps via [`CATEGORICAL_CELL_SENTINEL`], so the value-based
311/// inferer would otherwise demote an integer/numeric-coded factor (e.g. a
312/// `group(region)` grouping coded `0,1,2,3`) to `Continuous` and fit a single
313/// numeric ramp instead of one centred factor level per code. That makes the
314/// CLI's design strictly lower-capacity than the Python `gamfit.fit` design for
315/// the same data, which generalizes worse on every seed.
316///
317/// `categorical_roles` is keyed on the *formula role*, not on a value
318/// heuristic: a column is forced categorical only when the formula uses it in a
319/// role that is a factor by construction (`group(g)` / `factor(g)` / `re(g)`
320/// random-effect terms, or a categorical/multinomial response). A bare `+ x`
321/// linear term and a smooth argument `s(x)` are deliberately NOT included — they
322/// stay value-inferred, so a genuinely continuous integer covariate like
323/// `s(age)` or `+ age` remains `Continuous`. This mirrors the Python sentinel
324/// outcome (`force_categorical`, the column-major inferer) while keying it on
325/// the role the user actually declared.
326pub fn load_dataset_projected_with_categorical_roles(
327    path: &Path,
328    requested_columns: &[String],
329    categorical_roles: &HashSet<&str>,
330) -> Result<EncodedDataset, DataError> {
331    match detect_format(path)? {
332        DataFormat::Csv => {
333            load_delimited_inferred(path, b',', requested_columns, categorical_roles)
334        }
335        DataFormat::Tsv => {
336            load_delimited_inferred(path, b'\t', requested_columns, categorical_roles)
337        }
338        DataFormat::Parquet => load_parquet_inferred(path, requested_columns, categorical_roles),
339    }
340}
341
342pub fn load_datasetwith_schema_projected(
343    path: &Path,
344    schema: &DataSchema,
345    unseen_policy: UnseenCategoryPolicy,
346    requested_columns: &[String],
347) -> Result<EncodedDataset, DataError> {
348    match detect_format(path)? {
349        DataFormat::Csv => {
350            load_delimited_with_schema(path, b',', schema, unseen_policy, requested_columns)
351        }
352        DataFormat::Tsv => {
353            load_delimited_with_schema(path, b'\t', schema, unseen_policy, requested_columns)
354        }
355        DataFormat::Parquet => {
356            load_parquet_with_schema(path, schema, unseen_policy, requested_columns)
357        }
358    }
359}
360
361// ---------------------------------------------------------------------------
362// CSV convenience loader — infers the schema from the file header.
363// ---------------------------------------------------------------------------
364
365pub fn load_csvwith_inferred_schema(path: &Path) -> Result<EncodedDataset, DataError> {
366    load_delimited_inferred(path, b',', &[], &HashSet::new())
367}
368
369// ---------------------------------------------------------------------------
370// Delimited (CSV / TSV) — streaming, columnar, single-pass
371// ---------------------------------------------------------------------------
372
373/// Maximum number of rows used for schema inference when no schema is provided.
374const SCHEMA_SAMPLE_ROWS: usize = 1024;
375
376/// Prefix a typed Python frame stamps onto a cell that originates from a
377/// genuinely-categorical source column (string / object / categorical dtype).
378/// The column-major inference (`infer_and_encode_column_major`) and the
379/// schema-guided predict ingest (`gam-pyffi::string_records_from_rows`) both
380/// strip this prefix before recording or matching a level; its presence forces
381/// the column to `Categorical` even when every label parses as a number, so a
382/// string column labeled "0","1","2" is one centred factor level per label
383/// rather than a numeric ramp (#1317 / #1318). A leading NUL never appears in a
384/// numeric literal, so an untyped CSV/array frame (no prefix) is unaffected.
385pub const CATEGORICAL_CELL_SENTINEL: char = '\u{0}';
386
387/// Strip the leading [`CATEGORICAL_CELL_SENTINEL`] from a cell if present,
388/// returning the clean text and whether the marker was found.
389pub fn strip_categorical_sentinel(cell: &str) -> (&str, bool) {
390    match cell.strip_prefix(CATEGORICAL_CELL_SENTINEL) {
391        Some(rest) => (rest, true),
392        None => (cell, false),
393    }
394}
395
396fn resolve_requested_columns(
397    all_headers: &[String],
398    requested_columns: &[String],
399) -> Result<Vec<usize>, DataError> {
400    if requested_columns.is_empty() {
401        return Ok((0..all_headers.len()).collect());
402    }
403
404    let requested_set: HashSet<&str> = requested_columns.iter().map(String::as_str).collect();
405    let mut selected = Vec::with_capacity(requested_set.len());
406    for (idx, name) in all_headers.iter().enumerate() {
407        if requested_set.contains(name.as_str()) {
408            selected.push(idx);
409        }
410    }
411
412    if selected.len() != requested_set.len() {
413        let available_map: HashMap<String, usize> = all_headers
414            .iter()
415            .enumerate()
416            .map(|(index, header)| (header.clone(), index))
417            .collect();
418        let missing = requested_columns
419            .iter()
420            .filter(|name| !available_map.contains_key(name.as_str()))
421            .map(|name| {
422                DataError::column_not_found(&available_map, name, Some("requested")).to_string()
423            })
424            .collect::<Vec<_>>();
425        return Err(DataError::SchemaMismatch {
426            reason: missing.join("; "),
427        });
428    }
429
430    Ok(selected)
431}
432
433fn projected_headers(all_headers: &[String], selected_indices: &[usize]) -> Vec<String> {
434    selected_indices
435        .iter()
436        .map(|&idx| all_headers[idx].clone())
437        .collect()
438}
439
440fn load_delimited_inferred(
441    path: &Path,
442    delimiter: u8,
443    requested_columns: &[String],
444    categorical_roles: &HashSet<&str>,
445) -> Result<EncodedDataset, DataError> {
446    let t_open = std::time::Instant::now();
447    let mut rdr = ReaderBuilder::new()
448        .has_headers(true)
449        .delimiter(delimiter)
450        .from_path(path)
451        .map_err(|e| DataError::ParseError {
452            reason: format!("failed to open '{}': {e}", path.display()),
453        })?;
454
455    let all_headers: Vec<String> = rdr
456        .headers()
457        .map_err(|e| DataError::ParseError {
458            reason: format!("failed to read headers: {e}"),
459        })?
460        .iter()
461        .map(|s| s.trim().to_string())
462        .collect();
463    if all_headers.is_empty() {
464        return Err(DataError::EmptyInput {
465            reason: "file has no headers".to_string(),
466        });
467    }
468    let selected_indices = resolve_requested_columns(&all_headers, requested_columns)?;
469    let headers = projected_headers(&all_headers, &selected_indices);
470    let p = headers.len();
471    let open_ms = t_open.elapsed().as_secs_f64() * 1000.0;
472    if open_ms > 100.0 {
473        log::info!(
474            "[DATA-LOAD] delim_open+headers | n_headers={} | n_proj={} | {:.1}ms",
475            all_headers.len(),
476            p,
477            open_ms
478        );
479    }
480
481    // Phase 1: stream CSV structure exactly as before, but keep projected,
482    // trimmed fields in row-major order. Field validation, type conversion,
483    // inference, and row-to-column transposition happen after the streaming
484    // read so those CPU-heavy passes can run in parallel across independent
485    // columns. If a later row has malformed CSV width, defer returning that
486    // error until previously streamed rows have been validated to preserve the
487    // serial row-major error precedence.
488    let mut raw_fields = Vec::<String>::new();
489    let mut total_rows: usize = 0;
490    let mut stream_error: Option<DataError> = None;
491
492    let t_stream = std::time::Instant::now();
493    let mut record = StringRecord::new();
494    while rdr
495        .read_record(&mut record)
496        .map_err(|e| DataError::ParseError {
497            reason: format!("failed reading row: {e}"),
498        })?
499    {
500        if record.len() != all_headers.len() {
501            stream_error = Some(DataError::SchemaMismatch {
502                reason: format!(
503                    "row width mismatch at row {}: got {} fields, expected {}",
504                    total_rows + 1,
505                    record.len(),
506                    all_headers.len()
507                ),
508            });
509            break;
510        }
511        total_rows += 1;
512
513        for &selected_idx in &selected_indices {
514            let raw = record.get(selected_idx).unwrap().trim();
515            raw_fields.push(raw.to_string());
516        }
517    }
518
519    let stream_ms = t_stream.elapsed().as_secs_f64() * 1000.0;
520    if stream_ms > 100.0 {
521        log::info!(
522            "[DATA-LOAD] delim_stream | n_rows={} | n_cols={} | {:.1}ms",
523            total_rows,
524            p,
525            stream_ms
526        );
527    }
528
529    if total_rows == 0 {
530        if let Some(err) = stream_error {
531            return Err(err);
532        }
533        return Err(DataError::EmptyInput {
534            reason: "file has no rows".to_string(),
535        });
536    }
537
538    let t_schema = std::time::Instant::now();
539    let sample_count = total_rows.min(SCHEMA_SAMPLE_ROWS);
540    let inferred_columns = (0..p)
541        .into_par_iter()
542        .map(|j| {
543            infer_delimited_column(
544                &raw_fields,
545                total_rows,
546                p,
547                j,
548                &headers[j],
549                sample_count,
550                categorical_roles.contains(headers[j].as_str()),
551            )
552        })
553        .collect::<Vec<_>>();
554
555    let first_error = inferred_columns
556        .iter()
557        .filter_map(|result| result.as_ref().err())
558        .min_by_key(|err| (err.row, err.col));
559    if let Some(err) = first_error {
560        return Err(err.error.clone());
561    }
562    if let Some(err) = stream_error {
563        return Err(err);
564    }
565
566    let inferred_columns = inferred_columns
567        .into_iter()
568        .map(Result::unwrap)
569        .collect::<Vec<_>>();
570
571    // Build schema from inference state.
572    let mut schema_cols = Vec::<SchemaColumn>::with_capacity(p);
573    let mut column_kinds = Vec::<ColumnKindTag>::with_capacity(p);
574    for (j, inferred) in inferred_columns.iter().enumerate() {
575        column_kinds.push(inferred.kind);
576        schema_cols.push(SchemaColumn {
577            name: headers[j].clone(),
578            kind: inferred.kind,
579            levels: if matches!(inferred.kind, ColumnKindTag::Categorical) {
580                inferred.levels.clone()
581            } else {
582                Vec::new()
583            },
584        });
585    }
586    let schema_ms = t_schema.elapsed().as_secs_f64() * 1000.0;
587    if schema_ms > 100.0 {
588        let n_cat = column_kinds
589            .iter()
590            .filter(|k| matches!(k, ColumnKindTag::Categorical))
591            .count();
592        log::info!(
593            "[DATA-LOAD] delim_convert+infer | n_cols={} | n_cat={} | {:.1}ms",
594            p,
595            n_cat,
596            schema_ms
597        );
598    }
599
600    let t_assemble = std::time::Instant::now();
601    // Assemble into Array2 from independent column vectors in parallel.
602    let mut values = Array2::<f64>::zeros((total_rows, p));
603    values
604        .axis_iter_mut(Axis(1))
605        .into_par_iter()
606        .zip(inferred_columns.par_iter())
607        .for_each(|(mut out_col, inferred)| {
608            for (dst, &src) in out_col.iter_mut().zip(inferred.values.iter()) {
609                *dst = src;
610            }
611        });
612    let assemble_ms = t_assemble.elapsed().as_secs_f64() * 1000.0;
613    if assemble_ms > 100.0 {
614        log::info!(
615            "[DATA-LOAD] delim_assemble_array2 | n_rows={} | n_cols={} | {:.1}ms",
616            total_rows,
617            p,
618            assemble_ms
619        );
620    }
621
622    let schema = DataSchema {
623        columns: schema_cols,
624    };
625    Ok(EncodedDataset {
626        headers,
627        values,
628        schema,
629        column_kinds,
630    })
631}
632
633struct InferredDelimitedColumn {
634    values: Vec<f64>,
635    kind: ColumnKindTag,
636    levels: Vec<String>,
637}
638
639#[derive(Debug)]
640struct DelimitedInferenceError {
641    row: usize,
642    col: usize,
643    error: DataError,
644}
645
646fn infer_delimited_column(
647    raw_fields: &[String],
648    total_rows: usize,
649    n_cols: usize,
650    col: usize,
651    header: &str,
652    sample_count: usize,
653    force_categorical: bool,
654) -> Result<InferredDelimitedColumn, DelimitedInferenceError> {
655    // Per-column inference state (mirrors infer_schema_column logic).
656    let mut values = Vec::<f64>::with_capacity(total_rows);
657    let mut all_numeric = true;
658    let mut all_binary = true;
659    let mut level_index = HashMap::<String, usize>::new();
660    let mut levels = Vec::<String>::new();
661
662    // Shared constructor for the "non-finite parsed value" rejection, which is
663    // raised identically from the sample-window, post-window, and final recode
664    // passes below. `col`/`header` are in scope for the whole function.
665    let non_finite_err = |row_idx: usize| DelimitedInferenceError {
666        row: row_idx + 1,
667        col,
668        error: DataError::InvalidValue {
669            reason: format!(
670                "non-finite value at row {}, column '{}'",
671                row_idx + 1,
672                header
673            ),
674        },
675    };
676
677    for row_idx in 0..total_rows {
678        let raw = raw_fields[row_idx * n_cols + col].as_str();
679        if raw.is_empty() {
680            return Err(DelimitedInferenceError {
681                row: row_idx + 1,
682                col,
683                error: DataError::EmptyInput {
684                    reason: format!("empty field at row {}, column '{}'", row_idx + 1, header),
685                },
686            });
687        }
688
689        // Schema inference on sample window.
690        if row_idx < sample_count {
691            if let Ok(v) = raw.parse::<f64>() {
692                if !v.is_finite() {
693                    return Err(non_finite_err(row_idx));
694                }
695                if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
696                    all_binary = false;
697                }
698                values.push(v);
699            } else {
700                all_numeric = false;
701                all_binary = false;
702                level_index.entry(raw.to_string()).or_insert_with(|| {
703                    let idx = levels.len();
704                    levels.push(raw.to_string());
705                    idx
706                });
707                // Store a placeholder for sample-window strings; once the
708                // final column kind is known, categorical columns are fixed up
709                // with the same level codes as the previous serial path.
710                values.push(f64::NAN);
711            }
712        } else if let Ok(v) = raw.parse::<f64>() {
713            // After sample window: we still accumulate inference state for
714            // correctness (a column that looks binary in the first 1024 rows
715            // may contain 2.5 on row 1025).
716            if !v.is_finite() {
717                return Err(non_finite_err(row_idx));
718            }
719            if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
720                all_binary = false;
721            }
722            values.push(v);
723        } else {
724            all_numeric = false;
725            all_binary = false;
726            let idx = *level_index.entry(raw.to_string()).or_insert_with(|| {
727                let new_idx = levels.len();
728                levels.push(raw.to_string());
729                new_idx
730            });
731            values.push(idx as f64);
732        }
733    }
734
735    // A column the formula uses in a factor-by-construction role (an explicit
736    // `group(g)` / `factor(g)` / `re(g)` random effect, or a
737    // categorical/multinomial response) is encoded as a factor even when every
738    // label parsed as a number — the role-based analogue of the typed-frame
739    // `CATEGORICAL_CELL_SENTINEL` path, so CLI and Python produce the same
740    // factor design for a numeric-coded grouping column. The categorical fixup
741    // pass below recodes the already-parsed numeric `values` into sorted level
742    // indices, identical to the genuinely-non-numeric case.
743    let kind = if force_categorical {
744        ColumnKindTag::Categorical
745    } else if all_numeric {
746        if all_binary {
747            ColumnKindTag::Binary
748        } else {
749            ColumnKindTag::Continuous
750        }
751    } else {
752        ColumnKindTag::Categorical
753    };
754
755    if matches!(kind, ColumnKindTag::Categorical) {
756        // A column is categorical only if at least one row failed numeric
757        // parsing. Two failure modes used to silently corrupt the encoded
758        // values for such columns:
759        //   1. Sample-window rows that parsed as numbers stored the raw f64
760        //      (e.g. 0.0) in `values` without adding the raw string to
761        //      `level_index`.
762        //   2. Post-window rows that parsed as numbers stored the raw f64
763        //      directly without consulting `level_index`.
764        // After the column is declared categorical, those rows must be
765        // recoded as level indices using their original raw strings, treating
766        // every distinct raw string as a categorical level (including the
767        // numeric ones). Without this pass, a column like
768        // "0, 0, ..., 0, foo" mixes raw doubles with level codes, breaking
769        // the categorical encoding invariant.
770        //
771        // First discover every distinct level, then sort the level set
772        // lexicographically so the encoding is canonical (matching R `factor()`
773        // / pandas `Categorical`) and independent of row order — the same
774        // contract the column-major Python path enforces (#1319). Recode in a
775        // second pass against the sorted level → index map.
776        for row_idx in 0..total_rows {
777            let raw = raw_fields[row_idx * n_cols + col].as_str();
778            level_index.entry(raw.to_string()).or_insert_with(|| {
779                let new_idx = levels.len();
780                levels.push(raw.to_string());
781                new_idx
782            });
783        }
784        levels.sort();
785        level_index.clear();
786        for (idx, level) in levels.iter().enumerate() {
787            level_index.insert(level.clone(), idx);
788        }
789        for row_idx in 0..total_rows {
790            let raw = raw_fields[row_idx * n_cols + col].as_str();
791            values[row_idx] = level_index[raw] as f64;
792        }
793    }
794
795    for (row_idx, &v) in values.iter().enumerate() {
796        if !v.is_finite() {
797            return Err(non_finite_err(row_idx));
798        }
799    }
800
801    Ok(InferredDelimitedColumn {
802        values,
803        kind,
804        levels,
805    })
806}
807
808fn load_delimited_with_schema(
809    path: &Path,
810    delimiter: u8,
811    schema: &DataSchema,
812    unseen_policy: UnseenCategoryPolicy,
813    requested_columns: &[String],
814) -> Result<EncodedDataset, DataError> {
815    let t_open = std::time::Instant::now();
816    let mut rdr = ReaderBuilder::new()
817        .has_headers(true)
818        .delimiter(delimiter)
819        .from_path(path)
820        .map_err(|e| DataError::ParseError {
821            reason: format!("failed to open '{}': {e}", path.display()),
822        })?;
823
824    let all_headers: Vec<String> = rdr
825        .headers()
826        .map_err(|e| DataError::ParseError {
827            reason: format!("failed to read headers: {e}"),
828        })?
829        .iter()
830        .map(|s| s.trim().to_string())
831        .collect();
832    if all_headers.is_empty() {
833        return Err(DataError::EmptyInput {
834            reason: "file has no headers".to_string(),
835        });
836    }
837    let selected_indices = resolve_requested_columns(&all_headers, requested_columns)?;
838    let headers = projected_headers(&all_headers, &selected_indices);
839    let p = headers.len();
840    let open_ms = t_open.elapsed().as_secs_f64() * 1000.0;
841    if open_ms > 100.0 {
842        log::info!(
843            "[DATA-LOAD] delim_schema_open+headers | n_headers={} | n_proj={} | {:.1}ms",
844            all_headers.len(),
845            p,
846            open_ms
847        );
848    }
849
850    // Build per-column metadata from schema.
851    let schema_byname: HashMap<&str, &SchemaColumn> = schema
852        .columns
853        .iter()
854        .map(|c| (c.name.as_str(), c))
855        .collect();
856
857    let mut col_meta = Vec::<ColMeta>::with_capacity(p);
858    for name in &headers {
859        if let Some(sc) = schema_byname.get(name.as_str()) {
860            let level_map = if matches!(sc.kind, ColumnKindTag::Categorical) {
861                Some(
862                    sc.levels
863                        .iter()
864                        .enumerate()
865                        .map(|(idx, v)| (v.clone(), idx as f64))
866                        .collect::<HashMap<_, _>>(),
867                )
868            } else {
869                None
870            };
871            col_meta.push(ColMeta {
872                kind: sc.kind,
873                level_map,
874                schema_col: (*sc).clone(),
875            });
876        } else {
877            // Column not in schema — will be inferred below (fallback).
878            col_meta.push(ColMeta {
879                kind: ColumnKindTag::Continuous, // tentative
880                level_map: None,
881                schema_col: SchemaColumn {
882                    name: name.clone(),
883                    kind: ColumnKindTag::Continuous,
884                    levels: Vec::new(),
885                },
886            });
887        }
888    }
889
890    // Track which columns need inference (not in provided schema).
891    let needs_inference: Vec<bool> = headers
892        .iter()
893        .map(|h| !schema_byname.contains_key(h.as_str()))
894        .collect();
895
896    // Stream rows into column vecs.
897    let mut col_vecs: Vec<Vec<f64>> = vec![Vec::new(); p];
898    // For columns needing inference, track strings for categorical fixup.
899    let mut infer_all_numeric: Vec<bool> = vec![true; p];
900    let mut infer_all_binary: Vec<bool> = vec![true; p];
901    let mut infer_level_index: Vec<HashMap<String, usize>> = vec![HashMap::new(); p];
902    let mut infer_levels: Vec<Vec<String>> = vec![Vec::new(); p];
903    let mut infer_strings: Vec<Vec<(usize, String)>> = vec![Vec::new(); p]; // (row_idx, raw)
904
905    let mut total_rows: usize = 0;
906    let t_stream = std::time::Instant::now();
907    let mut record = StringRecord::new();
908    while rdr
909        .read_record(&mut record)
910        .map_err(|e| DataError::ParseError {
911            reason: format!("failed reading row: {e}"),
912        })?
913    {
914        if record.len() != all_headers.len() {
915            return Err(DataError::SchemaMismatch {
916                reason: format!(
917                    "row width mismatch at row {}: got {} fields, expected {}",
918                    total_rows + 1,
919                    record.len(),
920                    all_headers.len()
921                ),
922            });
923        }
924        total_rows += 1;
925
926        for j in 0..p {
927            let raw = record.get(selected_indices[j]).unwrap().trim();
928            if raw.is_empty() {
929                return Err(DataError::EmptyInput {
930                    reason: format!(
931                        "empty field at row {}, column '{}'",
932                        total_rows, &headers[j]
933                    ),
934                });
935            }
936
937            if needs_inference[j] {
938                // Accumulate inference state.
939                if let Ok(v) = raw.parse::<f64>() {
940                    if !v.is_finite() {
941                        return Err(DataError::InvalidValue {
942                            reason: format!(
943                                "non-finite value at row {}, column '{}'",
944                                total_rows, &headers[j]
945                            ),
946                        });
947                    }
948                    if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
949                        infer_all_binary[j] = false;
950                    }
951                    col_vecs[j].push(v);
952                    // Also remember the raw string in case this column ends up
953                    // categorical (because a *later* row fails numeric parsing).
954                    // Without this, numeric-parsing rows would keep their raw
955                    // f64 values mixed with level codes — silently corrupting
956                    // the encoding for columns like `0, 0, ..., 0, foo`. This
957                    // mirrors the fix-up already performed in the schema-less
958                    // `infer_delimited_column` path. If the column ends up
959                    // continuous/binary, this Vec is simply dropped.
960                    infer_strings[j].push((total_rows - 1, raw.to_string()));
961                } else {
962                    infer_all_numeric[j] = false;
963                    infer_all_binary[j] = false;
964                    let levels_ref = &mut infer_levels[j];
965                    infer_level_index[j]
966                        .entry(raw.to_string())
967                        .or_insert_with(|| {
968                            let idx = levels_ref.len();
969                            levels_ref.push(raw.to_string());
970                            idx
971                        });
972                    infer_strings[j].push((total_rows - 1, raw.to_string()));
973                    col_vecs[j].push(f64::NAN); // placeholder
974                }
975            } else {
976                // Schema-driven parse.
977                let val = parse_cell_with_schema(
978                    raw,
979                    &col_meta[j],
980                    total_rows,
981                    &headers[j],
982                    &unseen_policy,
983                )?;
984                col_vecs[j].push(val);
985            }
986        }
987    }
988
989    let stream_ms = t_stream.elapsed().as_secs_f64() * 1000.0;
990    if stream_ms > 100.0 {
991        let n_inf = needs_inference.iter().filter(|x| **x).count();
992        log::info!(
993            "[DATA-LOAD] delim_schema_stream | n_rows={} | n_cols={} | n_inf={} | {:.1}ms",
994            total_rows,
995            p,
996            n_inf,
997            stream_ms
998        );
999    }
1000
1001    if total_rows == 0 {
1002        return Err(DataError::EmptyInput {
1003            reason: "file has no rows".to_string(),
1004        });
1005    }
1006
1007    let t_finalize = std::time::Instant::now();
1008    // Finalize inferred columns.
1009    let mut column_kinds = Vec::<ColumnKindTag>::with_capacity(p);
1010    for j in 0..p {
1011        if needs_inference[j] {
1012            let kind = if infer_all_numeric[j] {
1013                if infer_all_binary[j] {
1014                    ColumnKindTag::Binary
1015                } else {
1016                    ColumnKindTag::Continuous
1017                }
1018            } else {
1019                ColumnKindTag::Categorical
1020            };
1021            col_meta[j].kind = kind;
1022            col_meta[j].schema_col.kind = kind;
1023            if matches!(kind, ColumnKindTag::Categorical) {
1024                // Re-encode the entire column as categorical level codes.
1025                // `infer_strings[j]` contains every (row_idx, raw) seen during
1026                // streaming (both numeric- and non-numeric-parsing rows), so
1027                // numeric-looking strings like "0" become their own levels
1028                // instead of leaking through as raw f64 values that would
1029                // collide with real level codes.
1030                //
1031                // First discover the full level set, then sort it
1032                // lexicographically and recode against the canonical order, so
1033                // the encoding matches R `factor()` / pandas `Categorical` and
1034                // is independent of row order (#1319) — the same contract as the
1035                // schema-less and column-major inference paths.
1036                for (_, raw) in &infer_strings[j] {
1037                    let levels_ref = &mut infer_levels[j];
1038                    infer_level_index[j].entry(raw.clone()).or_insert_with(|| {
1039                        let new_idx = levels_ref.len();
1040                        levels_ref.push(raw.clone());
1041                        new_idx
1042                    });
1043                }
1044                infer_levels[j].sort();
1045                infer_level_index[j].clear();
1046                for (idx, level) in infer_levels[j].iter().enumerate() {
1047                    infer_level_index[j].insert(level.clone(), idx);
1048                }
1049                for (row_idx, raw) in &infer_strings[j] {
1050                    col_vecs[j][*row_idx] = infer_level_index[j][raw] as f64;
1051                }
1052                col_meta[j].schema_col.levels = infer_levels[j].clone();
1053            }
1054        }
1055        column_kinds.push(col_meta[j].kind);
1056    }
1057    let finalize_ms = t_finalize.elapsed().as_secs_f64() * 1000.0;
1058    if finalize_ms > 100.0 {
1059        log::info!(
1060            "[DATA-LOAD] delim_schema_finalize | n_cols={} | {:.1}ms",
1061            p,
1062            finalize_ms
1063        );
1064    }
1065
1066    let t_assemble = std::time::Instant::now();
1067    // Assemble Array2 by column in parallel (mirrors the inferred path).
1068    // Each column carries its own finiteness check; errors are surfaced
1069    // through a parallel reduce so the first detected non-finite cell wins
1070    // by lexicographic (column, row) order — deterministic given the
1071    // collect.
1072    let mut values = Array2::<f64>::zeros((total_rows, p));
1073    let assemble_err: Option<DataError> = values
1074        .axis_iter_mut(Axis(1))
1075        .into_par_iter()
1076        .zip(col_vecs.par_iter())
1077        .zip(headers.par_iter())
1078        .map(|((mut out_col, col_vec), header)| {
1079            for (i, &v) in col_vec.iter().enumerate() {
1080                if !v.is_finite() {
1081                    return Some(DataError::InvalidValue {
1082                        reason: format!("non-finite value at row {}, column '{}'", i + 1, header),
1083                    });
1084                }
1085                out_col[i] = v;
1086            }
1087            None
1088        })
1089        .reduce(|| None, |a, b| a.or(b));
1090    if let Some(e) = assemble_err {
1091        return Err(e);
1092    }
1093    let assemble_ms = t_assemble.elapsed().as_secs_f64() * 1000.0;
1094    if assemble_ms > 100.0 {
1095        log::info!(
1096            "[DATA-LOAD] delim_schema_assemble | n_rows={} | n_cols={} | {:.1}ms",
1097            total_rows,
1098            p,
1099            assemble_ms
1100        );
1101    }
1102
1103    let schema_out = DataSchema {
1104        columns: col_meta.into_iter().map(|m| m.schema_col).collect(),
1105    };
1106    Ok(EncodedDataset {
1107        headers,
1108        values,
1109        schema: schema_out,
1110        column_kinds,
1111    })
1112}
1113
1114fn parse_cell_with_schema(
1115    raw: &str,
1116    meta: &ColMeta,
1117    row: usize,
1118    col_name: &str,
1119    unseen_policy: &UnseenCategoryPolicy,
1120) -> Result<f64, DataError> {
1121    let val = match meta.kind {
1122        ColumnKindTag::Continuous => raw.parse::<f64>().map_err(|err| {
1123            DataError::SchemaMismatch {
1124                reason: format!(
1125                    "column '{}' is continuous in schema but row {} has non-numeric value '{}': {}",
1126                    col_name, row, raw, err
1127                ),
1128            }
1129        })?,
1130        ColumnKindTag::Binary => {
1131            let v = raw
1132                .parse::<f64>()
1133                .map_err(|err| DataError::SchemaMismatch {
1134                    reason: format!(
1135                        "column '{}' is binary in schema but row {} has non-numeric value '{}': {}",
1136                        col_name, row, raw, err
1137                    ),
1138                })?;
1139            if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
1140                return Err(DataError::SchemaMismatch {
1141                    reason: format!(
1142                        "column '{}' is binary in schema but row {} has value {}; expected 0 or 1",
1143                        col_name, row, v
1144                    ),
1145                });
1146            }
1147            v
1148        }
1149        ColumnKindTag::Categorical => {
1150            let map = meta
1151                .level_map
1152                .as_ref()
1153                .ok_or_else(|| DataError::EncodingFailure {
1154                    reason: "internal categorical schema map missing".to_string(),
1155                })?;
1156            match map.get(raw) {
1157                Some(v) => *v,
1158                None => unseen_policy
1159                    .unseen_code_for(col_name, meta.schema_col.levels.len())
1160                    .ok_or_else(|| DataError::SchemaMismatch {
1161                        reason: format!(
1162                            "unseen level '{}' in categorical column '{}' at row {}",
1163                            raw, col_name, row
1164                        ),
1165                    })?,
1166            }
1167        }
1168    };
1169    if !val.is_finite() {
1170        return Err(DataError::InvalidValue {
1171            reason: format!("non-finite value at row {}, column '{}'", row, col_name),
1172        });
1173    }
1174    Ok(val)
1175}
1176
1177// Inner type used by load_delimited_with_schema; defined here to keep
1178// parse_cell_with_schema usable without forward-declaring inside the fn.
1179struct ColMeta {
1180    kind: ColumnKindTag,
1181    level_map: Option<HashMap<String, f64>>,
1182    schema_col: SchemaColumn,
1183}
1184
1185// ---------------------------------------------------------------------------
1186// Parquet — columnar, zero StringRecord, schema from metadata
1187// ---------------------------------------------------------------------------
1188
1189enum ParquetBatchColumn {
1190    Numeric(Vec<f64>),
1191    Strings(Vec<String>),
1192}
1193
1194/// True iff an Arrow column should be treated as a string/categorical column.
1195///
1196/// Dictionary encoding is a *storage* detail, not a semantic type: pyarrow
1197/// dictionary-encodes low-cardinality columns by default, including numeric
1198/// ones (integer factor levels, small enums stored as ints). A
1199/// `Dictionary(K, V)` column is categorical iff its *value* type `V` is a
1200/// string type; `Dictionary(_, Int*/UInt*/Float*/Bool)` is numeric. We recurse
1201/// through the value type so nested dictionaries resolve to their leaf type.
1202fn parquet_field_is_string(dt: &arrow::datatypes::DataType) -> bool {
1203    use arrow::datatypes::DataType;
1204    match dt {
1205        DataType::Utf8 | DataType::LargeUtf8 => true,
1206        DataType::Dictionary(_, value_type) => parquet_field_is_string(value_type),
1207        _ => false,
1208    }
1209}
1210
1211fn decode_parquet_batch_column(
1212    col: &dyn arrow::array::Array,
1213    n_rows: usize,
1214    base_row: usize,
1215    header: &str,
1216    is_string_col: bool,
1217) -> Result<ParquetBatchColumn, DataError> {
1218    use arrow::array::{
1219        Array as ArrowArray, BooleanArray, Float32Array, Float64Array, Int8Array, Int16Array,
1220        Int32Array, Int64Array, LargeStringArray, StringArray, UInt8Array, UInt16Array,
1221        UInt32Array, UInt64Array,
1222    };
1223    use arrow::datatypes::DataType;
1224
1225    if col.null_count() > 0 {
1226        for i in 0..n_rows {
1227            if col.is_null(i) {
1228                return Err(DataError::InvalidValue {
1229                    reason: format!(
1230                        "null value at row {}, column '{}'",
1231                        base_row + i + 1,
1232                        header
1233                    ),
1234                });
1235            }
1236        }
1237    }
1238
1239    if is_string_col {
1240        if let Some(arr) = col.as_any().downcast_ref::<StringArray>() {
1241            return Ok(ParquetBatchColumn::Strings(
1242                (0..n_rows).map(|i| arr.value(i).to_string()).collect(),
1243            ));
1244        }
1245        if let Some(arr) = col.as_any().downcast_ref::<LargeStringArray>() {
1246            return Ok(ParquetBatchColumn::Strings(
1247                (0..n_rows).map(|i| arr.value(i).to_string()).collect(),
1248            ));
1249        }
1250
1251        // Dictionary-encoded strings are not directly a StringArray. Cast only
1252        // those remaining string-like arrays rather than falling back for every
1253        // Utf8/LargeUtf8 column.
1254        let casted =
1255            arrow::compute::cast(col, &DataType::Utf8).map_err(|e| DataError::ParseError {
1256                reason: format!("failed to cast column '{}' to string: {e}", header),
1257            })?;
1258        let arr = casted
1259            .as_any()
1260            .downcast_ref::<StringArray>()
1261            .ok_or_else(|| DataError::EncodingFailure {
1262                reason: format!("column '{}' could not be read as string after cast", header),
1263            })?;
1264        return Ok(ParquetBatchColumn::Strings(
1265            (0..n_rows).map(|i| arr.value(i).to_string()).collect(),
1266        ));
1267    }
1268
1269    // Numeric-valued dictionary columns (pyarrow dictionary-encodes
1270    // low-cardinality numeric columns by default) are not directly a
1271    // primitive array. Decode them to their concrete value type so the normal
1272    // numeric arms below apply. `parquet_field_is_string` has already routed
1273    // string-valued dictionaries through the categorical branch above, so any
1274    // dictionary reaching here has a numeric value type.
1275    let decoded_col;
1276    let col: &dyn arrow::array::Array = if let DataType::Dictionary(_, value_type) = col.data_type()
1277    {
1278        decoded_col = arrow::compute::cast(col, value_type).map_err(|e| DataError::ParseError {
1279            reason: format!(
1280                "failed to decode dictionary-encoded numeric column '{}': {e}",
1281                header
1282            ),
1283        })?;
1284        decoded_col.as_ref()
1285    } else {
1286        col
1287    };
1288
1289    let mut values = Vec::with_capacity(n_rows);
1290    match col.data_type() {
1291        DataType::Float64 => {
1292            let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();
1293            values.extend(arr.values().iter().copied());
1294        }
1295        DataType::Float32 => {
1296            let arr = col.as_any().downcast_ref::<Float32Array>().unwrap();
1297            values.extend(arr.values().iter().map(|&v| v as f64));
1298        }
1299        DataType::Int64 => {
1300            let arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
1301            values.extend(arr.values().iter().map(|&v| v as f64));
1302        }
1303        DataType::Int32 => {
1304            let arr = col.as_any().downcast_ref::<Int32Array>().unwrap();
1305            values.extend(arr.values().iter().map(|&v| v as f64));
1306        }
1307        DataType::Int16 => {
1308            let arr = col.as_any().downcast_ref::<Int16Array>().unwrap();
1309            values.extend(arr.values().iter().map(|&v| v as f64));
1310        }
1311        DataType::Int8 => {
1312            let arr = col.as_any().downcast_ref::<Int8Array>().unwrap();
1313            values.extend(arr.values().iter().map(|&v| v as f64));
1314        }
1315        DataType::UInt64 => {
1316            let arr = col.as_any().downcast_ref::<UInt64Array>().unwrap();
1317            values.extend(arr.values().iter().map(|&v| v as f64));
1318        }
1319        DataType::UInt32 => {
1320            let arr = col.as_any().downcast_ref::<UInt32Array>().unwrap();
1321            values.extend(arr.values().iter().map(|&v| v as f64));
1322        }
1323        DataType::UInt16 => {
1324            let arr = col.as_any().downcast_ref::<UInt16Array>().unwrap();
1325            values.extend(arr.values().iter().map(|&v| v as f64));
1326        }
1327        DataType::UInt8 => {
1328            let arr = col.as_any().downcast_ref::<UInt8Array>().unwrap();
1329            values.extend(arr.values().iter().map(|&v| v as f64));
1330        }
1331        DataType::Boolean => {
1332            let arr = col.as_any().downcast_ref::<BooleanArray>().unwrap();
1333            values.extend((0..n_rows).map(|i| if arr.value(i) { 1.0 } else { 0.0 }));
1334        }
1335        other => {
1336            return Err(DataError::InvalidValue {
1337                reason: format!(
1338                    "unsupported parquet column type {:?} for column '{}'",
1339                    other, header
1340                ),
1341            });
1342        }
1343    }
1344
1345    if let Some(i) = values.iter().position(|v| !v.is_finite()) {
1346        return Err(DataError::InvalidValue {
1347            reason: format!(
1348                "non-finite value at row {}, column '{}'",
1349                base_row + i + 1,
1350                header
1351            ),
1352        });
1353    }
1354
1355    Ok(ParquetBatchColumn::Numeric(values))
1356}
1357
1358fn load_parquet_inferred(
1359    path: &Path,
1360    requested_columns: &[String],
1361    categorical_roles: &HashSet<&str>,
1362) -> Result<EncodedDataset, DataError> {
1363    use parquet::arrow::{ProjectionMask, arrow_reader::ParquetRecordBatchReaderBuilder};
1364    use rayon::prelude::*;
1365    use std::fs::File;
1366
1367    let t_open = std::time::Instant::now();
1368    let file = File::open(path).map_err(|e| DataError::ParseError {
1369        reason: format!("failed to open parquet '{}': {e}", path.display()),
1370    })?;
1371    let builder =
1372        ParquetRecordBatchReaderBuilder::try_new(file).map_err(|e| DataError::ParseError {
1373            reason: format!("failed to read parquet metadata '{}': {e}", path.display()),
1374        })?;
1375
1376    let full_schema = builder.schema().clone();
1377    let all_headers: Vec<String> = full_schema
1378        .fields()
1379        .iter()
1380        .map(|f| f.name().clone())
1381        .collect();
1382    if all_headers.is_empty() {
1383        return Err(DataError::EmptyInput {
1384            reason: "parquet file has no columns".to_string(),
1385        });
1386    }
1387    let selected_indices = resolve_requested_columns(&all_headers, requested_columns)?;
1388    let headers = projected_headers(&all_headers, &selected_indices);
1389    let selected_fields = selected_indices
1390        .iter()
1391        .map(|&idx| full_schema.fields()[idx].clone())
1392        .collect::<Vec<_>>();
1393    let projection =
1394        ProjectionMask::roots(builder.parquet_schema(), selected_indices.iter().copied());
1395    let reader =
1396        builder
1397            .with_projection(projection)
1398            .build()
1399            .map_err(|e| DataError::ParseError {
1400                reason: format!("failed to build parquet reader: {e}"),
1401            })?;
1402    let p = headers.len();
1403    let open_ms = t_open.elapsed().as_secs_f64() * 1000.0;
1404    if open_ms > 100.0 {
1405        log::info!(
1406            "[DATA-LOAD] parquet_open+meta | n_headers={} | n_proj={} | {:.1}ms",
1407            all_headers.len(),
1408            p,
1409            open_ms
1410        );
1411    }
1412
1413    let t_batches = std::time::Instant::now();
1414    // Collect all batches.
1415    let mut col_vecs: Vec<Vec<f64>> = vec![Vec::new(); p];
1416    // For string columns: accumulate raw strings to build level maps.
1417    let mut string_cols: Vec<Option<Vec<String>>> = (0..p).map(|_| None).collect();
1418    let mut is_string_col: Vec<bool> = vec![false; p];
1419
1420    for (j, field) in selected_fields.iter().enumerate() {
1421        // A dictionary-encoded column is categorical only when its *value* type
1422        // is a string type; a numeric-valued dictionary (e.g. pyarrow's default
1423        // encoding of low-cardinality integer columns) must stay numeric.
1424        if parquet_field_is_string(field.data_type()) {
1425            is_string_col[j] = true;
1426            string_cols[j] = Some(Vec::new());
1427        }
1428    }
1429
1430    let mut rows_seen = 0usize;
1431    for batch_result in reader {
1432        let batch = batch_result.map_err(|e| DataError::ParseError {
1433            reason: format!("failed to read parquet record batch: {e}"),
1434        })?;
1435        let n_rows = batch.num_rows();
1436
1437        let decoded_columns: Vec<Result<ParquetBatchColumn, DataError>> = (0..p)
1438            .into_par_iter()
1439            .map(|j| {
1440                decode_parquet_batch_column(
1441                    batch.column(j).as_ref(),
1442                    n_rows,
1443                    rows_seen,
1444                    &headers[j],
1445                    is_string_col[j],
1446                )
1447            })
1448            .collect();
1449
1450        for (j, decoded) in decoded_columns.into_iter().enumerate() {
1451            match decoded? {
1452                ParquetBatchColumn::Strings(mut strings) => {
1453                    assert!(is_string_col[j]);
1454                    string_cols[j].as_mut().unwrap().append(&mut strings);
1455                    let new_len = col_vecs[j].len() + n_rows;
1456                    col_vecs[j].resize(new_len, f64::NAN);
1457                }
1458                ParquetBatchColumn::Numeric(mut values) => {
1459                    assert!(!is_string_col[j]);
1460                    col_vecs[j].append(&mut values);
1461                }
1462            }
1463        }
1464        rows_seen += n_rows;
1465    }
1466
1467    let total_rows = col_vecs[0].len();
1468    let batches_ms = t_batches.elapsed().as_secs_f64() * 1000.0;
1469    if batches_ms > 100.0 {
1470        log::info!(
1471            "[DATA-LOAD] parquet_batches_decode | n_rows={} | n_cols={} | {:.1}ms",
1472            total_rows,
1473            p,
1474            batches_ms
1475        );
1476    }
1477    if total_rows == 0 {
1478        return Err(DataError::EmptyInput {
1479            reason: "parquet file has no rows".to_string(),
1480        });
1481    }
1482
1483    let t_schema = std::time::Instant::now();
1484    // Build schema: infer kind from data.
1485    let mut schema_cols = Vec::<SchemaColumn>::with_capacity(p);
1486    let mut column_kinds = Vec::<ColumnKindTag>::with_capacity(p);
1487
1488    let finalized_columns: Vec<(Vec<f64>, ColumnKindTag, SchemaColumn)> = col_vecs
1489        .into_par_iter()
1490        .zip(string_cols.into_par_iter())
1491        .zip(is_string_col.into_par_iter())
1492        .zip(headers.par_iter())
1493        .map(|(((mut col_values, strings), is_string), header)| {
1494            if is_string {
1495                // Categorical. Preserve level order by scanning each column in
1496                // row order; columns are independent and can be finalized in
1497                // parallel without changing schema order after collection.
1498                let strings = strings.expect("string column storage missing");
1499                let mut level_index: HashMap<String, usize> = HashMap::new();
1500                let mut levels_vec: Vec<String> = Vec::new();
1501                for s in &strings {
1502                    level_index.entry(s.clone()).or_insert_with(|| {
1503                        let idx = levels_vec.len();
1504                        levels_vec.push(s.clone());
1505                        idx
1506                    });
1507                }
1508                for (i, s) in strings.iter().enumerate() {
1509                    col_values[i] = *level_index.get(s.as_str()).unwrap() as f64;
1510                }
1511                (
1512                    col_values,
1513                    ColumnKindTag::Categorical,
1514                    SchemaColumn {
1515                        name: header.clone(),
1516                        kind: ColumnKindTag::Categorical,
1517                        levels: levels_vec,
1518                    },
1519                )
1520            } else if categorical_roles.contains(header.as_str()) {
1521                // Numeric column the formula uses in a factor-by-construction
1522                // role (group()/factor()/re() or a categorical response):
1523                // recode the numeric labels into sorted factor levels so a
1524                // numeric-coded grouping column becomes one centred level per
1525                // code, matching the typed-frame sentinel outcome and the
1526                // delimited loader's `force_categorical` path. Labels are
1527                // formatted with the same `{}` Display the level map keys on.
1528                let labels: Vec<String> = col_values.iter().map(|v| v.to_string()).collect();
1529                let mut levels_vec: Vec<String> = Vec::new();
1530                let mut level_index: HashMap<String, usize> = HashMap::new();
1531                for label in &labels {
1532                    level_index.entry(label.clone()).or_insert_with(|| {
1533                        let idx = levels_vec.len();
1534                        levels_vec.push(label.clone());
1535                        idx
1536                    });
1537                }
1538                levels_vec.sort();
1539                level_index.clear();
1540                for (idx, level) in levels_vec.iter().enumerate() {
1541                    level_index.insert(level.clone(), idx);
1542                }
1543                for (i, label) in labels.iter().enumerate() {
1544                    col_values[i] = level_index[label] as f64;
1545                }
1546                (
1547                    col_values,
1548                    ColumnKindTag::Categorical,
1549                    SchemaColumn {
1550                        name: header.clone(),
1551                        kind: ColumnKindTag::Categorical,
1552                        levels: levels_vec,
1553                    },
1554                )
1555            } else {
1556                // Numeric: check if binary.
1557                let all_binary = col_values
1558                    .iter()
1559                    .all(|&v| (v - 0.0).abs() < 1e-12 || (v - 1.0).abs() < 1e-12);
1560                let kind = if all_binary {
1561                    ColumnKindTag::Binary
1562                } else {
1563                    ColumnKindTag::Continuous
1564                };
1565                (
1566                    col_values,
1567                    kind,
1568                    SchemaColumn {
1569                        name: header.clone(),
1570                        kind,
1571                        levels: Vec::new(),
1572                    },
1573                )
1574            }
1575        })
1576        .collect();
1577
1578    let mut col_vecs = Vec::with_capacity(p);
1579    for (col_values, kind, schema_col) in finalized_columns {
1580        col_vecs.push(col_values);
1581        column_kinds.push(kind);
1582        schema_cols.push(schema_col);
1583    }
1584    let schema_ms = t_schema.elapsed().as_secs_f64() * 1000.0;
1585    if schema_ms > 100.0 {
1586        let n_cat = column_kinds
1587            .iter()
1588            .filter(|k| matches!(k, ColumnKindTag::Categorical))
1589            .count();
1590        log::info!(
1591            "[DATA-LOAD] parquet_finalize_schema | n_cols={} | n_cat={} | {:.1}ms",
1592            p,
1593            n_cat,
1594            schema_ms
1595        );
1596    }
1597
1598    let t_assemble = std::time::Instant::now();
1599    // Assemble Array2. Columns are independent; write by column in parallel
1600    // so each task touches a contiguous source vec (and the strided
1601    // destination column once) rather than scattering across all p columns
1602    // per row.
1603    let mut values = Array2::<f64>::zeros((total_rows, p));
1604    values
1605        .axis_iter_mut(Axis(1))
1606        .into_par_iter()
1607        .zip(col_vecs.par_iter())
1608        .for_each(|(mut out_col, src)| {
1609            for (dst, &v) in out_col.iter_mut().zip(src.iter()) {
1610                *dst = v;
1611            }
1612        });
1613    let assemble_ms = t_assemble.elapsed().as_secs_f64() * 1000.0;
1614    if assemble_ms > 100.0 {
1615        log::info!(
1616            "[DATA-LOAD] parquet_assemble_array2 | n_rows={} | n_cols={} | {:.1}ms",
1617            total_rows,
1618            p,
1619            assemble_ms
1620        );
1621    }
1622
1623    Ok(EncodedDataset {
1624        headers,
1625        values,
1626        schema: DataSchema {
1627            columns: schema_cols,
1628        },
1629        column_kinds,
1630    })
1631}
1632
1633fn load_parquet_with_schema(
1634    path: &Path,
1635    schema: &DataSchema,
1636    unseen_policy: UnseenCategoryPolicy,
1637    requested_columns: &[String],
1638) -> Result<EncodedDataset, DataError> {
1639    // Load with inference first, then validate/re-encode against provided schema.
1640    // No formula roles are threaded here: the saved schema already records each
1641    // column's categorical kind, and the re-encode pass below pins kinds to it.
1642    let inferred = load_parquet_inferred(path, requested_columns, &HashSet::new())?;
1643    let p = inferred.headers.len();
1644    let n = inferred.values.nrows();
1645
1646    let schema_byname: HashMap<&str, &SchemaColumn> = schema
1647        .columns
1648        .iter()
1649        .map(|c| (c.name.as_str(), c))
1650        .collect();
1651
1652    let mut column_kinds = Vec::<ColumnKindTag>::with_capacity(p);
1653    let mut schema_cols = Vec::<SchemaColumn>::with_capacity(p);
1654    let mut values = inferred.values;
1655
1656    for j in 0..p {
1657        let name = &inferred.headers[j];
1658        if let Some(sc) = schema_byname.get(name.as_str()) {
1659            column_kinds.push(sc.kind);
1660            schema_cols.push((*sc).clone());
1661
1662            match sc.kind {
1663                ColumnKindTag::Continuous => {
1664                    if matches!(inferred.column_kinds[j], ColumnKindTag::Categorical) {
1665                        return Err(DataError::SchemaMismatch {
1666                            reason: format!(
1667                                "column '{}' is continuous in schema but parquet column is string/categorical",
1668                                name
1669                            ),
1670                        });
1671                    }
1672                }
1673                ColumnKindTag::Binary => {
1674                    if matches!(inferred.column_kinds[j], ColumnKindTag::Categorical) {
1675                        return Err(DataError::SchemaMismatch {
1676                            reason: format!(
1677                                "column '{}' is binary in schema but parquet column is string/categorical",
1678                                name
1679                            ),
1680                        });
1681                    }
1682                    if let Some(row) = values.column(j).iter().position(|value| {
1683                        (*value - 0.0).abs() >= 1e-12 && (*value - 1.0).abs() >= 1e-12
1684                    }) {
1685                        return Err(DataError::SchemaMismatch {
1686                            reason: format!(
1687                                "column '{}' is binary in schema but row {} has value {}; expected 0 or 1",
1688                                name,
1689                                row + 1,
1690                                values[[row, j]]
1691                            ),
1692                        });
1693                    }
1694                }
1695                ColumnKindTag::Categorical => {
1696                    if !matches!(inferred.column_kinds[j], ColumnKindTag::Categorical) {
1697                        return Err(DataError::SchemaMismatch {
1698                            reason: format!(
1699                                "column '{}' is categorical in schema but parquet column is numeric",
1700                                name
1701                            ),
1702                        });
1703                    }
1704                    let inferred_col = &inferred.schema.columns[j];
1705                    // Build mapping: inferred_level_name -> schema_level_index.
1706                    let schema_level_map: HashMap<&str, f64> = sc
1707                        .levels
1708                        .iter()
1709                        .enumerate()
1710                        .map(|(idx, v)| (v.as_str(), idx as f64))
1711                        .collect();
1712                    let inferred_to_schema: Vec<f64> = inferred_col
1713                        .levels
1714                        .iter()
1715                        .map(|lv| {
1716                            schema_level_map
1717                                .get(lv.as_str())
1718                                .copied()
1719                                .or_else(|| unseen_policy.unseen_code_for(name, sc.levels.len()))
1720                                .ok_or_else(|| DataError::SchemaMismatch {
1721                                    reason: format!(
1722                                        "unseen level '{}' in categorical column '{}'",
1723                                        lv, name
1724                                    ),
1725                                })
1726                        })
1727                        .collect::<Result<Vec<_>, _>>()?;
1728                    for i in 0..n {
1729                        let old_code = values[[i, j]] as usize;
1730                        if old_code >= inferred_to_schema.len() {
1731                            let Some(unseen_code) =
1732                                unseen_policy.unseen_code_for(name, sc.levels.len())
1733                            else {
1734                                return Err(DataError::SchemaMismatch {
1735                                    reason: format!(
1736                                        "unseen categorical code at row {}, column '{}'",
1737                                        i + 1,
1738                                        name
1739                                    ),
1740                                });
1741                            };
1742                            values[[i, j]] = unseen_code;
1743                            continue;
1744                        }
1745                        values[[i, j]] = inferred_to_schema[old_code];
1746                    }
1747                }
1748            }
1749        } else {
1750            // Column not in schema — keep inferred.
1751            column_kinds.push(inferred.column_kinds[j]);
1752            schema_cols.push(inferred.schema.columns[j].clone());
1753        }
1754    }
1755
1756    Ok(EncodedDataset {
1757        headers: inferred.headers,
1758        values,
1759        schema: DataSchema {
1760            columns: schema_cols,
1761        },
1762        column_kinds,
1763    })
1764}
1765
1766pub fn encode_recordswith_inferred_schema(
1767    headers: Vec<String>,
1768    records: Vec<StringRecord>,
1769) -> Result<EncodedDataset, String> {
1770    if records.is_empty() {
1771        return Err(DataError::EmptyInput {
1772            reason: "table data cannot be empty".to_string(),
1773        }
1774        .into());
1775    }
1776    // Schema inference is column-independent: each column scans only its own
1777    // field across all rows. With wide frames (e.g. biobank: 22 cols × 194k
1778    // rows) the serial outer loop dominated ingest time, so fan the per-column
1779    // inference passes out over rayon. Order is preserved because `map` over an
1780    // indexed parallel iterator collects back in column order.
1781    let schema_cols = headers
1782        .par_iter()
1783        .enumerate()
1784        .map(|(j, name)| infer_schema_column(name, &records, j).map_err(String::from))
1785        .collect::<Result<Vec<SchemaColumn>, String>>()?;
1786    let schema = DataSchema {
1787        columns: schema_cols,
1788    };
1789    encode_recordswith_schema(headers, records, &schema, UnseenCategoryPolicy::Error)
1790}
1791
1792pub fn encode_recordswith_schema(
1793    headers: Vec<String>,
1794    records: Vec<StringRecord>,
1795    schema: &DataSchema,
1796    unseen_policy: UnseenCategoryPolicy,
1797) -> Result<EncodedDataset, String> {
1798    let n = records.len();
1799    if n == 0 {
1800        return Err(DataError::EmptyInput {
1801            reason: "table data cannot be empty".to_string(),
1802        }
1803        .into());
1804    }
1805    let p = headers.len();
1806    if p == 0 {
1807        return Err(DataError::EmptyInput {
1808            reason: "table data must have at least one header column".to_string(),
1809        }
1810        .into());
1811    }
1812    // Validate the row-width invariant up front. Without this check, records
1813    // wider than `headers` would be silently truncated (only the first
1814    // `headers.len()` fields per record would be encoded) and records
1815    // narrower than `headers` would only fail late when a per-column
1816    // `rec.get(j)` lookup returned `None`. Reject both cases explicitly so
1817    // callers cannot accidentally drop data via header/record shape skew.
1818    for (i, rec) in records.iter().enumerate() {
1819        if rec.len() != p {
1820            return Err(DataError::SchemaMismatch {
1821                reason: format!(
1822                    "row width mismatch at row {}: got {} fields, expected {} (one per header)",
1823                    i + 1,
1824                    rec.len(),
1825                    p
1826                ),
1827            }
1828            .into());
1829        }
1830    }
1831    let schema_byname: HashMap<&str, &SchemaColumn> = schema
1832        .columns
1833        .iter()
1834        .map(|c| (c.name.as_str(), c))
1835        .collect();
1836
1837    // Each column is encoded independently from the same row-major records, so
1838    // fan the per-column passes out over rayon (columns, not rows, so threads
1839    // never contend on a shared output cell). Each task returns its dense
1840    // `(kind, Vec<f64>)`; we then assemble the row-major `Array2` from the
1841    // collected columns. For wide frames this is the dominant ingest cost.
1842    let encoded_columns = headers
1843        .par_iter()
1844        .enumerate()
1845        .map(|(j, name)| {
1846            let inferred_for_extra;
1847            let col_schema = if let Some(s) = schema_byname.get(name.as_str()) {
1848                *s
1849            } else {
1850                inferred_for_extra =
1851                    infer_schema_column(name, &records, j).map_err(String::from)?;
1852                &inferred_for_extra
1853            };
1854            let column = encode_one_column(name, &records, j, col_schema, &unseen_policy)?;
1855            Ok::<(ColumnKindTag, Vec<f64>), String>((col_schema.kind, column))
1856        })
1857        .collect::<Result<Vec<(ColumnKindTag, Vec<f64>)>, String>>()?;
1858
1859    let mut column_kinds = Vec::<ColumnKindTag>::with_capacity(p);
1860    let mut values = Array2::<f64>::zeros((n, p));
1861    for (j, (kind, column)) in encoded_columns.into_iter().enumerate() {
1862        column_kinds.push(kind);
1863        values
1864            .column_mut(j)
1865            .assign(&ndarray::ArrayView1::from(&column));
1866    }
1867
1868    Ok(EncodedDataset {
1869        headers,
1870        values,
1871        schema: schema.clone(),
1872        column_kinds,
1873    })
1874}
1875
1876/// Encode a single column `j` of `records` to its dense `f64` representation
1877/// under `col_schema`. Continuous/binary values are parsed; categorical values
1878/// are mapped to their level index (or the unseen code under `unseen_policy`).
1879/// This is the per-column work unit fanned out across columns in
1880/// [`encode_recordswith_schema`]; it scans only field `j` of each record so
1881/// distinct columns never touch shared state.
1882fn encode_one_column(
1883    name: &str,
1884    records: &[StringRecord],
1885    j: usize,
1886    col_schema: &SchemaColumn,
1887    unseen_policy: &UnseenCategoryPolicy,
1888) -> Result<Vec<f64>, String> {
1889    let level_map = if matches!(col_schema.kind, ColumnKindTag::Categorical) {
1890        Some(
1891            col_schema
1892                .levels
1893                .iter()
1894                .enumerate()
1895                .map(|(idx, v)| (v.as_str(), idx as f64))
1896                .collect::<HashMap<_, _>>(),
1897        )
1898    } else {
1899        None
1900    };
1901
1902    let mut column = Vec::<f64>::with_capacity(records.len());
1903    for (i, rec) in records.iter().enumerate() {
1904        let raw = rec
1905            .get(j)
1906            .ok_or_else(|| {
1907                String::from(DataError::SchemaMismatch {
1908                    reason: format!("missing field at row {}, col {}", i + 1, j + 1),
1909                })
1910            })?
1911            .trim();
1912        if raw.is_empty() {
1913            return Err(DataError::EmptyInput {
1914                reason: format!("empty field at row {}, column '{}'", i + 1, name),
1915            }
1916            .into());
1917        }
1918        let val = match col_schema.kind {
1919            ColumnKindTag::Continuous => raw.parse::<f64>().map_err(|err| {
1920                String::from(DataError::SchemaMismatch {
1921                    reason: format!(
1922                        "column '{}' is continuous in schema but row {} has non-numeric value '{}': {}",
1923                        name,
1924                        i + 1,
1925                        raw,
1926                        err
1927                    ),
1928                })
1929            })?,
1930            ColumnKindTag::Binary => {
1931                let v = raw.parse::<f64>().map_err(|err| {
1932                    String::from(DataError::SchemaMismatch {
1933                        reason: format!(
1934                            "column '{}' is binary in schema but row {} has non-numeric value '{}': {}",
1935                            name,
1936                            i + 1,
1937                            raw,
1938                            err
1939                        ),
1940                    })
1941                })?;
1942                if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
1943                    return Err(DataError::SchemaMismatch {
1944                        reason: format!(
1945                            "column '{}' is binary in schema but row {} has value {}; expected 0 or 1",
1946                            name,
1947                            i + 1,
1948                            v
1949                        ),
1950                    }
1951                    .into());
1952                }
1953                v
1954            }
1955            ColumnKindTag::Categorical => {
1956                let map = level_map.as_ref().ok_or_else(|| {
1957                    String::from(DataError::EncodingFailure {
1958                        reason: "internal categorical schema map missing".to_string(),
1959                    })
1960                })?;
1961                match map.get(raw) {
1962                    Some(v) => *v,
1963                    None => unseen_policy
1964                        .unseen_code_for(name, col_schema.levels.len())
1965                        .ok_or_else(|| {
1966                            String::from(DataError::SchemaMismatch {
1967                                reason: format!(
1968                                    "unseen level '{}' in categorical column '{}' at row {}; allowed levels: {}",
1969                                    raw,
1970                                    name,
1971                                    i + 1,
1972                                    col_schema.levels.join(",")
1973                                ),
1974                            })
1975                        })?,
1976                }
1977            }
1978        };
1979        if !val.is_finite() {
1980            return Err(DataError::InvalidValue {
1981                reason: format!("non-finite value at row {}, column '{}'", i + 1, name),
1982            }
1983            .into());
1984        }
1985        column.push(val);
1986    }
1987    Ok(column)
1988}
1989
1990fn infer_schema_column(
1991    name: &str,
1992    records: &[StringRecord],
1993    col_idx: usize,
1994) -> Result<SchemaColumn, DataError> {
1995    let mut all_numeric = true;
1996    let mut all_binary = true;
1997    let mut levels = Vec::<String>::new();
1998    let mut level_index = HashMap::<String, usize>::new();
1999    for (i, rec) in records.iter().enumerate() {
2000        let raw = rec
2001            .get(col_idx)
2002            .ok_or_else(|| DataError::SchemaMismatch {
2003                reason: format!("missing field at row {}, col {}", i + 1, col_idx + 1),
2004            })?
2005            .trim();
2006        if raw.is_empty() {
2007            return Err(DataError::EmptyInput {
2008                reason: format!("empty field at row {}, column '{}'", i + 1, name),
2009            });
2010        }
2011        if let Ok(v) = raw.parse::<f64>() {
2012            if !v.is_finite() {
2013                return Err(DataError::InvalidValue {
2014                    reason: format!("non-finite value at row {}, column '{}'", i + 1, name),
2015                });
2016            }
2017            if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
2018                all_binary = false;
2019            }
2020        } else {
2021            all_numeric = false;
2022            all_binary = false;
2023            level_index.entry(raw.to_string()).or_insert_with(|| {
2024                let idx = levels.len();
2025                levels.push(raw.to_string());
2026                idx
2027            });
2028        }
2029    }
2030    let kind = if all_numeric {
2031        if all_binary {
2032            ColumnKindTag::Binary
2033        } else {
2034            ColumnKindTag::Continuous
2035        }
2036    } else {
2037        ColumnKindTag::Categorical
2038    };
2039    // Canonical (sorted) level order — see `infer_and_encode_column_major`. The
2040    // record-driven and column-major inference paths must produce byte-identical
2041    // schemas, so both sort the level set lexicographically (#1319).
2042    if matches!(kind, ColumnKindTag::Categorical) {
2043        levels.sort();
2044    }
2045    Ok(SchemaColumn {
2046        name: name.to_string(),
2047        kind,
2048        levels: if matches!(kind, ColumnKindTag::Categorical) {
2049            levels
2050        } else {
2051            Vec::new()
2052        },
2053    })
2054}
2055
2056/// Infer the schema of, and densely encode, a single column presented in
2057/// column-major form (`name` + its raw string field for every row).
2058///
2059/// This is the column-major sibling of the record-driven path: it produces the
2060/// byte-identical `(SchemaColumn, Vec<f64>)` that `encode_recordswith_inferred_schema`
2061/// would produce for the same column, but it reads from a `&[&str]` column
2062/// slice instead of indexing field `col_idx` of every `StringRecord`. It exists
2063/// so callers holding column-major data (e.g. the Python FFI, which can
2064/// fingerprint and cache invariant columns shared across many fits of the same
2065/// base cohort) can encode one column at a time without first materializing the
2066/// full row-major record table. `col_index` is 1-based only for error text and
2067/// matches the record-driven messages.
2068pub fn infer_and_encode_column_major(
2069    name: &str,
2070    column: &[&str],
2071    col_index: usize,
2072) -> Result<(SchemaColumn, Vec<f64>), String> {
2073    if column.is_empty() {
2074        return Err(DataError::EmptyInput {
2075            reason: "table data cannot be empty".to_string(),
2076        }
2077        .into());
2078    }
2079    // A typed Python frame prefixes every cell of a categorical-dtype column
2080    // with `CATEGORICAL_CELL_SENTINEL` so the column is encoded as a factor even
2081    // when its labels parse as numbers ("0","1","2"). Detect and strip the
2082    // marker before inference; its presence forces `Categorical` (#1317/#1318).
2083    let force_categorical = column.iter().any(|c| strip_categorical_sentinel(c).1);
2084    let mut all_numeric = !force_categorical;
2085    let mut all_binary = !force_categorical;
2086    let mut levels = Vec::<String>::new();
2087    let mut level_index = HashMap::<String, usize>::new();
2088    let mut trimmed = Vec::<&str>::with_capacity(column.len());
2089    // Capture the parsed numeric value alongside each trimmed field during the
2090    // single inference scan, so the encode pass below never re-parses a numeric
2091    // string. For wide biobank frames the f64 parse dominated ingest, and the
2092    // record-driven path used to parse every continuous/binary field twice
2093    // (once to infer the schema, once to encode). `parsed[i]` is `Some(v)` iff
2094    // field `i` parsed as a finite f64; categorical columns ignore it.
2095    let mut parsed = Vec::<Option<f64>>::with_capacity(column.len());
2096    for (i, raw_field) in column.iter().enumerate() {
2097        // Strip the categorical marker (if any) so the recorded level label and
2098        // any numeric parse see the user's clean text, not the sentinel.
2099        let (raw, _) = strip_categorical_sentinel(raw_field);
2100        let raw = raw.trim();
2101        if raw.is_empty() {
2102            return Err(DataError::EmptyInput {
2103                reason: format!("empty field at row {}, column '{}'", i + 1, name),
2104            }
2105            .into());
2106        }
2107        // When the source column is dtype-categorical, every cell is a level
2108        // regardless of whether its label parses as a number.
2109        if !force_categorical {
2110            if let Ok(v) = raw.parse::<f64>() {
2111                if !v.is_finite() {
2112                    return Err(DataError::InvalidValue {
2113                        reason: format!("non-finite value at row {}, column '{}'", i + 1, name),
2114                    }
2115                    .into());
2116                }
2117                if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
2118                    all_binary = false;
2119                }
2120                parsed.push(Some(v));
2121                trimmed.push(raw);
2122                continue;
2123            }
2124            all_numeric = false;
2125            all_binary = false;
2126        }
2127        level_index.entry(raw.to_string()).or_insert_with(|| {
2128            let idx = levels.len();
2129            levels.push(raw.to_string());
2130            idx
2131        });
2132        parsed.push(None);
2133        trimmed.push(raw);
2134    }
2135    let kind = if all_numeric {
2136        if all_binary {
2137            ColumnKindTag::Binary
2138        } else {
2139            ColumnKindTag::Continuous
2140        }
2141    } else {
2142        ColumnKindTag::Categorical
2143    };
2144    // Canonical level ordering: sort factor levels lexicographically rather than
2145    // recording them in first-appearance order. Every reference tool a gam user
2146    // comes from — R `factor()` (C-locale sort), pandas `Categorical`, sklearn
2147    // `LabelEncoder` — orders categorical levels canonically, and downstream
2148    // consumers key off that order: the multinomial driver lays out one output
2149    // probability column per level and takes the *last* level as the softmax
2150    // reference, so first-appearance order made the `(n, K)` prediction columns
2151    // depend on which class happened to appear first in the training rows (a
2152    // row-shuffle would permute the output) instead of on the class labels
2153    // (#1319). Sorting makes the encoding a deterministic function of the label
2154    // *set*, independent of row order, and matches the factor convention so
2155    // column `k` of a multinomial prediction is class `levels[k]`.
2156    if matches!(kind, ColumnKindTag::Categorical) {
2157        levels.sort();
2158    }
2159    let schema = SchemaColumn {
2160        name: name.to_string(),
2161        kind,
2162        levels: if matches!(kind, ColumnKindTag::Categorical) {
2163            levels
2164        } else {
2165            Vec::new()
2166        },
2167    };
2168
2169    let level_map = if matches!(kind, ColumnKindTag::Categorical) {
2170        Some(
2171            schema
2172                .levels
2173                .iter()
2174                .enumerate()
2175                .map(|(idx, v)| (v.as_str(), idx as f64))
2176                .collect::<HashMap<_, _>>(),
2177        )
2178    } else {
2179        None
2180    };
2181
2182    let mut values = Vec::<f64>::with_capacity(trimmed.len());
2183    for (i, raw) in trimmed.iter().enumerate() {
2184        let raw = *raw;
2185        let val = match kind {
2186            // Continuous/Binary kinds are only selected when every field parsed
2187            // as a finite f64 during inference, so `parsed[i]` is always `Some`
2188            // here — reuse it instead of re-parsing the string.
2189            ColumnKindTag::Continuous => parsed[i].ok_or_else(|| {
2190                String::from(DataError::EncodingFailure {
2191                    reason: format!(
2192                        "internal: continuous column '{}' lost its parsed value at row {} (col {})",
2193                        name,
2194                        i + 1,
2195                        col_index
2196                    ),
2197                })
2198            })?,
2199            ColumnKindTag::Binary => {
2200                let v = parsed[i].ok_or_else(|| {
2201                    String::from(DataError::EncodingFailure {
2202                        reason: format!(
2203                            "internal: binary column '{}' lost its parsed value at row {} (col {})",
2204                            name,
2205                            i + 1,
2206                            col_index
2207                        ),
2208                    })
2209                })?;
2210                if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
2211                    return Err(DataError::SchemaMismatch {
2212                        reason: format!(
2213                            "column '{}' is binary in schema but row {} has value {}; expected 0 or 1",
2214                            name,
2215                            i + 1,
2216                            v
2217                        ),
2218                    }
2219                    .into());
2220                }
2221                v
2222            }
2223            ColumnKindTag::Categorical => {
2224                let map = level_map.as_ref().ok_or_else(|| {
2225                    String::from(DataError::EncodingFailure {
2226                        reason: "internal categorical schema map missing".to_string(),
2227                    })
2228                })?;
2229                *map.get(raw).ok_or_else(|| {
2230                    String::from(DataError::EncodingFailure {
2231                        reason: format!(
2232                            "internal: level '{}' missing from freshly built map for column '{}' (col {})",
2233                            raw, name, col_index
2234                        ),
2235                    })
2236                })?
2237            }
2238        };
2239        if !val.is_finite() {
2240            return Err(DataError::InvalidValue {
2241                reason: format!("non-finite value at row {}, column '{}'", i + 1, name),
2242            }
2243            .into());
2244        }
2245        values.push(val);
2246    }
2247    Ok((schema, values))
2248}
2249
2250#[cfg(test)]
2251mod tests {
2252    use super::*;
2253
2254    #[test]
2255    fn encode_records_rejects_empty_input() {
2256        let headers = vec!["x".to_string()];
2257        let schema = DataSchema {
2258            columns: vec![SchemaColumn {
2259                name: "x".to_string(),
2260                kind: ColumnKindTag::Continuous,
2261                levels: Vec::new(),
2262            }],
2263        };
2264
2265        let err = encode_recordswith_inferred_schema(headers.clone(), Vec::new())
2266            .expect_err("empty inferred records should error");
2267        assert_eq!(err, "table data cannot be empty");
2268
2269        let err =
2270            encode_recordswith_schema(headers, Vec::new(), &schema, UnseenCategoryPolicy::Error)
2271                .expect_err("empty schema-guided records should error");
2272        assert_eq!(err, "table data cannot be empty");
2273    }
2274
2275    #[test]
2276    fn column_major_matches_record_driven_inferred_encode() {
2277        // The FFI ingest path encodes column-by-column via
2278        // `infer_and_encode_column_major`; it must produce byte-identical
2279        // schema + values to the record-driven `encode_recordswith_inferred_schema`
2280        // for the same frame across all three inferred kinds.
2281        let headers = vec!["cont".to_string(), "bin".to_string(), "cat".to_string()];
2282        let raw_rows = vec![
2283            vec!["1.5", "0", "a"],
2284            vec!["2.0", "1", "b"],
2285            vec!["-3.25", "1", "a"],
2286            vec!["0.0", "0", "c"],
2287        ];
2288        let records: Vec<StringRecord> = raw_rows
2289            .iter()
2290            .map(|r| StringRecord::from(r.clone()))
2291            .collect();
2292        let record_ds = encode_recordswith_inferred_schema(headers.clone(), records)
2293            .expect("record-driven encode");
2294
2295        for (j, name) in headers.iter().enumerate() {
2296            let column: Vec<&str> = raw_rows.iter().map(|r| r[j]).collect();
2297            let (schema_col, values) =
2298                infer_and_encode_column_major(name, &column, j + 1).expect("column-major encode");
2299            assert_eq!(schema_col.kind, record_ds.schema.columns[j].kind);
2300            assert_eq!(schema_col.levels, record_ds.schema.columns[j].levels);
2301            for (i, v) in values.iter().enumerate() {
2302                assert_eq!(*v, record_ds.values[[i, j]], "row {i} col {name}");
2303            }
2304        }
2305    }
2306
2307    #[test]
2308    fn encode_records_can_encode_unseen_named_categorical_column() {
2309        let schema = DataSchema {
2310            columns: vec![
2311                SchemaColumn {
2312                    name: "g".to_string(),
2313                    kind: ColumnKindTag::Categorical,
2314                    levels: vec!["a".to_string(), "b".to_string()],
2315                },
2316                SchemaColumn {
2317                    name: "x".to_string(),
2318                    kind: ColumnKindTag::Categorical,
2319                    levels: vec!["low".to_string(), "high".to_string()],
2320                },
2321            ],
2322        };
2323        let headers = vec!["g".to_string(), "x".to_string()];
2324        let records = vec![StringRecord::from(vec!["new-group", "low"])];
2325        let policy =
2326            UnseenCategoryPolicy::encode_unknown_for_columns(HashSet::from(["g".to_string()]));
2327
2328        let ds =
2329            encode_recordswith_schema(headers, records, &schema, policy).expect("encoded dataset");
2330
2331        assert_eq!(ds.values[[0, 0]], 2.0);
2332        assert_eq!(ds.values[[0, 1]], 0.0);
2333    }
2334
2335    #[test]
2336    fn numeric_valued_dictionary_column_classifies_and_decodes_as_numeric() {
2337        // Regression for #1162: pyarrow dictionary-encodes low-cardinality
2338        // *numeric* columns by default (e.g. `Dictionary(Int8, Int64)`).
2339        // Dictionary encoding is a storage detail, not a semantic type, so such
2340        // a column must stay numeric — both at classification time
2341        // (`parquet_field_is_string`) and at decode time
2342        // (`decode_parquet_batch_column`). Previously the loader matched ALL
2343        // `Dictionary(_, _)` as string/categorical, silently flipping numeric
2344        // features to categorical and rejecting valid numeric prediction files
2345        // with SchemaMismatch.
2346        use arrow::array::{Array, ArrayRef, DictionaryArray, Int8Array, Int64Array};
2347        use arrow::datatypes::{DataType, Int8Type};
2348        use std::sync::Arc;
2349
2350        // Logical column values: 5, 7, 5, 7, 5 (low-cardinality integers).
2351        let keys = Int8Array::from(vec![0i8, 1, 0, 1, 0]);
2352        let dict_values: ArrayRef = Arc::new(Int64Array::from(vec![5i64, 7]));
2353        let dict: DictionaryArray<Int8Type> = DictionaryArray::new(keys, dict_values);
2354
2355        // The dictionary's *value* type is numeric, so the column must NOT be
2356        // classified as string/categorical.
2357        assert!(matches!(dict.data_type(), DataType::Dictionary(_, _)));
2358        assert!(
2359            !parquet_field_is_string(dict.data_type()),
2360            "Dictionary(Int8, Int64) must not be treated as a string column"
2361        );
2362
2363        // A genuine string-valued dictionary still classifies as string.
2364        let str_dict: DictionaryArray<Int8Type> = vec!["a", "b", "a"].into_iter().collect();
2365        assert!(
2366            parquet_field_is_string(str_dict.data_type()),
2367            "Dictionary(Int8, Utf8) must remain a string column"
2368        );
2369
2370        // Decoding the numeric dictionary with `is_string_col = false` must
2371        // resolve indices through the dictionary and yield the underlying
2372        // numeric values (not error, not strings).
2373        let decoded = decode_parquet_batch_column(&dict, dict.len(), 0, "x", false)
2374            .expect("numeric dictionary column should decode as numeric");
2375        match decoded {
2376            ParquetBatchColumn::Numeric(values) => {
2377                assert_eq!(values, vec![5.0, 7.0, 5.0, 7.0, 5.0]);
2378            }
2379            ParquetBatchColumn::Strings(_) => {
2380                panic!("numeric dictionary column was decoded as strings");
2381            }
2382        }
2383
2384        // End-to-end: write a real parquet file whose only column is the
2385        // dictionary-encoded numeric one, then load it both ways. This is the
2386        // exact repro from #1162 (pyarrow's default dictionary encoding of a
2387        // low-cardinality numeric column).
2388        use arrow::datatypes::{Field, Schema};
2389        use arrow::record_batch::RecordBatch;
2390        use parquet::arrow::ArrowWriter;
2391
2392        let arrow_schema = Arc::new(Schema::new(vec![Field::new(
2393            "x",
2394            dict.data_type().clone(),
2395            false,
2396        )]));
2397        let batch = RecordBatch::try_new(arrow_schema.clone(), vec![Arc::new(dict.clone())])
2398            .expect("record batch with a dictionary numeric column");
2399
2400        let dir = tempfile::tempdir().expect("tempdir");
2401        let path = dir.path().join("dict_numeric.parquet");
2402        {
2403            let file = std::fs::File::create(&path).expect("create parquet");
2404            let mut writer =
2405                ArrowWriter::try_new(file, arrow_schema, None).expect("arrow parquet writer");
2406            writer.write(&batch).expect("write batch");
2407            writer.close().expect("close writer");
2408        }
2409
2410        // Inferred load: the column must be Continuous (5 and 7 are not 0/1),
2411        // never Categorical.
2412        let inferred =
2413            load_parquet_inferred(&path, &[], &HashSet::new()).expect("inferred parquet load");
2414        assert_eq!(inferred.column_kinds, vec![ColumnKindTag::Continuous]);
2415        assert_eq!(
2416            inferred.values.column(0).to_vec(),
2417            vec![5.0, 7.0, 5.0, 7.0, 5.0]
2418        );
2419
2420        // Schema-driven load with the column declared Continuous (as it would be
2421        // after training on CSV / non-dictionary parquet) must NOT raise the
2422        // SchemaMismatch that #1162 reported on valid numeric data.
2423        let schema = DataSchema {
2424            columns: vec![SchemaColumn {
2425                name: "x".to_string(),
2426                kind: ColumnKindTag::Continuous,
2427                levels: Vec::new(),
2428            }],
2429        };
2430        let schema_loaded =
2431            load_parquet_with_schema(&path, &schema, UnseenCategoryPolicy::Error, &[])
2432                .expect("dictionary-encoded numeric parquet must load against a Continuous schema");
2433        assert_eq!(schema_loaded.column_kinds, vec![ColumnKindTag::Continuous]);
2434        assert_eq!(
2435            schema_loaded.values.column(0).to_vec(),
2436            vec![5.0, 7.0, 5.0, 7.0, 5.0]
2437        );
2438    }
2439
2440    #[test]
2441    fn encode_records_keeps_unlisted_categorical_columns_strict() {
2442        let schema = DataSchema {
2443            columns: vec![
2444                SchemaColumn {
2445                    name: "g".to_string(),
2446                    kind: ColumnKindTag::Categorical,
2447                    levels: vec!["a".to_string(), "b".to_string()],
2448                },
2449                SchemaColumn {
2450                    name: "x".to_string(),
2451                    kind: ColumnKindTag::Categorical,
2452                    levels: vec!["low".to_string(), "high".to_string()],
2453                },
2454            ],
2455        };
2456        let headers = vec!["g".to_string(), "x".to_string()];
2457        let records = vec![StringRecord::from(vec!["a", "new-level"])];
2458        let policy =
2459            UnseenCategoryPolicy::encode_unknown_for_columns(HashSet::from(["g".to_string()]));
2460
2461        let err = encode_recordswith_schema(headers, records, &schema, policy)
2462            .expect_err("ordinary categorical column should stay strict");
2463
2464        assert!(err.contains("unseen level 'new-level' in categorical column 'x'"));
2465    }
2466
2467    // -----------------------------------------------------------------------
2468    // strip_categorical_sentinel
2469    // -----------------------------------------------------------------------
2470
2471    #[test]
2472    fn sentinel_strip_present_returns_rest_and_true() {
2473        let marked = format!("{}{}", CATEGORICAL_CELL_SENTINEL, "hello");
2474        let (rest, found) = strip_categorical_sentinel(&marked);
2475        assert_eq!(rest, "hello");
2476        assert!(found);
2477    }
2478
2479    #[test]
2480    fn sentinel_strip_absent_returns_original_and_false() {
2481        let (rest, found) = strip_categorical_sentinel("hello");
2482        assert_eq!(rest, "hello");
2483        assert!(!found);
2484    }
2485
2486    #[test]
2487    fn sentinel_strip_empty_string_returns_empty_and_false() {
2488        let (rest, found) = strip_categorical_sentinel("");
2489        assert_eq!(rest, "");
2490        assert!(!found);
2491    }
2492
2493    #[test]
2494    fn sentinel_strip_only_sentinel_returns_empty_and_true() {
2495        let marked = CATEGORICAL_CELL_SENTINEL.to_string();
2496        let (rest, found) = strip_categorical_sentinel(&marked);
2497        assert_eq!(rest, "");
2498        assert!(found);
2499    }
2500
2501    // -----------------------------------------------------------------------
2502    // EncodedDataset::feature_ranges
2503    // -----------------------------------------------------------------------
2504
2505    #[test]
2506    fn feature_ranges_two_columns() {
2507        let values = ndarray::arr2(&[[1.0_f64, 10.0], [3.0, 20.0], [2.0, 15.0]]);
2508        let ds = EncodedDataset {
2509            headers: vec!["a".to_string(), "b".to_string()],
2510            values,
2511            schema: DataSchema { columns: vec![] },
2512            column_kinds: vec![ColumnKindTag::Continuous, ColumnKindTag::Continuous],
2513        };
2514        let ranges = ds.feature_ranges();
2515        assert_eq!(ranges.len(), 2);
2516        assert_eq!(ranges[0], (1.0, 3.0));
2517        assert_eq!(ranges[1], (10.0, 20.0));
2518    }
2519
2520    #[test]
2521    fn feature_ranges_single_row_min_equals_max() {
2522        let values = ndarray::arr2(&[[5.0_f64, -3.0]]);
2523        let ds = EncodedDataset {
2524            headers: vec!["x".to_string(), "y".to_string()],
2525            values,
2526            schema: DataSchema { columns: vec![] },
2527            column_kinds: vec![ColumnKindTag::Continuous, ColumnKindTag::Continuous],
2528        };
2529        let ranges = ds.feature_ranges();
2530        assert_eq!(ranges[0], (5.0, 5.0));
2531        assert_eq!(ranges[1], (-3.0, -3.0));
2532    }
2533
2534    #[test]
2535    fn feature_ranges_all_nan_defaults_to_zero() {
2536        let values = ndarray::arr2(&[[f64::NAN], [f64::NAN]]);
2537        let ds = EncodedDataset {
2538            headers: vec!["x".to_string()],
2539            values,
2540            schema: DataSchema { columns: vec![] },
2541            column_kinds: vec![ColumnKindTag::Continuous],
2542        };
2543        let ranges = ds.feature_ranges();
2544        assert_eq!(ranges[0], (0.0, 0.0));
2545    }
2546
2547    // -----------------------------------------------------------------------
2548    // EncodedDataset::column_map
2549    // -----------------------------------------------------------------------
2550
2551    #[test]
2552    fn column_map_indexes_by_name() {
2553        let values = ndarray::arr2(&[[0.0_f64, 1.0], [2.0, 3.0]]);
2554        let ds = EncodedDataset {
2555            headers: vec!["alpha".to_string(), "beta".to_string()],
2556            values,
2557            schema: DataSchema { columns: vec![] },
2558            column_kinds: vec![ColumnKindTag::Continuous, ColumnKindTag::Continuous],
2559        };
2560        let map = ds.column_map();
2561        assert_eq!(map["alpha"], 0);
2562        assert_eq!(map["beta"], 1);
2563        assert_eq!(map.len(), 2);
2564    }
2565}