1use csv::{ReaderBuilder, StringRecord};
2use ndarray::{Array2, Axis};
3use rayon::prelude::*;
4use serde::{Deserialize, Serialize};
5use std::cmp::Ordering;
6use std::collections::{HashMap, HashSet};
7use std::fmt;
8use std::path::Path;
9
10fn natural_level_cmp(a: &str, b: &str) -> Ordering {
11 let mut ia = 0;
12 let mut ib = 0;
13 let ba = a.as_bytes();
14 let bb = b.as_bytes();
15 while ia < ba.len() && ib < bb.len() {
16 if ba[ia].is_ascii_digit() && bb[ib].is_ascii_digit() {
17 let sa = ia;
18 let sb = ib;
19 while ia < ba.len() && ba[ia].is_ascii_digit() {
20 ia += 1;
21 }
22 while ib < bb.len() && bb[ib].is_ascii_digit() {
23 ib += 1;
24 }
25 let da = &a[sa..ia];
26 let db = &b[sb..ib];
27 let ta = da.trim_start_matches('0');
28 let tb = db.trim_start_matches('0');
29 let ta = if ta.is_empty() { "0" } else { ta };
30 let tb = if tb.is_empty() { "0" } else { tb };
31 match ta.len().cmp(&tb.len()).then_with(|| ta.cmp(tb)) {
32 Ordering::Equal if da.len() != db.len() => return da.len().cmp(&db.len()),
33 Ordering::Equal => {}
34 ord => return ord,
35 }
36 } else {
37 match ba[ia].cmp(&bb[ib]) {
38 Ordering::Equal => {
39 ia += 1;
40 ib += 1;
41 }
42 ord => return ord,
43 }
44 }
45 }
46 ba.len().cmp(&bb.len())
47}
48
49fn sort_levels_canonical(levels: &mut [String]) {
50 levels.sort_by(|a, b| natural_level_cmp(a, b));
51}
52
53#[derive(Debug, Clone)]
64pub enum DataError {
65 SchemaMismatch { reason: String },
70 ParseError { reason: String },
74 EncodingFailure { reason: String },
78 EmptyInput { reason: String },
81 InvalidValue { reason: String },
85 ColumnNotFound {
92 name: String,
94 role: Option<String>,
98 available: Vec<String>,
100 similar: Vec<String>,
103 tsv_hint: bool,
108 },
109}
110
111impl DataError {
112 pub fn column_not_found(
118 col_map: &HashMap<String, usize>,
119 name: &str,
120 role: Option<&str>,
121 ) -> Self {
122 let target_lower = name.to_lowercase();
123 let mut similar: Vec<String> = col_map
124 .keys()
125 .filter(|k| {
126 let k_lower = k.to_lowercase();
127 k_lower.contains(&target_lower)
128 || target_lower.contains(&k_lower)
129 || shared_prefix(&k_lower, &target_lower) >= 3
130 })
131 .cloned()
132 .collect();
133 similar.sort_unstable();
134 let mut available: Vec<String> = col_map.keys().cloned().collect();
135 available.sort_unstable();
136 let tsv_hint = available.len() == 1 && available[0].contains('\t');
137 Self::ColumnNotFound {
138 name: name.to_string(),
139 role: role.map(str::to_string),
140 available,
141 similar,
142 tsv_hint,
143 }
144 }
145}
146
147impl fmt::Display for DataError {
148 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149 match self {
150 DataError::SchemaMismatch { reason }
151 | DataError::ParseError { reason }
152 | DataError::EncodingFailure { reason }
153 | DataError::EmptyInput { reason }
154 | DataError::InvalidValue { reason } => f.write_str(reason),
155 DataError::ColumnNotFound {
156 name,
157 role,
158 available,
159 similar,
160 tsv_hint,
161 } => {
162 let label = match role {
163 Some(r) => format!("{r} column '{name}'"),
164 None => format!("column '{name}'"),
165 };
166 let tsv_suffix = if *tsv_hint {
167 " — your file appears to be tab-separated; gam expects comma-separated CSV. \
168 Replace tabs with commas, or pre-convert with `tr '\\t' ',' < file.tsv > file.csv`."
169 } else {
170 ""
171 };
172 if similar.is_empty() {
173 write!(
174 f,
175 "{label} not found in data. Available columns: [{}]{tsv_suffix}",
176 available.join(", ")
177 )
178 } else {
179 write!(
180 f,
181 "{label} not found in data. Did you mean one of [{}]? Full list: [{}]{tsv_suffix}",
182 similar.join(", "),
183 available.join(", ")
184 )
185 }
186 }
187 }
188 }
189}
190
191impl std::error::Error for DataError {}
192
193impl From<DataError> for String {
194 fn from(err: DataError) -> String {
195 err.to_string()
196 }
197}
198
199#[derive(Clone, Debug, Serialize, Deserialize)]
204pub struct DataSchema {
205 pub columns: Vec<SchemaColumn>,
206}
207
208#[derive(Clone, Debug, Serialize, Deserialize)]
209pub struct SchemaColumn {
210 pub name: String,
211 pub kind: ColumnKindTag,
212 #[serde(default)]
213 pub levels: Vec<String>,
214}
215
216#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
217#[serde(rename_all = "kebab-case")]
218pub enum ColumnKindTag {
219 Continuous,
220 Binary,
221 Categorical,
222}
223
224#[derive(Clone, Debug, Eq, PartialEq)]
225pub enum UnseenCategoryPolicy {
226 Error,
227 EncodeUnknownForColumns(HashSet<String>),
228}
229
230impl UnseenCategoryPolicy {
231 pub fn encode_unknown_for_columns(columns: HashSet<String>) -> Self {
232 if columns.is_empty() {
233 Self::Error
234 } else {
235 Self::EncodeUnknownForColumns(columns)
236 }
237 }
238
239 fn unseen_code_for(&self, column_name: &str, level_count: usize) -> Option<f64> {
240 match self {
241 Self::Error => None,
242 Self::EncodeUnknownForColumns(columns) => {
243 columns.contains(column_name).then_some(level_count as f64)
244 }
245 }
246 }
247}
248
249#[derive(Clone, Debug)]
250pub struct EncodedDataset {
251 pub headers: Vec<String>,
252 pub values: Array2<f64>,
253 pub schema: DataSchema,
254 pub column_kinds: Vec<ColumnKindTag>,
255}
256
257impl EncodedDataset {
258 pub fn column_map(&self) -> HashMap<String, usize> {
259 self.headers
260 .iter()
261 .enumerate()
262 .map(|(index, header)| (header.clone(), index))
263 .collect()
264 }
265
266 pub fn feature_ranges(&self) -> Vec<(f64, f64)> {
272 self.values
279 .axis_iter(Axis(1))
280 .into_par_iter()
281 .map(|col| {
282 let (lo, hi) =
283 col.iter()
284 .fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
285 if v.is_finite() {
286 (lo.min(v), hi.max(v))
287 } else {
288 (lo, hi)
289 }
290 });
291 if !lo.is_finite() || !hi.is_finite() {
292 (0.0, 0.0)
293 } else {
294 (lo, hi)
295 }
296 })
297 .collect()
298 }
299}
300
301fn shared_prefix(a: &str, b: &str) -> usize {
302 a.chars()
303 .zip(b.chars())
304 .take_while(|(ca, cb)| ca == cb)
305 .count()
306}
307
308#[derive(Clone, Copy, Debug, Eq, PartialEq)]
313enum DataFormat {
314 Csv,
315 Tsv,
316 Parquet,
317}
318
319fn detect_format(path: &Path) -> Result<DataFormat, DataError> {
320 let ext = path
321 .extension()
322 .and_then(|s| s.to_str())
323 .unwrap_or_default()
324 .to_ascii_lowercase();
325 match ext.as_str() {
326 "csv" => Ok(DataFormat::Csv),
327 "tsv" | "txt" | "tab" => Ok(DataFormat::Tsv),
328 "parquet" | "pq" | "pqt" => Ok(DataFormat::Parquet),
329 other => Err(DataError::ParseError {
330 reason: format!(
331 "unsupported data file extension '.{other}'; expected csv, tsv, txt, parquet, or pq: '{}'",
332 path.display()
333 ),
334 }),
335 }
336}
337
338pub fn load_dataset_projected(
343 path: &Path,
344 requested_columns: &[String],
345) -> Result<EncodedDataset, DataError> {
346 load_dataset_projected_with_categorical_roles(path, requested_columns, &HashSet::new())
347}
348
349pub fn load_dataset_projected_with_categorical_roles(
371 path: &Path,
372 requested_columns: &[String],
373 categorical_roles: &HashSet<&str>,
374) -> Result<EncodedDataset, DataError> {
375 match detect_format(path)? {
376 DataFormat::Csv => {
377 load_delimited_inferred(path, b',', requested_columns, categorical_roles)
378 }
379 DataFormat::Tsv => {
380 load_delimited_inferred(path, b'\t', requested_columns, categorical_roles)
381 }
382 DataFormat::Parquet => load_parquet_inferred(path, requested_columns, categorical_roles),
383 }
384}
385
386pub fn load_datasetwith_schema_projected(
387 path: &Path,
388 schema: &DataSchema,
389 unseen_policy: UnseenCategoryPolicy,
390 requested_columns: &[String],
391) -> Result<EncodedDataset, DataError> {
392 match detect_format(path)? {
393 DataFormat::Csv => {
394 load_delimited_with_schema(path, b',', schema, unseen_policy, requested_columns)
395 }
396 DataFormat::Tsv => {
397 load_delimited_with_schema(path, b'\t', schema, unseen_policy, requested_columns)
398 }
399 DataFormat::Parquet => {
400 load_parquet_with_schema(path, schema, unseen_policy, requested_columns)
401 }
402 }
403}
404
405pub fn load_csvwith_inferred_schema(path: &Path) -> Result<EncodedDataset, DataError> {
410 load_delimited_inferred(path, b',', &[], &HashSet::new())
411}
412
413const SCHEMA_SAMPLE_ROWS: usize = 1024;
419
420pub const CATEGORICAL_CELL_SENTINEL: char = '\u{0}';
430
431pub fn strip_categorical_sentinel(cell: &str) -> (&str, bool) {
434 match cell.strip_prefix(CATEGORICAL_CELL_SENTINEL) {
435 Some(rest) => (rest, true),
436 None => (cell, false),
437 }
438}
439
440fn resolve_requested_columns(
441 all_headers: &[String],
442 requested_columns: &[String],
443) -> Result<Vec<usize>, DataError> {
444 if requested_columns.is_empty() {
445 return Ok((0..all_headers.len()).collect());
446 }
447
448 let requested_set: HashSet<&str> = requested_columns.iter().map(String::as_str).collect();
449 let mut selected = Vec::with_capacity(requested_set.len());
450 for (idx, name) in all_headers.iter().enumerate() {
451 if requested_set.contains(name.as_str()) {
452 selected.push(idx);
453 }
454 }
455
456 if selected.len() != requested_set.len() {
457 let available_map: HashMap<String, usize> = all_headers
458 .iter()
459 .enumerate()
460 .map(|(index, header)| (header.clone(), index))
461 .collect();
462 let missing = requested_columns
463 .iter()
464 .filter(|name| !available_map.contains_key(name.as_str()))
465 .map(|name| {
466 DataError::column_not_found(&available_map, name, Some("requested")).to_string()
467 })
468 .collect::<Vec<_>>();
469 return Err(DataError::SchemaMismatch {
470 reason: missing.join("; "),
471 });
472 }
473
474 Ok(selected)
475}
476
477fn projected_headers(all_headers: &[String], selected_indices: &[usize]) -> Vec<String> {
478 selected_indices
479 .iter()
480 .map(|&idx| all_headers[idx].clone())
481 .collect()
482}
483
484fn load_delimited_inferred(
485 path: &Path,
486 delimiter: u8,
487 requested_columns: &[String],
488 categorical_roles: &HashSet<&str>,
489) -> Result<EncodedDataset, DataError> {
490 let t_open = std::time::Instant::now();
491 let mut rdr = ReaderBuilder::new()
492 .has_headers(true)
493 .delimiter(delimiter)
494 .from_path(path)
495 .map_err(|e| DataError::ParseError {
496 reason: format!("failed to open '{}': {e}", path.display()),
497 })?;
498
499 let all_headers: Vec<String> = rdr
500 .headers()
501 .map_err(|e| DataError::ParseError {
502 reason: format!("failed to read headers: {e}"),
503 })?
504 .iter()
505 .map(|s| s.trim().to_string())
506 .collect();
507 if all_headers.is_empty() {
508 return Err(DataError::EmptyInput {
509 reason: "file has no headers".to_string(),
510 });
511 }
512 let selected_indices = resolve_requested_columns(&all_headers, requested_columns)?;
513 let headers = projected_headers(&all_headers, &selected_indices);
514 let p = headers.len();
515 let open_ms = t_open.elapsed().as_secs_f64() * 1000.0;
516 if open_ms > 100.0 {
517 log::info!(
518 "[DATA-LOAD] delim_open+headers | n_headers={} | n_proj={} | {:.1}ms",
519 all_headers.len(),
520 p,
521 open_ms
522 );
523 }
524
525 let mut raw_fields = Vec::<String>::new();
533 let mut total_rows: usize = 0;
534 let mut stream_error: Option<DataError> = None;
535
536 let t_stream = std::time::Instant::now();
537 let mut record = StringRecord::new();
538 while rdr
539 .read_record(&mut record)
540 .map_err(|e| DataError::ParseError {
541 reason: format!("failed reading row: {e}"),
542 })?
543 {
544 if record.len() != all_headers.len() {
545 stream_error = Some(DataError::SchemaMismatch {
546 reason: format!(
547 "row width mismatch at row {}: got {} fields, expected {}",
548 total_rows + 1,
549 record.len(),
550 all_headers.len()
551 ),
552 });
553 break;
554 }
555 total_rows += 1;
556
557 for &selected_idx in &selected_indices {
558 let raw = record.get(selected_idx).unwrap().trim();
559 raw_fields.push(raw.to_string());
560 }
561 }
562
563 let stream_ms = t_stream.elapsed().as_secs_f64() * 1000.0;
564 if stream_ms > 100.0 {
565 log::info!(
566 "[DATA-LOAD] delim_stream | n_rows={} | n_cols={} | {:.1}ms",
567 total_rows,
568 p,
569 stream_ms
570 );
571 }
572
573 if total_rows == 0 {
574 if let Some(err) = stream_error {
575 return Err(err);
576 }
577 return Err(DataError::EmptyInput {
578 reason: "file has no rows".to_string(),
579 });
580 }
581
582 let t_schema = std::time::Instant::now();
583 let sample_count = total_rows.min(SCHEMA_SAMPLE_ROWS);
584 let inferred_columns = (0..p)
585 .into_par_iter()
586 .map(|j| {
587 infer_delimited_column(
588 &raw_fields,
589 total_rows,
590 p,
591 j,
592 &headers[j],
593 sample_count,
594 categorical_roles.contains(headers[j].as_str()),
595 )
596 })
597 .collect::<Vec<_>>();
598
599 let first_error = inferred_columns
600 .iter()
601 .filter_map(|result| result.as_ref().err())
602 .min_by_key(|err| (err.row, err.col));
603 if let Some(err) = first_error {
604 return Err(err.error.clone());
605 }
606 if let Some(err) = stream_error {
607 return Err(err);
608 }
609
610 let inferred_columns = inferred_columns
611 .into_iter()
612 .map(Result::unwrap)
613 .collect::<Vec<_>>();
614
615 let mut schema_cols = Vec::<SchemaColumn>::with_capacity(p);
617 let mut column_kinds = Vec::<ColumnKindTag>::with_capacity(p);
618 for (j, inferred) in inferred_columns.iter().enumerate() {
619 column_kinds.push(inferred.kind);
620 schema_cols.push(SchemaColumn {
621 name: headers[j].clone(),
622 kind: inferred.kind,
623 levels: if matches!(inferred.kind, ColumnKindTag::Categorical) {
624 inferred.levels.clone()
625 } else {
626 Vec::new()
627 },
628 });
629 }
630 let schema_ms = t_schema.elapsed().as_secs_f64() * 1000.0;
631 if schema_ms > 100.0 {
632 let n_cat = column_kinds
633 .iter()
634 .filter(|k| matches!(k, ColumnKindTag::Categorical))
635 .count();
636 log::info!(
637 "[DATA-LOAD] delim_convert+infer | n_cols={} | n_cat={} | {:.1}ms",
638 p,
639 n_cat,
640 schema_ms
641 );
642 }
643
644 let t_assemble = std::time::Instant::now();
645 let mut values = Array2::<f64>::zeros((total_rows, p));
647 values
648 .axis_iter_mut(Axis(1))
649 .into_par_iter()
650 .zip(inferred_columns.par_iter())
651 .for_each(|(mut out_col, inferred)| {
652 for (dst, &src) in out_col.iter_mut().zip(inferred.values.iter()) {
653 *dst = src;
654 }
655 });
656 let assemble_ms = t_assemble.elapsed().as_secs_f64() * 1000.0;
657 if assemble_ms > 100.0 {
658 log::info!(
659 "[DATA-LOAD] delim_assemble_array2 | n_rows={} | n_cols={} | {:.1}ms",
660 total_rows,
661 p,
662 assemble_ms
663 );
664 }
665
666 let schema = DataSchema {
667 columns: schema_cols,
668 };
669 Ok(EncodedDataset {
670 headers,
671 values,
672 schema,
673 column_kinds,
674 })
675}
676
677struct InferredDelimitedColumn {
678 values: Vec<f64>,
679 kind: ColumnKindTag,
680 levels: Vec<String>,
681}
682
683#[derive(Debug)]
684struct DelimitedInferenceError {
685 row: usize,
686 col: usize,
687 error: DataError,
688}
689
690fn infer_delimited_column(
691 raw_fields: &[String],
692 total_rows: usize,
693 n_cols: usize,
694 col: usize,
695 header: &str,
696 sample_count: usize,
697 force_categorical: bool,
698) -> Result<InferredDelimitedColumn, DelimitedInferenceError> {
699 let mut values = Vec::<f64>::with_capacity(total_rows);
701 let mut all_numeric = true;
702 let mut all_binary = true;
703 let mut level_index = HashMap::<String, usize>::new();
704 let mut levels = Vec::<String>::new();
705
706 let non_finite_err = |row_idx: usize| DelimitedInferenceError {
710 row: row_idx + 1,
711 col,
712 error: DataError::InvalidValue {
713 reason: format!(
714 "non-finite value at row {}, column '{}'",
715 row_idx + 1,
716 header
717 ),
718 },
719 };
720
721 for row_idx in 0..total_rows {
722 let raw = raw_fields[row_idx * n_cols + col].as_str();
723 if raw.is_empty() {
724 return Err(DelimitedInferenceError {
725 row: row_idx + 1,
726 col,
727 error: DataError::EmptyInput {
728 reason: format!("empty field at row {}, column '{}'", row_idx + 1, header),
729 },
730 });
731 }
732
733 if row_idx < sample_count {
735 if let Ok(v) = raw.parse::<f64>() {
736 if !v.is_finite() {
737 return Err(non_finite_err(row_idx));
738 }
739 if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
740 all_binary = false;
741 }
742 values.push(v);
743 } else {
744 all_numeric = false;
745 all_binary = false;
746 level_index.entry(raw.to_string()).or_insert_with(|| {
747 let idx = levels.len();
748 levels.push(raw.to_string());
749 idx
750 });
751 values.push(f64::NAN);
755 }
756 } else if let Ok(v) = raw.parse::<f64>() {
757 if !v.is_finite() {
761 return Err(non_finite_err(row_idx));
762 }
763 if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
764 all_binary = false;
765 }
766 values.push(v);
767 } else {
768 all_numeric = false;
769 all_binary = false;
770 let idx = *level_index.entry(raw.to_string()).or_insert_with(|| {
771 let new_idx = levels.len();
772 levels.push(raw.to_string());
773 new_idx
774 });
775 values.push(idx as f64);
776 }
777 }
778
779 let kind = if force_categorical {
788 ColumnKindTag::Categorical
789 } else if all_numeric {
790 if all_binary {
791 ColumnKindTag::Binary
792 } else {
793 ColumnKindTag::Continuous
794 }
795 } else {
796 ColumnKindTag::Categorical
797 };
798
799 if matches!(kind, ColumnKindTag::Categorical) {
800 for row_idx in 0..total_rows {
821 let raw = raw_fields[row_idx * n_cols + col].as_str();
822 level_index.entry(raw.to_string()).or_insert_with(|| {
823 let new_idx = levels.len();
824 levels.push(raw.to_string());
825 new_idx
826 });
827 }
828 sort_levels_canonical(&mut levels);
829 level_index.clear();
830 for (idx, level) in levels.iter().enumerate() {
831 level_index.insert(level.clone(), idx);
832 }
833 for row_idx in 0..total_rows {
834 let raw = raw_fields[row_idx * n_cols + col].as_str();
835 values[row_idx] = level_index[raw] as f64;
836 }
837 }
838
839 for (row_idx, &v) in values.iter().enumerate() {
840 if !v.is_finite() {
841 return Err(non_finite_err(row_idx));
842 }
843 }
844
845 Ok(InferredDelimitedColumn {
846 values,
847 kind,
848 levels,
849 })
850}
851
852fn load_delimited_with_schema(
853 path: &Path,
854 delimiter: u8,
855 schema: &DataSchema,
856 unseen_policy: UnseenCategoryPolicy,
857 requested_columns: &[String],
858) -> Result<EncodedDataset, DataError> {
859 let t_open = std::time::Instant::now();
860 let mut rdr = ReaderBuilder::new()
861 .has_headers(true)
862 .delimiter(delimiter)
863 .from_path(path)
864 .map_err(|e| DataError::ParseError {
865 reason: format!("failed to open '{}': {e}", path.display()),
866 })?;
867
868 let all_headers: Vec<String> = rdr
869 .headers()
870 .map_err(|e| DataError::ParseError {
871 reason: format!("failed to read headers: {e}"),
872 })?
873 .iter()
874 .map(|s| s.trim().to_string())
875 .collect();
876 if all_headers.is_empty() {
877 return Err(DataError::EmptyInput {
878 reason: "file has no headers".to_string(),
879 });
880 }
881 let selected_indices = resolve_requested_columns(&all_headers, requested_columns)?;
882 let headers = projected_headers(&all_headers, &selected_indices);
883 let p = headers.len();
884 let open_ms = t_open.elapsed().as_secs_f64() * 1000.0;
885 if open_ms > 100.0 {
886 log::info!(
887 "[DATA-LOAD] delim_schema_open+headers | n_headers={} | n_proj={} | {:.1}ms",
888 all_headers.len(),
889 p,
890 open_ms
891 );
892 }
893
894 let schema_byname: HashMap<&str, &SchemaColumn> = schema
896 .columns
897 .iter()
898 .map(|c| (c.name.as_str(), c))
899 .collect();
900
901 let mut col_meta = Vec::<ColMeta>::with_capacity(p);
902 for name in &headers {
903 if let Some(sc) = schema_byname.get(name.as_str()) {
904 let level_map = if matches!(sc.kind, ColumnKindTag::Categorical) {
905 Some(
906 sc.levels
907 .iter()
908 .enumerate()
909 .map(|(idx, v)| (v.clone(), idx as f64))
910 .collect::<HashMap<_, _>>(),
911 )
912 } else {
913 None
914 };
915 col_meta.push(ColMeta {
916 kind: sc.kind,
917 level_map,
918 schema_col: (*sc).clone(),
919 });
920 } else {
921 col_meta.push(ColMeta {
923 kind: ColumnKindTag::Continuous, level_map: None,
925 schema_col: SchemaColumn {
926 name: name.clone(),
927 kind: ColumnKindTag::Continuous,
928 levels: Vec::new(),
929 },
930 });
931 }
932 }
933
934 let needs_inference: Vec<bool> = headers
936 .iter()
937 .map(|h| !schema_byname.contains_key(h.as_str()))
938 .collect();
939
940 let mut col_vecs: Vec<Vec<f64>> = vec![Vec::new(); p];
942 let mut infer_all_numeric: Vec<bool> = vec![true; p];
944 let mut infer_all_binary: Vec<bool> = vec![true; p];
945 let mut infer_level_index: Vec<HashMap<String, usize>> = vec![HashMap::new(); p];
946 let mut infer_levels: Vec<Vec<String>> = vec![Vec::new(); p];
947 let mut infer_strings: Vec<Vec<(usize, String)>> = vec![Vec::new(); p]; let mut total_rows: usize = 0;
950 let t_stream = std::time::Instant::now();
951 let mut record = StringRecord::new();
952 while rdr
953 .read_record(&mut record)
954 .map_err(|e| DataError::ParseError {
955 reason: format!("failed reading row: {e}"),
956 })?
957 {
958 if record.len() != all_headers.len() {
959 return Err(DataError::SchemaMismatch {
960 reason: format!(
961 "row width mismatch at row {}: got {} fields, expected {}",
962 total_rows + 1,
963 record.len(),
964 all_headers.len()
965 ),
966 });
967 }
968 total_rows += 1;
969
970 for j in 0..p {
971 let raw = record.get(selected_indices[j]).unwrap().trim();
972 if raw.is_empty() {
973 return Err(DataError::EmptyInput {
974 reason: format!(
975 "empty field at row {}, column '{}'",
976 total_rows, &headers[j]
977 ),
978 });
979 }
980
981 if needs_inference[j] {
982 if let Ok(v) = raw.parse::<f64>() {
984 if !v.is_finite() {
985 return Err(DataError::InvalidValue {
986 reason: format!(
987 "non-finite value at row {}, column '{}'",
988 total_rows, &headers[j]
989 ),
990 });
991 }
992 if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
993 infer_all_binary[j] = false;
994 }
995 col_vecs[j].push(v);
996 infer_strings[j].push((total_rows - 1, raw.to_string()));
1005 } else {
1006 infer_all_numeric[j] = false;
1007 infer_all_binary[j] = false;
1008 let levels_ref = &mut infer_levels[j];
1009 infer_level_index[j]
1010 .entry(raw.to_string())
1011 .or_insert_with(|| {
1012 let idx = levels_ref.len();
1013 levels_ref.push(raw.to_string());
1014 idx
1015 });
1016 infer_strings[j].push((total_rows - 1, raw.to_string()));
1017 col_vecs[j].push(f64::NAN); }
1019 } else {
1020 let val = parse_cell_with_schema(
1022 raw,
1023 &col_meta[j],
1024 total_rows,
1025 &headers[j],
1026 &unseen_policy,
1027 )?;
1028 col_vecs[j].push(val);
1029 }
1030 }
1031 }
1032
1033 let stream_ms = t_stream.elapsed().as_secs_f64() * 1000.0;
1034 if stream_ms > 100.0 {
1035 let n_inf = needs_inference.iter().filter(|x| **x).count();
1036 log::info!(
1037 "[DATA-LOAD] delim_schema_stream | n_rows={} | n_cols={} | n_inf={} | {:.1}ms",
1038 total_rows,
1039 p,
1040 n_inf,
1041 stream_ms
1042 );
1043 }
1044
1045 if total_rows == 0 {
1046 return Err(DataError::EmptyInput {
1047 reason: "file has no rows".to_string(),
1048 });
1049 }
1050
1051 let t_finalize = std::time::Instant::now();
1052 let mut column_kinds = Vec::<ColumnKindTag>::with_capacity(p);
1054 for j in 0..p {
1055 if needs_inference[j] {
1056 let kind = if infer_all_numeric[j] {
1057 if infer_all_binary[j] {
1058 ColumnKindTag::Binary
1059 } else {
1060 ColumnKindTag::Continuous
1061 }
1062 } else {
1063 ColumnKindTag::Categorical
1064 };
1065 col_meta[j].kind = kind;
1066 col_meta[j].schema_col.kind = kind;
1067 if matches!(kind, ColumnKindTag::Categorical) {
1068 for (_, raw) in &infer_strings[j] {
1081 let levels_ref = &mut infer_levels[j];
1082 infer_level_index[j].entry(raw.clone()).or_insert_with(|| {
1083 let new_idx = levels_ref.len();
1084 levels_ref.push(raw.clone());
1085 new_idx
1086 });
1087 }
1088 infer_levels[j].sort();
1089 infer_level_index[j].clear();
1090 for (idx, level) in infer_levels[j].iter().enumerate() {
1091 infer_level_index[j].insert(level.clone(), idx);
1092 }
1093 for (row_idx, raw) in &infer_strings[j] {
1094 col_vecs[j][*row_idx] = infer_level_index[j][raw] as f64;
1095 }
1096 col_meta[j].schema_col.levels = infer_levels[j].clone();
1097 }
1098 }
1099 column_kinds.push(col_meta[j].kind);
1100 }
1101 let finalize_ms = t_finalize.elapsed().as_secs_f64() * 1000.0;
1102 if finalize_ms > 100.0 {
1103 log::info!(
1104 "[DATA-LOAD] delim_schema_finalize | n_cols={} | {:.1}ms",
1105 p,
1106 finalize_ms
1107 );
1108 }
1109
1110 let t_assemble = std::time::Instant::now();
1111 let mut values = Array2::<f64>::zeros((total_rows, p));
1117 let assemble_err: Option<DataError> = values
1118 .axis_iter_mut(Axis(1))
1119 .into_par_iter()
1120 .zip(col_vecs.par_iter())
1121 .zip(headers.par_iter())
1122 .map(|((mut out_col, col_vec), header)| {
1123 for (i, &v) in col_vec.iter().enumerate() {
1124 if !v.is_finite() {
1125 return Some(DataError::InvalidValue {
1126 reason: format!("non-finite value at row {}, column '{}'", i + 1, header),
1127 });
1128 }
1129 out_col[i] = v;
1130 }
1131 None
1132 })
1133 .reduce(|| None, |a, b| a.or(b));
1134 if let Some(e) = assemble_err {
1135 return Err(e);
1136 }
1137 let assemble_ms = t_assemble.elapsed().as_secs_f64() * 1000.0;
1138 if assemble_ms > 100.0 {
1139 log::info!(
1140 "[DATA-LOAD] delim_schema_assemble | n_rows={} | n_cols={} | {:.1}ms",
1141 total_rows,
1142 p,
1143 assemble_ms
1144 );
1145 }
1146
1147 let schema_out = DataSchema {
1148 columns: col_meta.into_iter().map(|m| m.schema_col).collect(),
1149 };
1150 Ok(EncodedDataset {
1151 headers,
1152 values,
1153 schema: schema_out,
1154 column_kinds,
1155 })
1156}
1157
1158fn parse_cell_with_schema(
1159 raw: &str,
1160 meta: &ColMeta,
1161 row: usize,
1162 col_name: &str,
1163 unseen_policy: &UnseenCategoryPolicy,
1164) -> Result<f64, DataError> {
1165 let val = match meta.kind {
1166 ColumnKindTag::Continuous => raw.parse::<f64>().map_err(|err| {
1167 DataError::SchemaMismatch {
1168 reason: format!(
1169 "column '{}' is continuous in schema but row {} has non-numeric value '{}': {}",
1170 col_name, row, raw, err
1171 ),
1172 }
1173 })?,
1174 ColumnKindTag::Binary => {
1175 let v = raw
1176 .parse::<f64>()
1177 .map_err(|err| DataError::SchemaMismatch {
1178 reason: format!(
1179 "column '{}' is binary in schema but row {} has non-numeric value '{}': {}",
1180 col_name, row, raw, err
1181 ),
1182 })?;
1183 if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
1184 return Err(DataError::SchemaMismatch {
1185 reason: format!(
1186 "column '{}' is binary in schema but row {} has value {}; expected 0 or 1",
1187 col_name, row, v
1188 ),
1189 });
1190 }
1191 v
1192 }
1193 ColumnKindTag::Categorical => {
1194 let map = meta
1195 .level_map
1196 .as_ref()
1197 .ok_or_else(|| DataError::EncodingFailure {
1198 reason: "internal categorical schema map missing".to_string(),
1199 })?;
1200 match map.get(raw) {
1201 Some(v) => *v,
1202 None => unseen_policy
1203 .unseen_code_for(col_name, meta.schema_col.levels.len())
1204 .ok_or_else(|| DataError::SchemaMismatch {
1205 reason: format!(
1206 "unseen level '{}' in categorical column '{}' at row {}",
1207 raw, col_name, row
1208 ),
1209 })?,
1210 }
1211 }
1212 };
1213 if !val.is_finite() {
1214 return Err(DataError::InvalidValue {
1215 reason: format!("non-finite value at row {}, column '{}'", row, col_name),
1216 });
1217 }
1218 Ok(val)
1219}
1220
1221struct ColMeta {
1224 kind: ColumnKindTag,
1225 level_map: Option<HashMap<String, f64>>,
1226 schema_col: SchemaColumn,
1227}
1228
1229enum ParquetBatchColumn {
1234 Numeric(Vec<f64>),
1235 Strings(Vec<String>),
1236}
1237
1238fn parquet_field_is_string(dt: &arrow::datatypes::DataType) -> bool {
1247 use arrow::datatypes::DataType;
1248 match dt {
1249 DataType::Utf8 | DataType::LargeUtf8 => true,
1250 DataType::Dictionary(_, value_type) => parquet_field_is_string(value_type),
1251 _ => false,
1252 }
1253}
1254
1255fn decode_parquet_batch_column(
1256 col: &dyn arrow::array::Array,
1257 n_rows: usize,
1258 base_row: usize,
1259 header: &str,
1260 is_string_col: bool,
1261) -> Result<ParquetBatchColumn, DataError> {
1262 use arrow::array::{
1263 Array as ArrowArray, BooleanArray, Float32Array, Float64Array, Int8Array, Int16Array,
1264 Int32Array, Int64Array, LargeStringArray, StringArray, UInt8Array, UInt16Array,
1265 UInt32Array, UInt64Array,
1266 };
1267 use arrow::datatypes::DataType;
1268
1269 if col.null_count() > 0 {
1270 for i in 0..n_rows {
1271 if col.is_null(i) {
1272 return Err(DataError::InvalidValue {
1273 reason: format!(
1274 "null value at row {}, column '{}'",
1275 base_row + i + 1,
1276 header
1277 ),
1278 });
1279 }
1280 }
1281 }
1282
1283 if is_string_col {
1284 if let Some(arr) = col.as_any().downcast_ref::<StringArray>() {
1285 return Ok(ParquetBatchColumn::Strings(
1286 (0..n_rows).map(|i| arr.value(i).to_string()).collect(),
1287 ));
1288 }
1289 if let Some(arr) = col.as_any().downcast_ref::<LargeStringArray>() {
1290 return Ok(ParquetBatchColumn::Strings(
1291 (0..n_rows).map(|i| arr.value(i).to_string()).collect(),
1292 ));
1293 }
1294
1295 let casted =
1299 arrow::compute::cast(col, &DataType::Utf8).map_err(|e| DataError::ParseError {
1300 reason: format!("failed to cast column '{}' to string: {e}", header),
1301 })?;
1302 let arr = casted
1303 .as_any()
1304 .downcast_ref::<StringArray>()
1305 .ok_or_else(|| DataError::EncodingFailure {
1306 reason: format!("column '{}' could not be read as string after cast", header),
1307 })?;
1308 return Ok(ParquetBatchColumn::Strings(
1309 (0..n_rows).map(|i| arr.value(i).to_string()).collect(),
1310 ));
1311 }
1312
1313 let decoded_col;
1320 let col: &dyn arrow::array::Array = if let DataType::Dictionary(_, value_type) = col.data_type()
1321 {
1322 decoded_col = arrow::compute::cast(col, value_type).map_err(|e| DataError::ParseError {
1323 reason: format!(
1324 "failed to decode dictionary-encoded numeric column '{}': {e}",
1325 header
1326 ),
1327 })?;
1328 decoded_col.as_ref()
1329 } else {
1330 col
1331 };
1332
1333 let mut values = Vec::with_capacity(n_rows);
1334 match col.data_type() {
1335 DataType::Float64 => {
1336 let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();
1337 values.extend(arr.values().iter().copied());
1338 }
1339 DataType::Float32 => {
1340 let arr = col.as_any().downcast_ref::<Float32Array>().unwrap();
1341 values.extend(arr.values().iter().map(|&v| v as f64));
1342 }
1343 DataType::Int64 => {
1344 let arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
1345 values.extend(arr.values().iter().map(|&v| v as f64));
1346 }
1347 DataType::Int32 => {
1348 let arr = col.as_any().downcast_ref::<Int32Array>().unwrap();
1349 values.extend(arr.values().iter().map(|&v| v as f64));
1350 }
1351 DataType::Int16 => {
1352 let arr = col.as_any().downcast_ref::<Int16Array>().unwrap();
1353 values.extend(arr.values().iter().map(|&v| v as f64));
1354 }
1355 DataType::Int8 => {
1356 let arr = col.as_any().downcast_ref::<Int8Array>().unwrap();
1357 values.extend(arr.values().iter().map(|&v| v as f64));
1358 }
1359 DataType::UInt64 => {
1360 let arr = col.as_any().downcast_ref::<UInt64Array>().unwrap();
1361 values.extend(arr.values().iter().map(|&v| v as f64));
1362 }
1363 DataType::UInt32 => {
1364 let arr = col.as_any().downcast_ref::<UInt32Array>().unwrap();
1365 values.extend(arr.values().iter().map(|&v| v as f64));
1366 }
1367 DataType::UInt16 => {
1368 let arr = col.as_any().downcast_ref::<UInt16Array>().unwrap();
1369 values.extend(arr.values().iter().map(|&v| v as f64));
1370 }
1371 DataType::UInt8 => {
1372 let arr = col.as_any().downcast_ref::<UInt8Array>().unwrap();
1373 values.extend(arr.values().iter().map(|&v| v as f64));
1374 }
1375 DataType::Boolean => {
1376 let arr = col.as_any().downcast_ref::<BooleanArray>().unwrap();
1377 values.extend((0..n_rows).map(|i| if arr.value(i) { 1.0 } else { 0.0 }));
1378 }
1379 other => {
1380 return Err(DataError::InvalidValue {
1381 reason: format!(
1382 "unsupported parquet column type {:?} for column '{}'",
1383 other, header
1384 ),
1385 });
1386 }
1387 }
1388
1389 if let Some(i) = values.iter().position(|v| !v.is_finite()) {
1390 return Err(DataError::InvalidValue {
1391 reason: format!(
1392 "non-finite value at row {}, column '{}'",
1393 base_row + i + 1,
1394 header
1395 ),
1396 });
1397 }
1398
1399 Ok(ParquetBatchColumn::Numeric(values))
1400}
1401
1402fn load_parquet_inferred(
1403 path: &Path,
1404 requested_columns: &[String],
1405 categorical_roles: &HashSet<&str>,
1406) -> Result<EncodedDataset, DataError> {
1407 use parquet::arrow::{ProjectionMask, arrow_reader::ParquetRecordBatchReaderBuilder};
1408 use rayon::prelude::*;
1409 use std::fs::File;
1410
1411 let t_open = std::time::Instant::now();
1412 let file = File::open(path).map_err(|e| DataError::ParseError {
1413 reason: format!("failed to open parquet '{}': {e}", path.display()),
1414 })?;
1415 let builder =
1416 ParquetRecordBatchReaderBuilder::try_new(file).map_err(|e| DataError::ParseError {
1417 reason: format!("failed to read parquet metadata '{}': {e}", path.display()),
1418 })?;
1419
1420 let full_schema = builder.schema().clone();
1421 let all_headers: Vec<String> = full_schema
1422 .fields()
1423 .iter()
1424 .map(|f| f.name().clone())
1425 .collect();
1426 if all_headers.is_empty() {
1427 return Err(DataError::EmptyInput {
1428 reason: "parquet file has no columns".to_string(),
1429 });
1430 }
1431 let selected_indices = resolve_requested_columns(&all_headers, requested_columns)?;
1432 let headers = projected_headers(&all_headers, &selected_indices);
1433 let selected_fields = selected_indices
1434 .iter()
1435 .map(|&idx| full_schema.fields()[idx].clone())
1436 .collect::<Vec<_>>();
1437 let projection =
1438 ProjectionMask::roots(builder.parquet_schema(), selected_indices.iter().copied());
1439 let reader =
1440 builder
1441 .with_projection(projection)
1442 .build()
1443 .map_err(|e| DataError::ParseError {
1444 reason: format!("failed to build parquet reader: {e}"),
1445 })?;
1446 let p = headers.len();
1447 let open_ms = t_open.elapsed().as_secs_f64() * 1000.0;
1448 if open_ms > 100.0 {
1449 log::info!(
1450 "[DATA-LOAD] parquet_open+meta | n_headers={} | n_proj={} | {:.1}ms",
1451 all_headers.len(),
1452 p,
1453 open_ms
1454 );
1455 }
1456
1457 let t_batches = std::time::Instant::now();
1458 let mut col_vecs: Vec<Vec<f64>> = vec![Vec::new(); p];
1460 let mut string_cols: Vec<Option<Vec<String>>> = (0..p).map(|_| None).collect();
1462 let mut is_string_col: Vec<bool> = vec![false; p];
1463
1464 for (j, field) in selected_fields.iter().enumerate() {
1465 if parquet_field_is_string(field.data_type()) {
1469 is_string_col[j] = true;
1470 string_cols[j] = Some(Vec::new());
1471 }
1472 }
1473
1474 let mut rows_seen = 0usize;
1475 for batch_result in reader {
1476 let batch = batch_result.map_err(|e| DataError::ParseError {
1477 reason: format!("failed to read parquet record batch: {e}"),
1478 })?;
1479 let n_rows = batch.num_rows();
1480
1481 let decoded_columns: Vec<Result<ParquetBatchColumn, DataError>> = (0..p)
1482 .into_par_iter()
1483 .map(|j| {
1484 decode_parquet_batch_column(
1485 batch.column(j).as_ref(),
1486 n_rows,
1487 rows_seen,
1488 &headers[j],
1489 is_string_col[j],
1490 )
1491 })
1492 .collect();
1493
1494 for (j, decoded) in decoded_columns.into_iter().enumerate() {
1495 match decoded? {
1496 ParquetBatchColumn::Strings(mut strings) => {
1497 assert!(is_string_col[j]);
1498 string_cols[j].as_mut().unwrap().append(&mut strings);
1499 let new_len = col_vecs[j].len() + n_rows;
1500 col_vecs[j].resize(new_len, f64::NAN);
1501 }
1502 ParquetBatchColumn::Numeric(mut values) => {
1503 assert!(!is_string_col[j]);
1504 col_vecs[j].append(&mut values);
1505 }
1506 }
1507 }
1508 rows_seen += n_rows;
1509 }
1510
1511 let total_rows = col_vecs[0].len();
1512 let batches_ms = t_batches.elapsed().as_secs_f64() * 1000.0;
1513 if batches_ms > 100.0 {
1514 log::info!(
1515 "[DATA-LOAD] parquet_batches_decode | n_rows={} | n_cols={} | {:.1}ms",
1516 total_rows,
1517 p,
1518 batches_ms
1519 );
1520 }
1521 if total_rows == 0 {
1522 return Err(DataError::EmptyInput {
1523 reason: "parquet file has no rows".to_string(),
1524 });
1525 }
1526
1527 let t_schema = std::time::Instant::now();
1528 let mut schema_cols = Vec::<SchemaColumn>::with_capacity(p);
1530 let mut column_kinds = Vec::<ColumnKindTag>::with_capacity(p);
1531
1532 let finalized_columns: Vec<(Vec<f64>, ColumnKindTag, SchemaColumn)> = col_vecs
1533 .into_par_iter()
1534 .zip(string_cols.into_par_iter())
1535 .zip(is_string_col.into_par_iter())
1536 .zip(headers.par_iter())
1537 .map(|(((mut col_values, strings), is_string), header)| {
1538 if is_string {
1539 let strings = strings.expect("string column storage missing");
1543 let mut level_index: HashMap<String, usize> = HashMap::new();
1544 let mut levels_vec: Vec<String> = Vec::new();
1545 for s in &strings {
1546 level_index.entry(s.clone()).or_insert_with(|| {
1547 let idx = levels_vec.len();
1548 levels_vec.push(s.clone());
1549 idx
1550 });
1551 }
1552 for (i, s) in strings.iter().enumerate() {
1553 col_values[i] = *level_index.get(s.as_str()).unwrap() as f64;
1554 }
1555 (
1556 col_values,
1557 ColumnKindTag::Categorical,
1558 SchemaColumn {
1559 name: header.clone(),
1560 kind: ColumnKindTag::Categorical,
1561 levels: levels_vec,
1562 },
1563 )
1564 } else if categorical_roles.contains(header.as_str()) {
1565 let labels: Vec<String> = col_values.iter().map(|v| v.to_string()).collect();
1573 let mut levels_vec: Vec<String> = Vec::new();
1574 let mut level_index: HashMap<String, usize> = HashMap::new();
1575 for label in &labels {
1576 level_index.entry(label.clone()).or_insert_with(|| {
1577 let idx = levels_vec.len();
1578 levels_vec.push(label.clone());
1579 idx
1580 });
1581 }
1582 levels_vec.sort();
1583 level_index.clear();
1584 for (idx, level) in levels_vec.iter().enumerate() {
1585 level_index.insert(level.clone(), idx);
1586 }
1587 for (i, label) in labels.iter().enumerate() {
1588 col_values[i] = level_index[label] as f64;
1589 }
1590 (
1591 col_values,
1592 ColumnKindTag::Categorical,
1593 SchemaColumn {
1594 name: header.clone(),
1595 kind: ColumnKindTag::Categorical,
1596 levels: levels_vec,
1597 },
1598 )
1599 } else {
1600 let all_binary = col_values
1602 .iter()
1603 .all(|&v| (v - 0.0).abs() < 1e-12 || (v - 1.0).abs() < 1e-12);
1604 let kind = if all_binary {
1605 ColumnKindTag::Binary
1606 } else {
1607 ColumnKindTag::Continuous
1608 };
1609 (
1610 col_values,
1611 kind,
1612 SchemaColumn {
1613 name: header.clone(),
1614 kind,
1615 levels: Vec::new(),
1616 },
1617 )
1618 }
1619 })
1620 .collect();
1621
1622 let mut col_vecs = Vec::with_capacity(p);
1623 for (col_values, kind, schema_col) in finalized_columns {
1624 col_vecs.push(col_values);
1625 column_kinds.push(kind);
1626 schema_cols.push(schema_col);
1627 }
1628 let schema_ms = t_schema.elapsed().as_secs_f64() * 1000.0;
1629 if schema_ms > 100.0 {
1630 let n_cat = column_kinds
1631 .iter()
1632 .filter(|k| matches!(k, ColumnKindTag::Categorical))
1633 .count();
1634 log::info!(
1635 "[DATA-LOAD] parquet_finalize_schema | n_cols={} | n_cat={} | {:.1}ms",
1636 p,
1637 n_cat,
1638 schema_ms
1639 );
1640 }
1641
1642 let t_assemble = std::time::Instant::now();
1643 let mut values = Array2::<f64>::zeros((total_rows, p));
1648 values
1649 .axis_iter_mut(Axis(1))
1650 .into_par_iter()
1651 .zip(col_vecs.par_iter())
1652 .for_each(|(mut out_col, src)| {
1653 for (dst, &v) in out_col.iter_mut().zip(src.iter()) {
1654 *dst = v;
1655 }
1656 });
1657 let assemble_ms = t_assemble.elapsed().as_secs_f64() * 1000.0;
1658 if assemble_ms > 100.0 {
1659 log::info!(
1660 "[DATA-LOAD] parquet_assemble_array2 | n_rows={} | n_cols={} | {:.1}ms",
1661 total_rows,
1662 p,
1663 assemble_ms
1664 );
1665 }
1666
1667 Ok(EncodedDataset {
1668 headers,
1669 values,
1670 schema: DataSchema {
1671 columns: schema_cols,
1672 },
1673 column_kinds,
1674 })
1675}
1676
1677fn load_parquet_with_schema(
1678 path: &Path,
1679 schema: &DataSchema,
1680 unseen_policy: UnseenCategoryPolicy,
1681 requested_columns: &[String],
1682) -> Result<EncodedDataset, DataError> {
1683 let inferred = load_parquet_inferred(path, requested_columns, &HashSet::new())?;
1687 let p = inferred.headers.len();
1688 let n = inferred.values.nrows();
1689
1690 let schema_byname: HashMap<&str, &SchemaColumn> = schema
1691 .columns
1692 .iter()
1693 .map(|c| (c.name.as_str(), c))
1694 .collect();
1695
1696 let mut column_kinds = Vec::<ColumnKindTag>::with_capacity(p);
1697 let mut schema_cols = Vec::<SchemaColumn>::with_capacity(p);
1698 let mut values = inferred.values;
1699
1700 for j in 0..p {
1701 let name = &inferred.headers[j];
1702 if let Some(sc) = schema_byname.get(name.as_str()) {
1703 column_kinds.push(sc.kind);
1704 schema_cols.push((*sc).clone());
1705
1706 match sc.kind {
1707 ColumnKindTag::Continuous => {
1708 if matches!(inferred.column_kinds[j], ColumnKindTag::Categorical) {
1709 return Err(DataError::SchemaMismatch {
1710 reason: format!(
1711 "column '{}' is continuous in schema but parquet column is string/categorical",
1712 name
1713 ),
1714 });
1715 }
1716 }
1717 ColumnKindTag::Binary => {
1718 if matches!(inferred.column_kinds[j], ColumnKindTag::Categorical) {
1719 return Err(DataError::SchemaMismatch {
1720 reason: format!(
1721 "column '{}' is binary in schema but parquet column is string/categorical",
1722 name
1723 ),
1724 });
1725 }
1726 if let Some(row) = values.column(j).iter().position(|value| {
1727 (*value - 0.0).abs() >= 1e-12 && (*value - 1.0).abs() >= 1e-12
1728 }) {
1729 return Err(DataError::SchemaMismatch {
1730 reason: format!(
1731 "column '{}' is binary in schema but row {} has value {}; expected 0 or 1",
1732 name,
1733 row + 1,
1734 values[[row, j]]
1735 ),
1736 });
1737 }
1738 }
1739 ColumnKindTag::Categorical => {
1740 if !matches!(inferred.column_kinds[j], ColumnKindTag::Categorical) {
1741 return Err(DataError::SchemaMismatch {
1742 reason: format!(
1743 "column '{}' is categorical in schema but parquet column is numeric",
1744 name
1745 ),
1746 });
1747 }
1748 let inferred_col = &inferred.schema.columns[j];
1749 let schema_level_map: HashMap<&str, f64> = sc
1751 .levels
1752 .iter()
1753 .enumerate()
1754 .map(|(idx, v)| (v.as_str(), idx as f64))
1755 .collect();
1756 let inferred_to_schema: Vec<f64> = inferred_col
1757 .levels
1758 .iter()
1759 .map(|lv| {
1760 schema_level_map
1761 .get(lv.as_str())
1762 .copied()
1763 .or_else(|| unseen_policy.unseen_code_for(name, sc.levels.len()))
1764 .ok_or_else(|| DataError::SchemaMismatch {
1765 reason: format!(
1766 "unseen level '{}' in categorical column '{}'",
1767 lv, name
1768 ),
1769 })
1770 })
1771 .collect::<Result<Vec<_>, _>>()?;
1772 for i in 0..n {
1773 let old_code = values[[i, j]] as usize;
1774 if old_code >= inferred_to_schema.len() {
1775 let Some(unseen_code) =
1776 unseen_policy.unseen_code_for(name, sc.levels.len())
1777 else {
1778 return Err(DataError::SchemaMismatch {
1779 reason: format!(
1780 "unseen categorical code at row {}, column '{}'",
1781 i + 1,
1782 name
1783 ),
1784 });
1785 };
1786 values[[i, j]] = unseen_code;
1787 continue;
1788 }
1789 values[[i, j]] = inferred_to_schema[old_code];
1790 }
1791 }
1792 }
1793 } else {
1794 column_kinds.push(inferred.column_kinds[j]);
1796 schema_cols.push(inferred.schema.columns[j].clone());
1797 }
1798 }
1799
1800 Ok(EncodedDataset {
1801 headers: inferred.headers,
1802 values,
1803 schema: DataSchema {
1804 columns: schema_cols,
1805 },
1806 column_kinds,
1807 })
1808}
1809
1810pub fn encode_recordswith_inferred_schema(
1811 headers: Vec<String>,
1812 records: Vec<StringRecord>,
1813) -> Result<EncodedDataset, String> {
1814 if records.is_empty() {
1815 return Err(DataError::EmptyInput {
1816 reason: "table data cannot be empty".to_string(),
1817 }
1818 .into());
1819 }
1820 let schema_cols = headers
1826 .par_iter()
1827 .enumerate()
1828 .map(|(j, name)| infer_schema_column(name, &records, j).map_err(String::from))
1829 .collect::<Result<Vec<SchemaColumn>, String>>()?;
1830 let schema = DataSchema {
1831 columns: schema_cols,
1832 };
1833 encode_recordswith_schema(headers, records, &schema, UnseenCategoryPolicy::Error)
1834}
1835
1836pub fn encode_recordswith_schema(
1837 headers: Vec<String>,
1838 records: Vec<StringRecord>,
1839 schema: &DataSchema,
1840 unseen_policy: UnseenCategoryPolicy,
1841) -> Result<EncodedDataset, String> {
1842 let n = records.len();
1843 if n == 0 {
1844 return Err(DataError::EmptyInput {
1845 reason: "table data cannot be empty".to_string(),
1846 }
1847 .into());
1848 }
1849 let p = headers.len();
1850 if p == 0 {
1851 return Err(DataError::EmptyInput {
1852 reason: "table data must have at least one header column".to_string(),
1853 }
1854 .into());
1855 }
1856 for (i, rec) in records.iter().enumerate() {
1863 if rec.len() != p {
1864 return Err(DataError::SchemaMismatch {
1865 reason: format!(
1866 "row width mismatch at row {}: got {} fields, expected {} (one per header)",
1867 i + 1,
1868 rec.len(),
1869 p
1870 ),
1871 }
1872 .into());
1873 }
1874 }
1875 let schema_byname: HashMap<&str, &SchemaColumn> = schema
1876 .columns
1877 .iter()
1878 .map(|c| (c.name.as_str(), c))
1879 .collect();
1880
1881 let encoded_columns = headers
1887 .par_iter()
1888 .enumerate()
1889 .map(|(j, name)| {
1890 let inferred_for_extra;
1891 let col_schema = if let Some(s) = schema_byname.get(name.as_str()) {
1892 *s
1893 } else {
1894 inferred_for_extra =
1895 infer_schema_column(name, &records, j).map_err(String::from)?;
1896 &inferred_for_extra
1897 };
1898 let column = encode_one_column(name, &records, j, col_schema, &unseen_policy)?;
1899 Ok::<(ColumnKindTag, Vec<f64>), String>((col_schema.kind, column))
1900 })
1901 .collect::<Result<Vec<(ColumnKindTag, Vec<f64>)>, String>>()?;
1902
1903 let mut column_kinds = Vec::<ColumnKindTag>::with_capacity(p);
1904 let mut values = Array2::<f64>::zeros((n, p));
1905 for (j, (kind, column)) in encoded_columns.into_iter().enumerate() {
1906 column_kinds.push(kind);
1907 values
1908 .column_mut(j)
1909 .assign(&ndarray::ArrayView1::from(&column));
1910 }
1911
1912 Ok(EncodedDataset {
1913 headers,
1914 values,
1915 schema: schema.clone(),
1916 column_kinds,
1917 })
1918}
1919
1920fn encode_one_column(
1927 name: &str,
1928 records: &[StringRecord],
1929 j: usize,
1930 col_schema: &SchemaColumn,
1931 unseen_policy: &UnseenCategoryPolicy,
1932) -> Result<Vec<f64>, String> {
1933 let level_map = if matches!(col_schema.kind, ColumnKindTag::Categorical) {
1934 Some(
1935 col_schema
1936 .levels
1937 .iter()
1938 .enumerate()
1939 .map(|(idx, v)| (v.as_str(), idx as f64))
1940 .collect::<HashMap<_, _>>(),
1941 )
1942 } else {
1943 None
1944 };
1945
1946 let mut column = Vec::<f64>::with_capacity(records.len());
1947 for (i, rec) in records.iter().enumerate() {
1948 let raw = rec
1949 .get(j)
1950 .ok_or_else(|| {
1951 String::from(DataError::SchemaMismatch {
1952 reason: format!("missing field at row {}, col {}", i + 1, j + 1),
1953 })
1954 })?
1955 .trim();
1956 if raw.is_empty() {
1957 return Err(DataError::EmptyInput {
1958 reason: format!("empty field at row {}, column '{}'", i + 1, name),
1959 }
1960 .into());
1961 }
1962 let val = match col_schema.kind {
1963 ColumnKindTag::Continuous => raw.parse::<f64>().map_err(|err| {
1964 String::from(DataError::SchemaMismatch {
1965 reason: format!(
1966 "column '{}' is continuous in schema but row {} has non-numeric value '{}': {}",
1967 name,
1968 i + 1,
1969 raw,
1970 err
1971 ),
1972 })
1973 })?,
1974 ColumnKindTag::Binary => {
1975 let v = raw.parse::<f64>().map_err(|err| {
1976 String::from(DataError::SchemaMismatch {
1977 reason: format!(
1978 "column '{}' is binary in schema but row {} has non-numeric value '{}': {}",
1979 name,
1980 i + 1,
1981 raw,
1982 err
1983 ),
1984 })
1985 })?;
1986 if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
1987 return Err(DataError::SchemaMismatch {
1988 reason: format!(
1989 "column '{}' is binary in schema but row {} has value {}; expected 0 or 1",
1990 name,
1991 i + 1,
1992 v
1993 ),
1994 }
1995 .into());
1996 }
1997 v
1998 }
1999 ColumnKindTag::Categorical => {
2000 let map = level_map.as_ref().ok_or_else(|| {
2001 String::from(DataError::EncodingFailure {
2002 reason: "internal categorical schema map missing".to_string(),
2003 })
2004 })?;
2005 match map.get(raw) {
2006 Some(v) => *v,
2007 None => unseen_policy
2008 .unseen_code_for(name, col_schema.levels.len())
2009 .ok_or_else(|| {
2010 String::from(DataError::SchemaMismatch {
2011 reason: format!(
2012 "unseen level '{}' in categorical column '{}' at row {}; allowed levels: {}",
2013 raw,
2014 name,
2015 i + 1,
2016 col_schema.levels.join(",")
2017 ),
2018 })
2019 })?,
2020 }
2021 }
2022 };
2023 if !val.is_finite() {
2024 return Err(DataError::InvalidValue {
2025 reason: format!("non-finite value at row {}, column '{}'", i + 1, name),
2026 }
2027 .into());
2028 }
2029 column.push(val);
2030 }
2031 Ok(column)
2032}
2033
2034fn infer_schema_column(
2035 name: &str,
2036 records: &[StringRecord],
2037 col_idx: usize,
2038) -> Result<SchemaColumn, DataError> {
2039 let mut all_numeric = true;
2040 let mut all_binary = true;
2041 let mut levels = Vec::<String>::new();
2042 let mut level_index = HashMap::<String, usize>::new();
2043 for (i, rec) in records.iter().enumerate() {
2044 let raw = rec
2045 .get(col_idx)
2046 .ok_or_else(|| DataError::SchemaMismatch {
2047 reason: format!("missing field at row {}, col {}", i + 1, col_idx + 1),
2048 })?
2049 .trim();
2050 if raw.is_empty() {
2051 return Err(DataError::EmptyInput {
2052 reason: format!("empty field at row {}, column '{}'", i + 1, name),
2053 });
2054 }
2055 if let Ok(v) = raw.parse::<f64>() {
2056 if !v.is_finite() {
2057 return Err(DataError::InvalidValue {
2058 reason: format!("non-finite value at row {}, column '{}'", i + 1, name),
2059 });
2060 }
2061 if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
2062 all_binary = false;
2063 }
2064 } else {
2065 all_numeric = false;
2066 all_binary = false;
2067 level_index.entry(raw.to_string()).or_insert_with(|| {
2068 let idx = levels.len();
2069 levels.push(raw.to_string());
2070 idx
2071 });
2072 }
2073 }
2074 let kind = if all_numeric {
2075 if all_binary {
2076 ColumnKindTag::Binary
2077 } else {
2078 ColumnKindTag::Continuous
2079 }
2080 } else {
2081 ColumnKindTag::Categorical
2082 };
2083 if matches!(kind, ColumnKindTag::Categorical) {
2088 sort_levels_canonical(&mut levels);
2089 }
2090 Ok(SchemaColumn {
2091 name: name.to_string(),
2092 kind,
2093 levels: if matches!(kind, ColumnKindTag::Categorical) {
2094 levels
2095 } else {
2096 Vec::new()
2097 },
2098 })
2099}
2100
2101pub fn infer_and_encode_column_major(
2114 name: &str,
2115 column: &[&str],
2116 col_index: usize,
2117) -> Result<(SchemaColumn, Vec<f64>), String> {
2118 if column.is_empty() {
2119 return Err(DataError::EmptyInput {
2120 reason: "table data cannot be empty".to_string(),
2121 }
2122 .into());
2123 }
2124 let force_categorical = column.iter().any(|c| strip_categorical_sentinel(c).1);
2129 let mut all_numeric = !force_categorical;
2130 let mut all_binary = !force_categorical;
2131 let mut levels = Vec::<String>::new();
2132 let mut level_index = HashMap::<String, usize>::new();
2133 let mut trimmed = Vec::<&str>::with_capacity(column.len());
2134 let mut parsed = Vec::<Option<f64>>::with_capacity(column.len());
2141 for (i, raw_field) in column.iter().enumerate() {
2142 let (raw, _) = strip_categorical_sentinel(raw_field);
2145 let raw = raw.trim();
2146 if raw.is_empty() {
2147 return Err(DataError::EmptyInput {
2148 reason: format!("empty field at row {}, column '{}'", i + 1, name),
2149 }
2150 .into());
2151 }
2152 if !force_categorical {
2155 if let Ok(v) = raw.parse::<f64>() {
2156 if !v.is_finite() {
2157 return Err(DataError::InvalidValue {
2158 reason: format!("non-finite value at row {}, column '{}'", i + 1, name),
2159 }
2160 .into());
2161 }
2162 if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
2163 all_binary = false;
2164 }
2165 parsed.push(Some(v));
2166 trimmed.push(raw);
2167 continue;
2168 }
2169 all_numeric = false;
2170 all_binary = false;
2171 }
2172 level_index.entry(raw.to_string()).or_insert_with(|| {
2173 let idx = levels.len();
2174 levels.push(raw.to_string());
2175 idx
2176 });
2177 parsed.push(None);
2178 trimmed.push(raw);
2179 }
2180 let kind = if all_numeric {
2181 if all_binary {
2182 ColumnKindTag::Binary
2183 } else {
2184 ColumnKindTag::Continuous
2185 }
2186 } else {
2187 ColumnKindTag::Categorical
2188 };
2189 if matches!(kind, ColumnKindTag::Categorical) {
2203 sort_levels_canonical(&mut levels);
2204 }
2205 let schema = SchemaColumn {
2206 name: name.to_string(),
2207 kind,
2208 levels: if matches!(kind, ColumnKindTag::Categorical) {
2209 levels
2210 } else {
2211 Vec::new()
2212 },
2213 };
2214
2215 let level_map = if matches!(kind, ColumnKindTag::Categorical) {
2216 Some(
2217 schema
2218 .levels
2219 .iter()
2220 .enumerate()
2221 .map(|(idx, v)| (v.as_str(), idx as f64))
2222 .collect::<HashMap<_, _>>(),
2223 )
2224 } else {
2225 None
2226 };
2227
2228 let mut values = Vec::<f64>::with_capacity(trimmed.len());
2229 for (i, raw) in trimmed.iter().enumerate() {
2230 let raw = *raw;
2231 let val = match kind {
2232 ColumnKindTag::Continuous => parsed[i].ok_or_else(|| {
2236 String::from(DataError::EncodingFailure {
2237 reason: format!(
2238 "internal: continuous column '{}' lost its parsed value at row {} (col {})",
2239 name,
2240 i + 1,
2241 col_index
2242 ),
2243 })
2244 })?,
2245 ColumnKindTag::Binary => {
2246 let v = parsed[i].ok_or_else(|| {
2247 String::from(DataError::EncodingFailure {
2248 reason: format!(
2249 "internal: binary column '{}' lost its parsed value at row {} (col {})",
2250 name,
2251 i + 1,
2252 col_index
2253 ),
2254 })
2255 })?;
2256 if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
2257 return Err(DataError::SchemaMismatch {
2258 reason: format!(
2259 "column '{}' is binary in schema but row {} has value {}; expected 0 or 1",
2260 name,
2261 i + 1,
2262 v
2263 ),
2264 }
2265 .into());
2266 }
2267 v
2268 }
2269 ColumnKindTag::Categorical => {
2270 let map = level_map.as_ref().ok_or_else(|| {
2271 String::from(DataError::EncodingFailure {
2272 reason: "internal categorical schema map missing".to_string(),
2273 })
2274 })?;
2275 *map.get(raw).ok_or_else(|| {
2276 String::from(DataError::EncodingFailure {
2277 reason: format!(
2278 "internal: level '{}' missing from freshly built map for column '{}' (col {})",
2279 raw, name, col_index
2280 ),
2281 })
2282 })?
2283 }
2284 };
2285 if !val.is_finite() {
2286 return Err(DataError::InvalidValue {
2287 reason: format!("non-finite value at row {}, column '{}'", i + 1, name),
2288 }
2289 .into());
2290 }
2291 values.push(val);
2292 }
2293 Ok((schema, values))
2294}
2295
2296#[cfg(test)]
2297mod tests {
2298 use super::*;
2299
2300 #[test]
2301 fn encode_records_rejects_empty_input() {
2302 let headers = vec!["x".to_string()];
2303 let schema = DataSchema {
2304 columns: vec![SchemaColumn {
2305 name: "x".to_string(),
2306 kind: ColumnKindTag::Continuous,
2307 levels: Vec::new(),
2308 }],
2309 };
2310
2311 let err = encode_recordswith_inferred_schema(headers.clone(), Vec::new())
2312 .expect_err("empty inferred records should error");
2313 assert_eq!(err, "table data cannot be empty");
2314
2315 let err =
2316 encode_recordswith_schema(headers, Vec::new(), &schema, UnseenCategoryPolicy::Error)
2317 .expect_err("empty schema-guided records should error");
2318 assert_eq!(err, "table data cannot be empty");
2319 }
2320
2321 #[test]
2322 fn column_major_matches_record_driven_inferred_encode() {
2323 let headers = vec!["cont".to_string(), "bin".to_string(), "cat".to_string()];
2328 let raw_rows = vec![
2329 vec!["1.5", "0", "a"],
2330 vec!["2.0", "1", "b"],
2331 vec!["-3.25", "1", "a"],
2332 vec!["0.0", "0", "c"],
2333 ];
2334 let records: Vec<StringRecord> = raw_rows
2335 .iter()
2336 .map(|r| StringRecord::from(r.clone()))
2337 .collect();
2338 let record_ds = encode_recordswith_inferred_schema(headers.clone(), records)
2339 .expect("record-driven encode");
2340
2341 for (j, name) in headers.iter().enumerate() {
2342 let column: Vec<&str> = raw_rows.iter().map(|r| r[j]).collect();
2343 let (schema_col, values) =
2344 infer_and_encode_column_major(name, &column, j + 1).expect("column-major encode");
2345 assert_eq!(schema_col.kind, record_ds.schema.columns[j].kind);
2346 assert_eq!(schema_col.levels, record_ds.schema.columns[j].levels);
2347 for (i, v) in values.iter().enumerate() {
2348 assert_eq!(*v, record_ds.values[[i, j]], "row {i} col {name}");
2349 }
2350 }
2351 }
2352
2353 #[test]
2354 fn encode_records_can_encode_unseen_named_categorical_column() {
2355 let schema = DataSchema {
2356 columns: vec![
2357 SchemaColumn {
2358 name: "g".to_string(),
2359 kind: ColumnKindTag::Categorical,
2360 levels: vec!["a".to_string(), "b".to_string()],
2361 },
2362 SchemaColumn {
2363 name: "x".to_string(),
2364 kind: ColumnKindTag::Categorical,
2365 levels: vec!["low".to_string(), "high".to_string()],
2366 },
2367 ],
2368 };
2369 let headers = vec!["g".to_string(), "x".to_string()];
2370 let records = vec![StringRecord::from(vec!["new-group", "low"])];
2371 let policy =
2372 UnseenCategoryPolicy::encode_unknown_for_columns(HashSet::from(["g".to_string()]));
2373
2374 let ds =
2375 encode_recordswith_schema(headers, records, &schema, policy).expect("encoded dataset");
2376
2377 assert_eq!(ds.values[[0, 0]], 2.0);
2378 assert_eq!(ds.values[[0, 1]], 0.0);
2379 }
2380
2381 #[test]
2382 fn numeric_valued_dictionary_column_classifies_and_decodes_as_numeric() {
2383 use arrow::array::{Array, ArrayRef, DictionaryArray, Int8Array, Int64Array};
2393 use arrow::datatypes::{DataType, Int8Type};
2394 use std::sync::Arc;
2395
2396 let keys = Int8Array::from(vec![0i8, 1, 0, 1, 0]);
2398 let dict_values: ArrayRef = Arc::new(Int64Array::from(vec![5i64, 7]));
2399 let dict: DictionaryArray<Int8Type> = DictionaryArray::new(keys, dict_values);
2400
2401 assert!(matches!(dict.data_type(), DataType::Dictionary(_, _)));
2404 assert!(
2405 !parquet_field_is_string(dict.data_type()),
2406 "Dictionary(Int8, Int64) must not be treated as a string column"
2407 );
2408
2409 let str_dict: DictionaryArray<Int8Type> = vec!["a", "b", "a"].into_iter().collect();
2411 assert!(
2412 parquet_field_is_string(str_dict.data_type()),
2413 "Dictionary(Int8, Utf8) must remain a string column"
2414 );
2415
2416 let decoded = decode_parquet_batch_column(&dict, dict.len(), 0, "x", false)
2420 .expect("numeric dictionary column should decode as numeric");
2421 match decoded {
2422 ParquetBatchColumn::Numeric(values) => {
2423 assert_eq!(values, vec![5.0, 7.0, 5.0, 7.0, 5.0]);
2424 }
2425 ParquetBatchColumn::Strings(_) => {
2426 panic!("numeric dictionary column was decoded as strings");
2427 }
2428 }
2429
2430 use arrow::datatypes::{Field, Schema};
2435 use arrow::record_batch::RecordBatch;
2436 use parquet::arrow::ArrowWriter;
2437
2438 let arrow_schema = Arc::new(Schema::new(vec![Field::new(
2439 "x",
2440 dict.data_type().clone(),
2441 false,
2442 )]));
2443 let batch = RecordBatch::try_new(arrow_schema.clone(), vec![Arc::new(dict.clone())])
2444 .expect("record batch with a dictionary numeric column");
2445
2446 let dir = tempfile::tempdir().expect("tempdir");
2447 let path = dir.path().join("dict_numeric.parquet");
2448 {
2449 let file = std::fs::File::create(&path).expect("create parquet");
2450 let mut writer =
2451 ArrowWriter::try_new(file, arrow_schema, None).expect("arrow parquet writer");
2452 writer.write(&batch).expect("write batch");
2453 writer.close().expect("close writer");
2454 }
2455
2456 let inferred =
2459 load_parquet_inferred(&path, &[], &HashSet::new()).expect("inferred parquet load");
2460 assert_eq!(inferred.column_kinds, vec![ColumnKindTag::Continuous]);
2461 assert_eq!(
2462 inferred.values.column(0).to_vec(),
2463 vec![5.0, 7.0, 5.0, 7.0, 5.0]
2464 );
2465
2466 let schema = DataSchema {
2470 columns: vec![SchemaColumn {
2471 name: "x".to_string(),
2472 kind: ColumnKindTag::Continuous,
2473 levels: Vec::new(),
2474 }],
2475 };
2476 let schema_loaded =
2477 load_parquet_with_schema(&path, &schema, UnseenCategoryPolicy::Error, &[])
2478 .expect("dictionary-encoded numeric parquet must load against a Continuous schema");
2479 assert_eq!(schema_loaded.column_kinds, vec![ColumnKindTag::Continuous]);
2480 assert_eq!(
2481 schema_loaded.values.column(0).to_vec(),
2482 vec![5.0, 7.0, 5.0, 7.0, 5.0]
2483 );
2484 }
2485
2486 #[test]
2487 fn encode_records_keeps_unlisted_categorical_columns_strict() {
2488 let schema = DataSchema {
2489 columns: vec![
2490 SchemaColumn {
2491 name: "g".to_string(),
2492 kind: ColumnKindTag::Categorical,
2493 levels: vec!["a".to_string(), "b".to_string()],
2494 },
2495 SchemaColumn {
2496 name: "x".to_string(),
2497 kind: ColumnKindTag::Categorical,
2498 levels: vec!["low".to_string(), "high".to_string()],
2499 },
2500 ],
2501 };
2502 let headers = vec!["g".to_string(), "x".to_string()];
2503 let records = vec![StringRecord::from(vec!["a", "new-level"])];
2504 let policy =
2505 UnseenCategoryPolicy::encode_unknown_for_columns(HashSet::from(["g".to_string()]));
2506
2507 let err = encode_recordswith_schema(headers, records, &schema, policy)
2508 .expect_err("ordinary categorical column should stay strict");
2509
2510 assert!(err.contains("unseen level 'new-level' in categorical column 'x'"));
2511 }
2512
2513 #[test]
2518 fn sentinel_strip_present_returns_rest_and_true() {
2519 let marked = format!("{}{}", CATEGORICAL_CELL_SENTINEL, "hello");
2520 let (rest, found) = strip_categorical_sentinel(&marked);
2521 assert_eq!(rest, "hello");
2522 assert!(found);
2523 }
2524
2525 #[test]
2526 fn sentinel_strip_absent_returns_original_and_false() {
2527 let (rest, found) = strip_categorical_sentinel("hello");
2528 assert_eq!(rest, "hello");
2529 assert!(!found);
2530 }
2531
2532 #[test]
2533 fn sentinel_strip_empty_string_returns_empty_and_false() {
2534 let (rest, found) = strip_categorical_sentinel("");
2535 assert_eq!(rest, "");
2536 assert!(!found);
2537 }
2538
2539 #[test]
2540 fn sentinel_strip_only_sentinel_returns_empty_and_true() {
2541 let marked = CATEGORICAL_CELL_SENTINEL.to_string();
2542 let (rest, found) = strip_categorical_sentinel(&marked);
2543 assert_eq!(rest, "");
2544 assert!(found);
2545 }
2546
2547 #[test]
2552 fn feature_ranges_two_columns() {
2553 let values = ndarray::arr2(&[[1.0_f64, 10.0], [3.0, 20.0], [2.0, 15.0]]);
2554 let ds = EncodedDataset {
2555 headers: vec!["a".to_string(), "b".to_string()],
2556 values,
2557 schema: DataSchema { columns: vec![] },
2558 column_kinds: vec![ColumnKindTag::Continuous, ColumnKindTag::Continuous],
2559 };
2560 let ranges = ds.feature_ranges();
2561 assert_eq!(ranges.len(), 2);
2562 assert_eq!(ranges[0], (1.0, 3.0));
2563 assert_eq!(ranges[1], (10.0, 20.0));
2564 }
2565
2566 #[test]
2567 fn feature_ranges_single_row_min_equals_max() {
2568 let values = ndarray::arr2(&[[5.0_f64, -3.0]]);
2569 let ds = EncodedDataset {
2570 headers: vec!["x".to_string(), "y".to_string()],
2571 values,
2572 schema: DataSchema { columns: vec![] },
2573 column_kinds: vec![ColumnKindTag::Continuous, ColumnKindTag::Continuous],
2574 };
2575 let ranges = ds.feature_ranges();
2576 assert_eq!(ranges[0], (5.0, 5.0));
2577 assert_eq!(ranges[1], (-3.0, -3.0));
2578 }
2579
2580 #[test]
2581 fn feature_ranges_all_nan_defaults_to_zero() {
2582 let values = ndarray::arr2(&[[f64::NAN], [f64::NAN]]);
2583 let ds = EncodedDataset {
2584 headers: vec!["x".to_string()],
2585 values,
2586 schema: DataSchema { columns: vec![] },
2587 column_kinds: vec![ColumnKindTag::Continuous],
2588 };
2589 let ranges = ds.feature_ranges();
2590 assert_eq!(ranges[0], (0.0, 0.0));
2591 }
2592
2593 #[test]
2598 fn column_map_indexes_by_name() {
2599 let values = ndarray::arr2(&[[0.0_f64, 1.0], [2.0, 3.0]]);
2600 let ds = EncodedDataset {
2601 headers: vec!["alpha".to_string(), "beta".to_string()],
2602 values,
2603 schema: DataSchema { columns: vec![] },
2604 column_kinds: vec![ColumnKindTag::Continuous, ColumnKindTag::Continuous],
2605 };
2606 let map = ds.column_map();
2607 assert_eq!(map["alpha"], 0);
2608 assert_eq!(map["beta"], 1);
2609 assert_eq!(map.len(), 2);
2610 }
2611
2612 #[test]
2615 fn shared_prefix_identical_strings() {
2616 assert_eq!(shared_prefix("hello", "hello"), 5);
2617 }
2618
2619 #[test]
2620 fn shared_prefix_no_common_prefix() {
2621 assert_eq!(shared_prefix("abc", "xyz"), 0);
2622 }
2623
2624 #[test]
2625 fn shared_prefix_partial_match() {
2626 assert_eq!(shared_prefix("foobar", "foobaz"), 5);
2627 }
2628
2629 #[test]
2630 fn shared_prefix_one_empty() {
2631 assert_eq!(shared_prefix("", "hello"), 0);
2632 assert_eq!(shared_prefix("hello", ""), 0);
2633 }
2634
2635 #[test]
2636 fn shared_prefix_both_empty() {
2637 assert_eq!(shared_prefix("", ""), 0);
2638 }
2639
2640 #[test]
2641 fn shared_prefix_shorter_string_is_prefix() {
2642 assert_eq!(shared_prefix("foo", "foobar"), 3);
2643 }
2644
2645 #[test]
2648 fn detect_format_csv() {
2649 let path = std::path::Path::new("data.csv");
2650 assert_eq!(detect_format(path).unwrap(), DataFormat::Csv);
2651 }
2652
2653 #[test]
2654 fn detect_format_tsv() {
2655 assert_eq!(
2656 detect_format(std::path::Path::new("data.tsv")).unwrap(),
2657 DataFormat::Tsv
2658 );
2659 assert_eq!(
2660 detect_format(std::path::Path::new("data.txt")).unwrap(),
2661 DataFormat::Tsv
2662 );
2663 assert_eq!(
2664 detect_format(std::path::Path::new("data.tab")).unwrap(),
2665 DataFormat::Tsv
2666 );
2667 }
2668
2669 #[test]
2670 fn detect_format_parquet() {
2671 assert_eq!(
2672 detect_format(std::path::Path::new("data.parquet")).unwrap(),
2673 DataFormat::Parquet
2674 );
2675 assert_eq!(
2676 detect_format(std::path::Path::new("data.pq")).unwrap(),
2677 DataFormat::Parquet
2678 );
2679 assert_eq!(
2680 detect_format(std::path::Path::new("data.pqt")).unwrap(),
2681 DataFormat::Parquet
2682 );
2683 }
2684
2685 #[test]
2686 fn detect_format_uppercase_extension() {
2687 assert_eq!(
2688 detect_format(std::path::Path::new("data.CSV")).unwrap(),
2689 DataFormat::Csv
2690 );
2691 }
2692
2693 #[test]
2694 fn detect_format_unknown_extension_is_error() {
2695 let err = detect_format(std::path::Path::new("data.json")).unwrap_err();
2696 let msg = format!("{err:?}");
2697 assert!(
2698 msg.contains("json") || msg.contains("unsupported"),
2699 "error should mention extension, got: {msg}"
2700 );
2701 }
2702
2703 #[test]
2706 fn strip_categorical_sentinel_marked_cell() {
2707 let marked = "\u{0}hello";
2709 let (text, found) = strip_categorical_sentinel(marked);
2710 assert!(found);
2711 assert_eq!(text, "hello");
2712 }
2713
2714 #[test]
2715 fn strip_categorical_sentinel_unmarked_cell() {
2716 let (text, found) = strip_categorical_sentinel("plain");
2717 assert!(!found);
2718 assert_eq!(text, "plain");
2719 }
2720
2721 #[test]
2722 fn strip_categorical_sentinel_empty_string() {
2723 let (text, found) = strip_categorical_sentinel("");
2724 assert!(!found);
2725 assert_eq!(text, "");
2726 }
2727
2728 #[test]
2729 fn strip_categorical_sentinel_only_sentinel() {
2730 let s = "\u{0}";
2731 let (text, found) = strip_categorical_sentinel(s);
2732 assert!(found);
2733 assert_eq!(text, "");
2734 }
2735
2736 #[test]
2739 fn projected_headers_selects_by_index() {
2740 let all = vec!["a".to_string(), "b".to_string(), "c".to_string(), "d".to_string()];
2741 let selected = projected_headers(&all, &[1, 3]);
2742 assert_eq!(selected, vec!["b".to_string(), "d".to_string()]);
2743 }
2744
2745 #[test]
2746 fn projected_headers_empty_selection() {
2747 let all = vec!["x".to_string(), "y".to_string()];
2748 let selected = projected_headers(&all, &[]);
2749 assert!(selected.is_empty());
2750 }
2751
2752 #[test]
2753 fn projected_headers_all_indices() {
2754 let all = vec!["p".to_string(), "q".to_string()];
2755 let selected = projected_headers(&all, &[0, 1]);
2756 assert_eq!(selected, all);
2757 }
2758}