1use csv::{ReaderBuilder, StringRecord};
2use ndarray::{Array2, Axis};
3use rayon::prelude::*;
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, HashSet};
6use std::fmt;
7use std::path::Path;
8
9#[derive(Debug, Clone)]
20pub enum DataError {
21 SchemaMismatch { reason: String },
26 ParseError { reason: String },
30 EncodingFailure { reason: String },
34 EmptyInput { reason: String },
37 InvalidValue { reason: String },
41 ColumnNotFound {
48 name: String,
50 role: Option<String>,
54 available: Vec<String>,
56 similar: Vec<String>,
59 tsv_hint: bool,
64 },
65}
66
67impl DataError {
68 pub fn column_not_found(
74 col_map: &HashMap<String, usize>,
75 name: &str,
76 role: Option<&str>,
77 ) -> Self {
78 let target_lower = name.to_lowercase();
79 let mut similar: Vec<String> = col_map
80 .keys()
81 .filter(|k| {
82 let k_lower = k.to_lowercase();
83 k_lower.contains(&target_lower)
84 || target_lower.contains(&k_lower)
85 || shared_prefix(&k_lower, &target_lower) >= 3
86 })
87 .cloned()
88 .collect();
89 similar.sort_unstable();
90 let mut available: Vec<String> = col_map.keys().cloned().collect();
91 available.sort_unstable();
92 let tsv_hint = available.len() == 1 && available[0].contains('\t');
93 Self::ColumnNotFound {
94 name: name.to_string(),
95 role: role.map(str::to_string),
96 available,
97 similar,
98 tsv_hint,
99 }
100 }
101}
102
103impl fmt::Display for DataError {
104 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105 match self {
106 DataError::SchemaMismatch { reason }
107 | DataError::ParseError { reason }
108 | DataError::EncodingFailure { reason }
109 | DataError::EmptyInput { reason }
110 | DataError::InvalidValue { reason } => f.write_str(reason),
111 DataError::ColumnNotFound {
112 name,
113 role,
114 available,
115 similar,
116 tsv_hint,
117 } => {
118 let label = match role {
119 Some(r) => format!("{r} column '{name}'"),
120 None => format!("column '{name}'"),
121 };
122 let tsv_suffix = if *tsv_hint {
123 " — your file appears to be tab-separated; gam expects comma-separated CSV. \
124 Replace tabs with commas, or pre-convert with `tr '\\t' ',' < file.tsv > file.csv`."
125 } else {
126 ""
127 };
128 if similar.is_empty() {
129 write!(
130 f,
131 "{label} not found in data. Available columns: [{}]{tsv_suffix}",
132 available.join(", ")
133 )
134 } else {
135 write!(
136 f,
137 "{label} not found in data. Did you mean one of [{}]? Full list: [{}]{tsv_suffix}",
138 similar.join(", "),
139 available.join(", ")
140 )
141 }
142 }
143 }
144 }
145}
146
147impl std::error::Error for DataError {}
148
149impl From<DataError> for String {
150 fn from(err: DataError) -> String {
151 err.to_string()
152 }
153}
154
155#[derive(Clone, Debug, Serialize, Deserialize)]
160pub struct DataSchema {
161 pub columns: Vec<SchemaColumn>,
162}
163
164#[derive(Clone, Debug, Serialize, Deserialize)]
165pub struct SchemaColumn {
166 pub name: String,
167 pub kind: ColumnKindTag,
168 #[serde(default)]
169 pub levels: Vec<String>,
170}
171
172#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
173#[serde(rename_all = "kebab-case")]
174pub enum ColumnKindTag {
175 Continuous,
176 Binary,
177 Categorical,
178}
179
180#[derive(Clone, Debug, Eq, PartialEq)]
181pub enum UnseenCategoryPolicy {
182 Error,
183 EncodeUnknownForColumns(HashSet<String>),
184}
185
186impl UnseenCategoryPolicy {
187 pub fn encode_unknown_for_columns(columns: HashSet<String>) -> Self {
188 if columns.is_empty() {
189 Self::Error
190 } else {
191 Self::EncodeUnknownForColumns(columns)
192 }
193 }
194
195 fn unseen_code_for(&self, column_name: &str, level_count: usize) -> Option<f64> {
196 match self {
197 Self::Error => None,
198 Self::EncodeUnknownForColumns(columns) => {
199 columns.contains(column_name).then_some(level_count as f64)
200 }
201 }
202 }
203}
204
205#[derive(Clone, Debug)]
206pub struct EncodedDataset {
207 pub headers: Vec<String>,
208 pub values: Array2<f64>,
209 pub schema: DataSchema,
210 pub column_kinds: Vec<ColumnKindTag>,
211}
212
213impl EncodedDataset {
214 pub fn column_map(&self) -> HashMap<String, usize> {
215 self.headers
216 .iter()
217 .enumerate()
218 .map(|(index, header)| (header.clone(), index))
219 .collect()
220 }
221
222 pub fn feature_ranges(&self) -> Vec<(f64, f64)> {
228 self.values
235 .axis_iter(Axis(1))
236 .into_par_iter()
237 .map(|col| {
238 let (lo, hi) =
239 col.iter()
240 .fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
241 if v.is_finite() {
242 (lo.min(v), hi.max(v))
243 } else {
244 (lo, hi)
245 }
246 });
247 if !lo.is_finite() || !hi.is_finite() {
248 (0.0, 0.0)
249 } else {
250 (lo, hi)
251 }
252 })
253 .collect()
254 }
255}
256
257fn shared_prefix(a: &str, b: &str) -> usize {
258 a.chars()
259 .zip(b.chars())
260 .take_while(|(ca, cb)| ca == cb)
261 .count()
262}
263
264#[derive(Clone, Copy, Debug, Eq, PartialEq)]
269enum DataFormat {
270 Csv,
271 Tsv,
272 Parquet,
273}
274
275fn detect_format(path: &Path) -> Result<DataFormat, DataError> {
276 let ext = path
277 .extension()
278 .and_then(|s| s.to_str())
279 .unwrap_or_default()
280 .to_ascii_lowercase();
281 match ext.as_str() {
282 "csv" => Ok(DataFormat::Csv),
283 "tsv" | "txt" | "tab" => Ok(DataFormat::Tsv),
284 "parquet" | "pq" | "pqt" => Ok(DataFormat::Parquet),
285 other => Err(DataError::ParseError {
286 reason: format!(
287 "unsupported data file extension '.{other}'; expected csv, tsv, txt, parquet, or pq: '{}'",
288 path.display()
289 ),
290 }),
291 }
292}
293
294pub fn load_dataset_projected(
299 path: &Path,
300 requested_columns: &[String],
301) -> Result<EncodedDataset, DataError> {
302 load_dataset_projected_with_categorical_roles(path, requested_columns, &HashSet::new())
303}
304
305pub fn load_dataset_projected_with_categorical_roles(
327 path: &Path,
328 requested_columns: &[String],
329 categorical_roles: &HashSet<&str>,
330) -> Result<EncodedDataset, DataError> {
331 match detect_format(path)? {
332 DataFormat::Csv => {
333 load_delimited_inferred(path, b',', requested_columns, categorical_roles)
334 }
335 DataFormat::Tsv => {
336 load_delimited_inferred(path, b'\t', requested_columns, categorical_roles)
337 }
338 DataFormat::Parquet => load_parquet_inferred(path, requested_columns, categorical_roles),
339 }
340}
341
342pub fn load_datasetwith_schema_projected(
343 path: &Path,
344 schema: &DataSchema,
345 unseen_policy: UnseenCategoryPolicy,
346 requested_columns: &[String],
347) -> Result<EncodedDataset, DataError> {
348 match detect_format(path)? {
349 DataFormat::Csv => {
350 load_delimited_with_schema(path, b',', schema, unseen_policy, requested_columns)
351 }
352 DataFormat::Tsv => {
353 load_delimited_with_schema(path, b'\t', schema, unseen_policy, requested_columns)
354 }
355 DataFormat::Parquet => {
356 load_parquet_with_schema(path, schema, unseen_policy, requested_columns)
357 }
358 }
359}
360
361pub fn load_csvwith_inferred_schema(path: &Path) -> Result<EncodedDataset, DataError> {
366 load_delimited_inferred(path, b',', &[], &HashSet::new())
367}
368
369const SCHEMA_SAMPLE_ROWS: usize = 1024;
375
376pub const CATEGORICAL_CELL_SENTINEL: char = '\u{0}';
386
387pub fn strip_categorical_sentinel(cell: &str) -> (&str, bool) {
390 match cell.strip_prefix(CATEGORICAL_CELL_SENTINEL) {
391 Some(rest) => (rest, true),
392 None => (cell, false),
393 }
394}
395
396fn resolve_requested_columns(
397 all_headers: &[String],
398 requested_columns: &[String],
399) -> Result<Vec<usize>, DataError> {
400 if requested_columns.is_empty() {
401 return Ok((0..all_headers.len()).collect());
402 }
403
404 let requested_set: HashSet<&str> = requested_columns.iter().map(String::as_str).collect();
405 let mut selected = Vec::with_capacity(requested_set.len());
406 for (idx, name) in all_headers.iter().enumerate() {
407 if requested_set.contains(name.as_str()) {
408 selected.push(idx);
409 }
410 }
411
412 if selected.len() != requested_set.len() {
413 let available_map: HashMap<String, usize> = all_headers
414 .iter()
415 .enumerate()
416 .map(|(index, header)| (header.clone(), index))
417 .collect();
418 let missing = requested_columns
419 .iter()
420 .filter(|name| !available_map.contains_key(name.as_str()))
421 .map(|name| {
422 DataError::column_not_found(&available_map, name, Some("requested")).to_string()
423 })
424 .collect::<Vec<_>>();
425 return Err(DataError::SchemaMismatch {
426 reason: missing.join("; "),
427 });
428 }
429
430 Ok(selected)
431}
432
433fn projected_headers(all_headers: &[String], selected_indices: &[usize]) -> Vec<String> {
434 selected_indices
435 .iter()
436 .map(|&idx| all_headers[idx].clone())
437 .collect()
438}
439
440fn load_delimited_inferred(
441 path: &Path,
442 delimiter: u8,
443 requested_columns: &[String],
444 categorical_roles: &HashSet<&str>,
445) -> Result<EncodedDataset, DataError> {
446 let t_open = std::time::Instant::now();
447 let mut rdr = ReaderBuilder::new()
448 .has_headers(true)
449 .delimiter(delimiter)
450 .from_path(path)
451 .map_err(|e| DataError::ParseError {
452 reason: format!("failed to open '{}': {e}", path.display()),
453 })?;
454
455 let all_headers: Vec<String> = rdr
456 .headers()
457 .map_err(|e| DataError::ParseError {
458 reason: format!("failed to read headers: {e}"),
459 })?
460 .iter()
461 .map(|s| s.trim().to_string())
462 .collect();
463 if all_headers.is_empty() {
464 return Err(DataError::EmptyInput {
465 reason: "file has no headers".to_string(),
466 });
467 }
468 let selected_indices = resolve_requested_columns(&all_headers, requested_columns)?;
469 let headers = projected_headers(&all_headers, &selected_indices);
470 let p = headers.len();
471 let open_ms = t_open.elapsed().as_secs_f64() * 1000.0;
472 if open_ms > 100.0 {
473 log::info!(
474 "[DATA-LOAD] delim_open+headers | n_headers={} | n_proj={} | {:.1}ms",
475 all_headers.len(),
476 p,
477 open_ms
478 );
479 }
480
481 let mut raw_fields = Vec::<String>::new();
489 let mut total_rows: usize = 0;
490 let mut stream_error: Option<DataError> = None;
491
492 let t_stream = std::time::Instant::now();
493 let mut record = StringRecord::new();
494 while rdr
495 .read_record(&mut record)
496 .map_err(|e| DataError::ParseError {
497 reason: format!("failed reading row: {e}"),
498 })?
499 {
500 if record.len() != all_headers.len() {
501 stream_error = Some(DataError::SchemaMismatch {
502 reason: format!(
503 "row width mismatch at row {}: got {} fields, expected {}",
504 total_rows + 1,
505 record.len(),
506 all_headers.len()
507 ),
508 });
509 break;
510 }
511 total_rows += 1;
512
513 for &selected_idx in &selected_indices {
514 let raw = record.get(selected_idx).unwrap().trim();
515 raw_fields.push(raw.to_string());
516 }
517 }
518
519 let stream_ms = t_stream.elapsed().as_secs_f64() * 1000.0;
520 if stream_ms > 100.0 {
521 log::info!(
522 "[DATA-LOAD] delim_stream | n_rows={} | n_cols={} | {:.1}ms",
523 total_rows,
524 p,
525 stream_ms
526 );
527 }
528
529 if total_rows == 0 {
530 if let Some(err) = stream_error {
531 return Err(err);
532 }
533 return Err(DataError::EmptyInput {
534 reason: "file has no rows".to_string(),
535 });
536 }
537
538 let t_schema = std::time::Instant::now();
539 let sample_count = total_rows.min(SCHEMA_SAMPLE_ROWS);
540 let inferred_columns = (0..p)
541 .into_par_iter()
542 .map(|j| {
543 infer_delimited_column(
544 &raw_fields,
545 total_rows,
546 p,
547 j,
548 &headers[j],
549 sample_count,
550 categorical_roles.contains(headers[j].as_str()),
551 )
552 })
553 .collect::<Vec<_>>();
554
555 let first_error = inferred_columns
556 .iter()
557 .filter_map(|result| result.as_ref().err())
558 .min_by_key(|err| (err.row, err.col));
559 if let Some(err) = first_error {
560 return Err(err.error.clone());
561 }
562 if let Some(err) = stream_error {
563 return Err(err);
564 }
565
566 let inferred_columns = inferred_columns
567 .into_iter()
568 .map(Result::unwrap)
569 .collect::<Vec<_>>();
570
571 let mut schema_cols = Vec::<SchemaColumn>::with_capacity(p);
573 let mut column_kinds = Vec::<ColumnKindTag>::with_capacity(p);
574 for (j, inferred) in inferred_columns.iter().enumerate() {
575 column_kinds.push(inferred.kind);
576 schema_cols.push(SchemaColumn {
577 name: headers[j].clone(),
578 kind: inferred.kind,
579 levels: if matches!(inferred.kind, ColumnKindTag::Categorical) {
580 inferred.levels.clone()
581 } else {
582 Vec::new()
583 },
584 });
585 }
586 let schema_ms = t_schema.elapsed().as_secs_f64() * 1000.0;
587 if schema_ms > 100.0 {
588 let n_cat = column_kinds
589 .iter()
590 .filter(|k| matches!(k, ColumnKindTag::Categorical))
591 .count();
592 log::info!(
593 "[DATA-LOAD] delim_convert+infer | n_cols={} | n_cat={} | {:.1}ms",
594 p,
595 n_cat,
596 schema_ms
597 );
598 }
599
600 let t_assemble = std::time::Instant::now();
601 let mut values = Array2::<f64>::zeros((total_rows, p));
603 values
604 .axis_iter_mut(Axis(1))
605 .into_par_iter()
606 .zip(inferred_columns.par_iter())
607 .for_each(|(mut out_col, inferred)| {
608 for (dst, &src) in out_col.iter_mut().zip(inferred.values.iter()) {
609 *dst = src;
610 }
611 });
612 let assemble_ms = t_assemble.elapsed().as_secs_f64() * 1000.0;
613 if assemble_ms > 100.0 {
614 log::info!(
615 "[DATA-LOAD] delim_assemble_array2 | n_rows={} | n_cols={} | {:.1}ms",
616 total_rows,
617 p,
618 assemble_ms
619 );
620 }
621
622 let schema = DataSchema {
623 columns: schema_cols,
624 };
625 Ok(EncodedDataset {
626 headers,
627 values,
628 schema,
629 column_kinds,
630 })
631}
632
633struct InferredDelimitedColumn {
634 values: Vec<f64>,
635 kind: ColumnKindTag,
636 levels: Vec<String>,
637}
638
639#[derive(Debug)]
640struct DelimitedInferenceError {
641 row: usize,
642 col: usize,
643 error: DataError,
644}
645
646fn infer_delimited_column(
647 raw_fields: &[String],
648 total_rows: usize,
649 n_cols: usize,
650 col: usize,
651 header: &str,
652 sample_count: usize,
653 force_categorical: bool,
654) -> Result<InferredDelimitedColumn, DelimitedInferenceError> {
655 let mut values = Vec::<f64>::with_capacity(total_rows);
657 let mut all_numeric = true;
658 let mut all_binary = true;
659 let mut level_index = HashMap::<String, usize>::new();
660 let mut levels = Vec::<String>::new();
661
662 let non_finite_err = |row_idx: usize| DelimitedInferenceError {
666 row: row_idx + 1,
667 col,
668 error: DataError::InvalidValue {
669 reason: format!(
670 "non-finite value at row {}, column '{}'",
671 row_idx + 1,
672 header
673 ),
674 },
675 };
676
677 for row_idx in 0..total_rows {
678 let raw = raw_fields[row_idx * n_cols + col].as_str();
679 if raw.is_empty() {
680 return Err(DelimitedInferenceError {
681 row: row_idx + 1,
682 col,
683 error: DataError::EmptyInput {
684 reason: format!("empty field at row {}, column '{}'", row_idx + 1, header),
685 },
686 });
687 }
688
689 if row_idx < sample_count {
691 if let Ok(v) = raw.parse::<f64>() {
692 if !v.is_finite() {
693 return Err(non_finite_err(row_idx));
694 }
695 if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
696 all_binary = false;
697 }
698 values.push(v);
699 } else {
700 all_numeric = false;
701 all_binary = false;
702 level_index.entry(raw.to_string()).or_insert_with(|| {
703 let idx = levels.len();
704 levels.push(raw.to_string());
705 idx
706 });
707 values.push(f64::NAN);
711 }
712 } else if let Ok(v) = raw.parse::<f64>() {
713 if !v.is_finite() {
717 return Err(non_finite_err(row_idx));
718 }
719 if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
720 all_binary = false;
721 }
722 values.push(v);
723 } else {
724 all_numeric = false;
725 all_binary = false;
726 let idx = *level_index.entry(raw.to_string()).or_insert_with(|| {
727 let new_idx = levels.len();
728 levels.push(raw.to_string());
729 new_idx
730 });
731 values.push(idx as f64);
732 }
733 }
734
735 let kind = if force_categorical {
744 ColumnKindTag::Categorical
745 } else if all_numeric {
746 if all_binary {
747 ColumnKindTag::Binary
748 } else {
749 ColumnKindTag::Continuous
750 }
751 } else {
752 ColumnKindTag::Categorical
753 };
754
755 if matches!(kind, ColumnKindTag::Categorical) {
756 for row_idx in 0..total_rows {
777 let raw = raw_fields[row_idx * n_cols + col].as_str();
778 level_index.entry(raw.to_string()).or_insert_with(|| {
779 let new_idx = levels.len();
780 levels.push(raw.to_string());
781 new_idx
782 });
783 }
784 levels.sort();
785 level_index.clear();
786 for (idx, level) in levels.iter().enumerate() {
787 level_index.insert(level.clone(), idx);
788 }
789 for row_idx in 0..total_rows {
790 let raw = raw_fields[row_idx * n_cols + col].as_str();
791 values[row_idx] = level_index[raw] as f64;
792 }
793 }
794
795 for (row_idx, &v) in values.iter().enumerate() {
796 if !v.is_finite() {
797 return Err(non_finite_err(row_idx));
798 }
799 }
800
801 Ok(InferredDelimitedColumn {
802 values,
803 kind,
804 levels,
805 })
806}
807
808fn load_delimited_with_schema(
809 path: &Path,
810 delimiter: u8,
811 schema: &DataSchema,
812 unseen_policy: UnseenCategoryPolicy,
813 requested_columns: &[String],
814) -> Result<EncodedDataset, DataError> {
815 let t_open = std::time::Instant::now();
816 let mut rdr = ReaderBuilder::new()
817 .has_headers(true)
818 .delimiter(delimiter)
819 .from_path(path)
820 .map_err(|e| DataError::ParseError {
821 reason: format!("failed to open '{}': {e}", path.display()),
822 })?;
823
824 let all_headers: Vec<String> = rdr
825 .headers()
826 .map_err(|e| DataError::ParseError {
827 reason: format!("failed to read headers: {e}"),
828 })?
829 .iter()
830 .map(|s| s.trim().to_string())
831 .collect();
832 if all_headers.is_empty() {
833 return Err(DataError::EmptyInput {
834 reason: "file has no headers".to_string(),
835 });
836 }
837 let selected_indices = resolve_requested_columns(&all_headers, requested_columns)?;
838 let headers = projected_headers(&all_headers, &selected_indices);
839 let p = headers.len();
840 let open_ms = t_open.elapsed().as_secs_f64() * 1000.0;
841 if open_ms > 100.0 {
842 log::info!(
843 "[DATA-LOAD] delim_schema_open+headers | n_headers={} | n_proj={} | {:.1}ms",
844 all_headers.len(),
845 p,
846 open_ms
847 );
848 }
849
850 let schema_byname: HashMap<&str, &SchemaColumn> = schema
852 .columns
853 .iter()
854 .map(|c| (c.name.as_str(), c))
855 .collect();
856
857 let mut col_meta = Vec::<ColMeta>::with_capacity(p);
858 for name in &headers {
859 if let Some(sc) = schema_byname.get(name.as_str()) {
860 let level_map = if matches!(sc.kind, ColumnKindTag::Categorical) {
861 Some(
862 sc.levels
863 .iter()
864 .enumerate()
865 .map(|(idx, v)| (v.clone(), idx as f64))
866 .collect::<HashMap<_, _>>(),
867 )
868 } else {
869 None
870 };
871 col_meta.push(ColMeta {
872 kind: sc.kind,
873 level_map,
874 schema_col: (*sc).clone(),
875 });
876 } else {
877 col_meta.push(ColMeta {
879 kind: ColumnKindTag::Continuous, level_map: None,
881 schema_col: SchemaColumn {
882 name: name.clone(),
883 kind: ColumnKindTag::Continuous,
884 levels: Vec::new(),
885 },
886 });
887 }
888 }
889
890 let needs_inference: Vec<bool> = headers
892 .iter()
893 .map(|h| !schema_byname.contains_key(h.as_str()))
894 .collect();
895
896 let mut col_vecs: Vec<Vec<f64>> = vec![Vec::new(); p];
898 let mut infer_all_numeric: Vec<bool> = vec![true; p];
900 let mut infer_all_binary: Vec<bool> = vec![true; p];
901 let mut infer_level_index: Vec<HashMap<String, usize>> = vec![HashMap::new(); p];
902 let mut infer_levels: Vec<Vec<String>> = vec![Vec::new(); p];
903 let mut infer_strings: Vec<Vec<(usize, String)>> = vec![Vec::new(); p]; let mut total_rows: usize = 0;
906 let t_stream = std::time::Instant::now();
907 let mut record = StringRecord::new();
908 while rdr
909 .read_record(&mut record)
910 .map_err(|e| DataError::ParseError {
911 reason: format!("failed reading row: {e}"),
912 })?
913 {
914 if record.len() != all_headers.len() {
915 return Err(DataError::SchemaMismatch {
916 reason: format!(
917 "row width mismatch at row {}: got {} fields, expected {}",
918 total_rows + 1,
919 record.len(),
920 all_headers.len()
921 ),
922 });
923 }
924 total_rows += 1;
925
926 for j in 0..p {
927 let raw = record.get(selected_indices[j]).unwrap().trim();
928 if raw.is_empty() {
929 return Err(DataError::EmptyInput {
930 reason: format!(
931 "empty field at row {}, column '{}'",
932 total_rows, &headers[j]
933 ),
934 });
935 }
936
937 if needs_inference[j] {
938 if let Ok(v) = raw.parse::<f64>() {
940 if !v.is_finite() {
941 return Err(DataError::InvalidValue {
942 reason: format!(
943 "non-finite value at row {}, column '{}'",
944 total_rows, &headers[j]
945 ),
946 });
947 }
948 if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
949 infer_all_binary[j] = false;
950 }
951 col_vecs[j].push(v);
952 infer_strings[j].push((total_rows - 1, raw.to_string()));
961 } else {
962 infer_all_numeric[j] = false;
963 infer_all_binary[j] = false;
964 let levels_ref = &mut infer_levels[j];
965 infer_level_index[j]
966 .entry(raw.to_string())
967 .or_insert_with(|| {
968 let idx = levels_ref.len();
969 levels_ref.push(raw.to_string());
970 idx
971 });
972 infer_strings[j].push((total_rows - 1, raw.to_string()));
973 col_vecs[j].push(f64::NAN); }
975 } else {
976 let val = parse_cell_with_schema(
978 raw,
979 &col_meta[j],
980 total_rows,
981 &headers[j],
982 &unseen_policy,
983 )?;
984 col_vecs[j].push(val);
985 }
986 }
987 }
988
989 let stream_ms = t_stream.elapsed().as_secs_f64() * 1000.0;
990 if stream_ms > 100.0 {
991 let n_inf = needs_inference.iter().filter(|x| **x).count();
992 log::info!(
993 "[DATA-LOAD] delim_schema_stream | n_rows={} | n_cols={} | n_inf={} | {:.1}ms",
994 total_rows,
995 p,
996 n_inf,
997 stream_ms
998 );
999 }
1000
1001 if total_rows == 0 {
1002 return Err(DataError::EmptyInput {
1003 reason: "file has no rows".to_string(),
1004 });
1005 }
1006
1007 let t_finalize = std::time::Instant::now();
1008 let mut column_kinds = Vec::<ColumnKindTag>::with_capacity(p);
1010 for j in 0..p {
1011 if needs_inference[j] {
1012 let kind = if infer_all_numeric[j] {
1013 if infer_all_binary[j] {
1014 ColumnKindTag::Binary
1015 } else {
1016 ColumnKindTag::Continuous
1017 }
1018 } else {
1019 ColumnKindTag::Categorical
1020 };
1021 col_meta[j].kind = kind;
1022 col_meta[j].schema_col.kind = kind;
1023 if matches!(kind, ColumnKindTag::Categorical) {
1024 for (_, raw) in &infer_strings[j] {
1037 let levels_ref = &mut infer_levels[j];
1038 infer_level_index[j].entry(raw.clone()).or_insert_with(|| {
1039 let new_idx = levels_ref.len();
1040 levels_ref.push(raw.clone());
1041 new_idx
1042 });
1043 }
1044 infer_levels[j].sort();
1045 infer_level_index[j].clear();
1046 for (idx, level) in infer_levels[j].iter().enumerate() {
1047 infer_level_index[j].insert(level.clone(), idx);
1048 }
1049 for (row_idx, raw) in &infer_strings[j] {
1050 col_vecs[j][*row_idx] = infer_level_index[j][raw] as f64;
1051 }
1052 col_meta[j].schema_col.levels = infer_levels[j].clone();
1053 }
1054 }
1055 column_kinds.push(col_meta[j].kind);
1056 }
1057 let finalize_ms = t_finalize.elapsed().as_secs_f64() * 1000.0;
1058 if finalize_ms > 100.0 {
1059 log::info!(
1060 "[DATA-LOAD] delim_schema_finalize | n_cols={} | {:.1}ms",
1061 p,
1062 finalize_ms
1063 );
1064 }
1065
1066 let t_assemble = std::time::Instant::now();
1067 let mut values = Array2::<f64>::zeros((total_rows, p));
1073 let assemble_err: Option<DataError> = values
1074 .axis_iter_mut(Axis(1))
1075 .into_par_iter()
1076 .zip(col_vecs.par_iter())
1077 .zip(headers.par_iter())
1078 .map(|((mut out_col, col_vec), header)| {
1079 for (i, &v) in col_vec.iter().enumerate() {
1080 if !v.is_finite() {
1081 return Some(DataError::InvalidValue {
1082 reason: format!("non-finite value at row {}, column '{}'", i + 1, header),
1083 });
1084 }
1085 out_col[i] = v;
1086 }
1087 None
1088 })
1089 .reduce(|| None, |a, b| a.or(b));
1090 if let Some(e) = assemble_err {
1091 return Err(e);
1092 }
1093 let assemble_ms = t_assemble.elapsed().as_secs_f64() * 1000.0;
1094 if assemble_ms > 100.0 {
1095 log::info!(
1096 "[DATA-LOAD] delim_schema_assemble | n_rows={} | n_cols={} | {:.1}ms",
1097 total_rows,
1098 p,
1099 assemble_ms
1100 );
1101 }
1102
1103 let schema_out = DataSchema {
1104 columns: col_meta.into_iter().map(|m| m.schema_col).collect(),
1105 };
1106 Ok(EncodedDataset {
1107 headers,
1108 values,
1109 schema: schema_out,
1110 column_kinds,
1111 })
1112}
1113
1114fn parse_cell_with_schema(
1115 raw: &str,
1116 meta: &ColMeta,
1117 row: usize,
1118 col_name: &str,
1119 unseen_policy: &UnseenCategoryPolicy,
1120) -> Result<f64, DataError> {
1121 let val = match meta.kind {
1122 ColumnKindTag::Continuous => raw.parse::<f64>().map_err(|err| {
1123 DataError::SchemaMismatch {
1124 reason: format!(
1125 "column '{}' is continuous in schema but row {} has non-numeric value '{}': {}",
1126 col_name, row, raw, err
1127 ),
1128 }
1129 })?,
1130 ColumnKindTag::Binary => {
1131 let v = raw
1132 .parse::<f64>()
1133 .map_err(|err| DataError::SchemaMismatch {
1134 reason: format!(
1135 "column '{}' is binary in schema but row {} has non-numeric value '{}': {}",
1136 col_name, row, raw, err
1137 ),
1138 })?;
1139 if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
1140 return Err(DataError::SchemaMismatch {
1141 reason: format!(
1142 "column '{}' is binary in schema but row {} has value {}; expected 0 or 1",
1143 col_name, row, v
1144 ),
1145 });
1146 }
1147 v
1148 }
1149 ColumnKindTag::Categorical => {
1150 let map = meta
1151 .level_map
1152 .as_ref()
1153 .ok_or_else(|| DataError::EncodingFailure {
1154 reason: "internal categorical schema map missing".to_string(),
1155 })?;
1156 match map.get(raw) {
1157 Some(v) => *v,
1158 None => unseen_policy
1159 .unseen_code_for(col_name, meta.schema_col.levels.len())
1160 .ok_or_else(|| DataError::SchemaMismatch {
1161 reason: format!(
1162 "unseen level '{}' in categorical column '{}' at row {}",
1163 raw, col_name, row
1164 ),
1165 })?,
1166 }
1167 }
1168 };
1169 if !val.is_finite() {
1170 return Err(DataError::InvalidValue {
1171 reason: format!("non-finite value at row {}, column '{}'", row, col_name),
1172 });
1173 }
1174 Ok(val)
1175}
1176
1177struct ColMeta {
1180 kind: ColumnKindTag,
1181 level_map: Option<HashMap<String, f64>>,
1182 schema_col: SchemaColumn,
1183}
1184
1185enum ParquetBatchColumn {
1190 Numeric(Vec<f64>),
1191 Strings(Vec<String>),
1192}
1193
1194fn parquet_field_is_string(dt: &arrow::datatypes::DataType) -> bool {
1203 use arrow::datatypes::DataType;
1204 match dt {
1205 DataType::Utf8 | DataType::LargeUtf8 => true,
1206 DataType::Dictionary(_, value_type) => parquet_field_is_string(value_type),
1207 _ => false,
1208 }
1209}
1210
1211fn decode_parquet_batch_column(
1212 col: &dyn arrow::array::Array,
1213 n_rows: usize,
1214 base_row: usize,
1215 header: &str,
1216 is_string_col: bool,
1217) -> Result<ParquetBatchColumn, DataError> {
1218 use arrow::array::{
1219 Array as ArrowArray, BooleanArray, Float32Array, Float64Array, Int8Array, Int16Array,
1220 Int32Array, Int64Array, LargeStringArray, StringArray, UInt8Array, UInt16Array,
1221 UInt32Array, UInt64Array,
1222 };
1223 use arrow::datatypes::DataType;
1224
1225 if col.null_count() > 0 {
1226 for i in 0..n_rows {
1227 if col.is_null(i) {
1228 return Err(DataError::InvalidValue {
1229 reason: format!(
1230 "null value at row {}, column '{}'",
1231 base_row + i + 1,
1232 header
1233 ),
1234 });
1235 }
1236 }
1237 }
1238
1239 if is_string_col {
1240 if let Some(arr) = col.as_any().downcast_ref::<StringArray>() {
1241 return Ok(ParquetBatchColumn::Strings(
1242 (0..n_rows).map(|i| arr.value(i).to_string()).collect(),
1243 ));
1244 }
1245 if let Some(arr) = col.as_any().downcast_ref::<LargeStringArray>() {
1246 return Ok(ParquetBatchColumn::Strings(
1247 (0..n_rows).map(|i| arr.value(i).to_string()).collect(),
1248 ));
1249 }
1250
1251 let casted =
1255 arrow::compute::cast(col, &DataType::Utf8).map_err(|e| DataError::ParseError {
1256 reason: format!("failed to cast column '{}' to string: {e}", header),
1257 })?;
1258 let arr = casted
1259 .as_any()
1260 .downcast_ref::<StringArray>()
1261 .ok_or_else(|| DataError::EncodingFailure {
1262 reason: format!("column '{}' could not be read as string after cast", header),
1263 })?;
1264 return Ok(ParquetBatchColumn::Strings(
1265 (0..n_rows).map(|i| arr.value(i).to_string()).collect(),
1266 ));
1267 }
1268
1269 let decoded_col;
1276 let col: &dyn arrow::array::Array = if let DataType::Dictionary(_, value_type) = col.data_type()
1277 {
1278 decoded_col = arrow::compute::cast(col, value_type).map_err(|e| DataError::ParseError {
1279 reason: format!(
1280 "failed to decode dictionary-encoded numeric column '{}': {e}",
1281 header
1282 ),
1283 })?;
1284 decoded_col.as_ref()
1285 } else {
1286 col
1287 };
1288
1289 let mut values = Vec::with_capacity(n_rows);
1290 match col.data_type() {
1291 DataType::Float64 => {
1292 let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();
1293 values.extend(arr.values().iter().copied());
1294 }
1295 DataType::Float32 => {
1296 let arr = col.as_any().downcast_ref::<Float32Array>().unwrap();
1297 values.extend(arr.values().iter().map(|&v| v as f64));
1298 }
1299 DataType::Int64 => {
1300 let arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
1301 values.extend(arr.values().iter().map(|&v| v as f64));
1302 }
1303 DataType::Int32 => {
1304 let arr = col.as_any().downcast_ref::<Int32Array>().unwrap();
1305 values.extend(arr.values().iter().map(|&v| v as f64));
1306 }
1307 DataType::Int16 => {
1308 let arr = col.as_any().downcast_ref::<Int16Array>().unwrap();
1309 values.extend(arr.values().iter().map(|&v| v as f64));
1310 }
1311 DataType::Int8 => {
1312 let arr = col.as_any().downcast_ref::<Int8Array>().unwrap();
1313 values.extend(arr.values().iter().map(|&v| v as f64));
1314 }
1315 DataType::UInt64 => {
1316 let arr = col.as_any().downcast_ref::<UInt64Array>().unwrap();
1317 values.extend(arr.values().iter().map(|&v| v as f64));
1318 }
1319 DataType::UInt32 => {
1320 let arr = col.as_any().downcast_ref::<UInt32Array>().unwrap();
1321 values.extend(arr.values().iter().map(|&v| v as f64));
1322 }
1323 DataType::UInt16 => {
1324 let arr = col.as_any().downcast_ref::<UInt16Array>().unwrap();
1325 values.extend(arr.values().iter().map(|&v| v as f64));
1326 }
1327 DataType::UInt8 => {
1328 let arr = col.as_any().downcast_ref::<UInt8Array>().unwrap();
1329 values.extend(arr.values().iter().map(|&v| v as f64));
1330 }
1331 DataType::Boolean => {
1332 let arr = col.as_any().downcast_ref::<BooleanArray>().unwrap();
1333 values.extend((0..n_rows).map(|i| if arr.value(i) { 1.0 } else { 0.0 }));
1334 }
1335 other => {
1336 return Err(DataError::InvalidValue {
1337 reason: format!(
1338 "unsupported parquet column type {:?} for column '{}'",
1339 other, header
1340 ),
1341 });
1342 }
1343 }
1344
1345 if let Some(i) = values.iter().position(|v| !v.is_finite()) {
1346 return Err(DataError::InvalidValue {
1347 reason: format!(
1348 "non-finite value at row {}, column '{}'",
1349 base_row + i + 1,
1350 header
1351 ),
1352 });
1353 }
1354
1355 Ok(ParquetBatchColumn::Numeric(values))
1356}
1357
1358fn load_parquet_inferred(
1359 path: &Path,
1360 requested_columns: &[String],
1361 categorical_roles: &HashSet<&str>,
1362) -> Result<EncodedDataset, DataError> {
1363 use parquet::arrow::{ProjectionMask, arrow_reader::ParquetRecordBatchReaderBuilder};
1364 use rayon::prelude::*;
1365 use std::fs::File;
1366
1367 let t_open = std::time::Instant::now();
1368 let file = File::open(path).map_err(|e| DataError::ParseError {
1369 reason: format!("failed to open parquet '{}': {e}", path.display()),
1370 })?;
1371 let builder =
1372 ParquetRecordBatchReaderBuilder::try_new(file).map_err(|e| DataError::ParseError {
1373 reason: format!("failed to read parquet metadata '{}': {e}", path.display()),
1374 })?;
1375
1376 let full_schema = builder.schema().clone();
1377 let all_headers: Vec<String> = full_schema
1378 .fields()
1379 .iter()
1380 .map(|f| f.name().clone())
1381 .collect();
1382 if all_headers.is_empty() {
1383 return Err(DataError::EmptyInput {
1384 reason: "parquet file has no columns".to_string(),
1385 });
1386 }
1387 let selected_indices = resolve_requested_columns(&all_headers, requested_columns)?;
1388 let headers = projected_headers(&all_headers, &selected_indices);
1389 let selected_fields = selected_indices
1390 .iter()
1391 .map(|&idx| full_schema.fields()[idx].clone())
1392 .collect::<Vec<_>>();
1393 let projection =
1394 ProjectionMask::roots(builder.parquet_schema(), selected_indices.iter().copied());
1395 let reader =
1396 builder
1397 .with_projection(projection)
1398 .build()
1399 .map_err(|e| DataError::ParseError {
1400 reason: format!("failed to build parquet reader: {e}"),
1401 })?;
1402 let p = headers.len();
1403 let open_ms = t_open.elapsed().as_secs_f64() * 1000.0;
1404 if open_ms > 100.0 {
1405 log::info!(
1406 "[DATA-LOAD] parquet_open+meta | n_headers={} | n_proj={} | {:.1}ms",
1407 all_headers.len(),
1408 p,
1409 open_ms
1410 );
1411 }
1412
1413 let t_batches = std::time::Instant::now();
1414 let mut col_vecs: Vec<Vec<f64>> = vec![Vec::new(); p];
1416 let mut string_cols: Vec<Option<Vec<String>>> = (0..p).map(|_| None).collect();
1418 let mut is_string_col: Vec<bool> = vec![false; p];
1419
1420 for (j, field) in selected_fields.iter().enumerate() {
1421 if parquet_field_is_string(field.data_type()) {
1425 is_string_col[j] = true;
1426 string_cols[j] = Some(Vec::new());
1427 }
1428 }
1429
1430 let mut rows_seen = 0usize;
1431 for batch_result in reader {
1432 let batch = batch_result.map_err(|e| DataError::ParseError {
1433 reason: format!("failed to read parquet record batch: {e}"),
1434 })?;
1435 let n_rows = batch.num_rows();
1436
1437 let decoded_columns: Vec<Result<ParquetBatchColumn, DataError>> = (0..p)
1438 .into_par_iter()
1439 .map(|j| {
1440 decode_parquet_batch_column(
1441 batch.column(j).as_ref(),
1442 n_rows,
1443 rows_seen,
1444 &headers[j],
1445 is_string_col[j],
1446 )
1447 })
1448 .collect();
1449
1450 for (j, decoded) in decoded_columns.into_iter().enumerate() {
1451 match decoded? {
1452 ParquetBatchColumn::Strings(mut strings) => {
1453 assert!(is_string_col[j]);
1454 string_cols[j].as_mut().unwrap().append(&mut strings);
1455 let new_len = col_vecs[j].len() + n_rows;
1456 col_vecs[j].resize(new_len, f64::NAN);
1457 }
1458 ParquetBatchColumn::Numeric(mut values) => {
1459 assert!(!is_string_col[j]);
1460 col_vecs[j].append(&mut values);
1461 }
1462 }
1463 }
1464 rows_seen += n_rows;
1465 }
1466
1467 let total_rows = col_vecs[0].len();
1468 let batches_ms = t_batches.elapsed().as_secs_f64() * 1000.0;
1469 if batches_ms > 100.0 {
1470 log::info!(
1471 "[DATA-LOAD] parquet_batches_decode | n_rows={} | n_cols={} | {:.1}ms",
1472 total_rows,
1473 p,
1474 batches_ms
1475 );
1476 }
1477 if total_rows == 0 {
1478 return Err(DataError::EmptyInput {
1479 reason: "parquet file has no rows".to_string(),
1480 });
1481 }
1482
1483 let t_schema = std::time::Instant::now();
1484 let mut schema_cols = Vec::<SchemaColumn>::with_capacity(p);
1486 let mut column_kinds = Vec::<ColumnKindTag>::with_capacity(p);
1487
1488 let finalized_columns: Vec<(Vec<f64>, ColumnKindTag, SchemaColumn)> = col_vecs
1489 .into_par_iter()
1490 .zip(string_cols.into_par_iter())
1491 .zip(is_string_col.into_par_iter())
1492 .zip(headers.par_iter())
1493 .map(|(((mut col_values, strings), is_string), header)| {
1494 if is_string {
1495 let strings = strings.expect("string column storage missing");
1499 let mut level_index: HashMap<String, usize> = HashMap::new();
1500 let mut levels_vec: Vec<String> = Vec::new();
1501 for s in &strings {
1502 level_index.entry(s.clone()).or_insert_with(|| {
1503 let idx = levels_vec.len();
1504 levels_vec.push(s.clone());
1505 idx
1506 });
1507 }
1508 for (i, s) in strings.iter().enumerate() {
1509 col_values[i] = *level_index.get(s.as_str()).unwrap() as f64;
1510 }
1511 (
1512 col_values,
1513 ColumnKindTag::Categorical,
1514 SchemaColumn {
1515 name: header.clone(),
1516 kind: ColumnKindTag::Categorical,
1517 levels: levels_vec,
1518 },
1519 )
1520 } else if categorical_roles.contains(header.as_str()) {
1521 let labels: Vec<String> = col_values.iter().map(|v| v.to_string()).collect();
1529 let mut levels_vec: Vec<String> = Vec::new();
1530 let mut level_index: HashMap<String, usize> = HashMap::new();
1531 for label in &labels {
1532 level_index.entry(label.clone()).or_insert_with(|| {
1533 let idx = levels_vec.len();
1534 levels_vec.push(label.clone());
1535 idx
1536 });
1537 }
1538 levels_vec.sort();
1539 level_index.clear();
1540 for (idx, level) in levels_vec.iter().enumerate() {
1541 level_index.insert(level.clone(), idx);
1542 }
1543 for (i, label) in labels.iter().enumerate() {
1544 col_values[i] = level_index[label] as f64;
1545 }
1546 (
1547 col_values,
1548 ColumnKindTag::Categorical,
1549 SchemaColumn {
1550 name: header.clone(),
1551 kind: ColumnKindTag::Categorical,
1552 levels: levels_vec,
1553 },
1554 )
1555 } else {
1556 let all_binary = col_values
1558 .iter()
1559 .all(|&v| (v - 0.0).abs() < 1e-12 || (v - 1.0).abs() < 1e-12);
1560 let kind = if all_binary {
1561 ColumnKindTag::Binary
1562 } else {
1563 ColumnKindTag::Continuous
1564 };
1565 (
1566 col_values,
1567 kind,
1568 SchemaColumn {
1569 name: header.clone(),
1570 kind,
1571 levels: Vec::new(),
1572 },
1573 )
1574 }
1575 })
1576 .collect();
1577
1578 let mut col_vecs = Vec::with_capacity(p);
1579 for (col_values, kind, schema_col) in finalized_columns {
1580 col_vecs.push(col_values);
1581 column_kinds.push(kind);
1582 schema_cols.push(schema_col);
1583 }
1584 let schema_ms = t_schema.elapsed().as_secs_f64() * 1000.0;
1585 if schema_ms > 100.0 {
1586 let n_cat = column_kinds
1587 .iter()
1588 .filter(|k| matches!(k, ColumnKindTag::Categorical))
1589 .count();
1590 log::info!(
1591 "[DATA-LOAD] parquet_finalize_schema | n_cols={} | n_cat={} | {:.1}ms",
1592 p,
1593 n_cat,
1594 schema_ms
1595 );
1596 }
1597
1598 let t_assemble = std::time::Instant::now();
1599 let mut values = Array2::<f64>::zeros((total_rows, p));
1604 values
1605 .axis_iter_mut(Axis(1))
1606 .into_par_iter()
1607 .zip(col_vecs.par_iter())
1608 .for_each(|(mut out_col, src)| {
1609 for (dst, &v) in out_col.iter_mut().zip(src.iter()) {
1610 *dst = v;
1611 }
1612 });
1613 let assemble_ms = t_assemble.elapsed().as_secs_f64() * 1000.0;
1614 if assemble_ms > 100.0 {
1615 log::info!(
1616 "[DATA-LOAD] parquet_assemble_array2 | n_rows={} | n_cols={} | {:.1}ms",
1617 total_rows,
1618 p,
1619 assemble_ms
1620 );
1621 }
1622
1623 Ok(EncodedDataset {
1624 headers,
1625 values,
1626 schema: DataSchema {
1627 columns: schema_cols,
1628 },
1629 column_kinds,
1630 })
1631}
1632
1633fn load_parquet_with_schema(
1634 path: &Path,
1635 schema: &DataSchema,
1636 unseen_policy: UnseenCategoryPolicy,
1637 requested_columns: &[String],
1638) -> Result<EncodedDataset, DataError> {
1639 let inferred = load_parquet_inferred(path, requested_columns, &HashSet::new())?;
1643 let p = inferred.headers.len();
1644 let n = inferred.values.nrows();
1645
1646 let schema_byname: HashMap<&str, &SchemaColumn> = schema
1647 .columns
1648 .iter()
1649 .map(|c| (c.name.as_str(), c))
1650 .collect();
1651
1652 let mut column_kinds = Vec::<ColumnKindTag>::with_capacity(p);
1653 let mut schema_cols = Vec::<SchemaColumn>::with_capacity(p);
1654 let mut values = inferred.values;
1655
1656 for j in 0..p {
1657 let name = &inferred.headers[j];
1658 if let Some(sc) = schema_byname.get(name.as_str()) {
1659 column_kinds.push(sc.kind);
1660 schema_cols.push((*sc).clone());
1661
1662 match sc.kind {
1663 ColumnKindTag::Continuous => {
1664 if matches!(inferred.column_kinds[j], ColumnKindTag::Categorical) {
1665 return Err(DataError::SchemaMismatch {
1666 reason: format!(
1667 "column '{}' is continuous in schema but parquet column is string/categorical",
1668 name
1669 ),
1670 });
1671 }
1672 }
1673 ColumnKindTag::Binary => {
1674 if matches!(inferred.column_kinds[j], ColumnKindTag::Categorical) {
1675 return Err(DataError::SchemaMismatch {
1676 reason: format!(
1677 "column '{}' is binary in schema but parquet column is string/categorical",
1678 name
1679 ),
1680 });
1681 }
1682 if let Some(row) = values.column(j).iter().position(|value| {
1683 (*value - 0.0).abs() >= 1e-12 && (*value - 1.0).abs() >= 1e-12
1684 }) {
1685 return Err(DataError::SchemaMismatch {
1686 reason: format!(
1687 "column '{}' is binary in schema but row {} has value {}; expected 0 or 1",
1688 name,
1689 row + 1,
1690 values[[row, j]]
1691 ),
1692 });
1693 }
1694 }
1695 ColumnKindTag::Categorical => {
1696 if !matches!(inferred.column_kinds[j], ColumnKindTag::Categorical) {
1697 return Err(DataError::SchemaMismatch {
1698 reason: format!(
1699 "column '{}' is categorical in schema but parquet column is numeric",
1700 name
1701 ),
1702 });
1703 }
1704 let inferred_col = &inferred.schema.columns[j];
1705 let schema_level_map: HashMap<&str, f64> = sc
1707 .levels
1708 .iter()
1709 .enumerate()
1710 .map(|(idx, v)| (v.as_str(), idx as f64))
1711 .collect();
1712 let inferred_to_schema: Vec<f64> = inferred_col
1713 .levels
1714 .iter()
1715 .map(|lv| {
1716 schema_level_map
1717 .get(lv.as_str())
1718 .copied()
1719 .or_else(|| unseen_policy.unseen_code_for(name, sc.levels.len()))
1720 .ok_or_else(|| DataError::SchemaMismatch {
1721 reason: format!(
1722 "unseen level '{}' in categorical column '{}'",
1723 lv, name
1724 ),
1725 })
1726 })
1727 .collect::<Result<Vec<_>, _>>()?;
1728 for i in 0..n {
1729 let old_code = values[[i, j]] as usize;
1730 if old_code >= inferred_to_schema.len() {
1731 let Some(unseen_code) =
1732 unseen_policy.unseen_code_for(name, sc.levels.len())
1733 else {
1734 return Err(DataError::SchemaMismatch {
1735 reason: format!(
1736 "unseen categorical code at row {}, column '{}'",
1737 i + 1,
1738 name
1739 ),
1740 });
1741 };
1742 values[[i, j]] = unseen_code;
1743 continue;
1744 }
1745 values[[i, j]] = inferred_to_schema[old_code];
1746 }
1747 }
1748 }
1749 } else {
1750 column_kinds.push(inferred.column_kinds[j]);
1752 schema_cols.push(inferred.schema.columns[j].clone());
1753 }
1754 }
1755
1756 Ok(EncodedDataset {
1757 headers: inferred.headers,
1758 values,
1759 schema: DataSchema {
1760 columns: schema_cols,
1761 },
1762 column_kinds,
1763 })
1764}
1765
1766pub fn encode_recordswith_inferred_schema(
1767 headers: Vec<String>,
1768 records: Vec<StringRecord>,
1769) -> Result<EncodedDataset, String> {
1770 if records.is_empty() {
1771 return Err(DataError::EmptyInput {
1772 reason: "table data cannot be empty".to_string(),
1773 }
1774 .into());
1775 }
1776 let schema_cols = headers
1782 .par_iter()
1783 .enumerate()
1784 .map(|(j, name)| infer_schema_column(name, &records, j).map_err(String::from))
1785 .collect::<Result<Vec<SchemaColumn>, String>>()?;
1786 let schema = DataSchema {
1787 columns: schema_cols,
1788 };
1789 encode_recordswith_schema(headers, records, &schema, UnseenCategoryPolicy::Error)
1790}
1791
1792pub fn encode_recordswith_schema(
1793 headers: Vec<String>,
1794 records: Vec<StringRecord>,
1795 schema: &DataSchema,
1796 unseen_policy: UnseenCategoryPolicy,
1797) -> Result<EncodedDataset, String> {
1798 let n = records.len();
1799 if n == 0 {
1800 return Err(DataError::EmptyInput {
1801 reason: "table data cannot be empty".to_string(),
1802 }
1803 .into());
1804 }
1805 let p = headers.len();
1806 if p == 0 {
1807 return Err(DataError::EmptyInput {
1808 reason: "table data must have at least one header column".to_string(),
1809 }
1810 .into());
1811 }
1812 for (i, rec) in records.iter().enumerate() {
1819 if rec.len() != p {
1820 return Err(DataError::SchemaMismatch {
1821 reason: format!(
1822 "row width mismatch at row {}: got {} fields, expected {} (one per header)",
1823 i + 1,
1824 rec.len(),
1825 p
1826 ),
1827 }
1828 .into());
1829 }
1830 }
1831 let schema_byname: HashMap<&str, &SchemaColumn> = schema
1832 .columns
1833 .iter()
1834 .map(|c| (c.name.as_str(), c))
1835 .collect();
1836
1837 let encoded_columns = headers
1843 .par_iter()
1844 .enumerate()
1845 .map(|(j, name)| {
1846 let inferred_for_extra;
1847 let col_schema = if let Some(s) = schema_byname.get(name.as_str()) {
1848 *s
1849 } else {
1850 inferred_for_extra =
1851 infer_schema_column(name, &records, j).map_err(String::from)?;
1852 &inferred_for_extra
1853 };
1854 let column = encode_one_column(name, &records, j, col_schema, &unseen_policy)?;
1855 Ok::<(ColumnKindTag, Vec<f64>), String>((col_schema.kind, column))
1856 })
1857 .collect::<Result<Vec<(ColumnKindTag, Vec<f64>)>, String>>()?;
1858
1859 let mut column_kinds = Vec::<ColumnKindTag>::with_capacity(p);
1860 let mut values = Array2::<f64>::zeros((n, p));
1861 for (j, (kind, column)) in encoded_columns.into_iter().enumerate() {
1862 column_kinds.push(kind);
1863 values
1864 .column_mut(j)
1865 .assign(&ndarray::ArrayView1::from(&column));
1866 }
1867
1868 Ok(EncodedDataset {
1869 headers,
1870 values,
1871 schema: schema.clone(),
1872 column_kinds,
1873 })
1874}
1875
1876fn encode_one_column(
1883 name: &str,
1884 records: &[StringRecord],
1885 j: usize,
1886 col_schema: &SchemaColumn,
1887 unseen_policy: &UnseenCategoryPolicy,
1888) -> Result<Vec<f64>, String> {
1889 let level_map = if matches!(col_schema.kind, ColumnKindTag::Categorical) {
1890 Some(
1891 col_schema
1892 .levels
1893 .iter()
1894 .enumerate()
1895 .map(|(idx, v)| (v.as_str(), idx as f64))
1896 .collect::<HashMap<_, _>>(),
1897 )
1898 } else {
1899 None
1900 };
1901
1902 let mut column = Vec::<f64>::with_capacity(records.len());
1903 for (i, rec) in records.iter().enumerate() {
1904 let raw = rec
1905 .get(j)
1906 .ok_or_else(|| {
1907 String::from(DataError::SchemaMismatch {
1908 reason: format!("missing field at row {}, col {}", i + 1, j + 1),
1909 })
1910 })?
1911 .trim();
1912 if raw.is_empty() {
1913 return Err(DataError::EmptyInput {
1914 reason: format!("empty field at row {}, column '{}'", i + 1, name),
1915 }
1916 .into());
1917 }
1918 let val = match col_schema.kind {
1919 ColumnKindTag::Continuous => raw.parse::<f64>().map_err(|err| {
1920 String::from(DataError::SchemaMismatch {
1921 reason: format!(
1922 "column '{}' is continuous in schema but row {} has non-numeric value '{}': {}",
1923 name,
1924 i + 1,
1925 raw,
1926 err
1927 ),
1928 })
1929 })?,
1930 ColumnKindTag::Binary => {
1931 let v = raw.parse::<f64>().map_err(|err| {
1932 String::from(DataError::SchemaMismatch {
1933 reason: format!(
1934 "column '{}' is binary in schema but row {} has non-numeric value '{}': {}",
1935 name,
1936 i + 1,
1937 raw,
1938 err
1939 ),
1940 })
1941 })?;
1942 if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
1943 return Err(DataError::SchemaMismatch {
1944 reason: format!(
1945 "column '{}' is binary in schema but row {} has value {}; expected 0 or 1",
1946 name,
1947 i + 1,
1948 v
1949 ),
1950 }
1951 .into());
1952 }
1953 v
1954 }
1955 ColumnKindTag::Categorical => {
1956 let map = level_map.as_ref().ok_or_else(|| {
1957 String::from(DataError::EncodingFailure {
1958 reason: "internal categorical schema map missing".to_string(),
1959 })
1960 })?;
1961 match map.get(raw) {
1962 Some(v) => *v,
1963 None => unseen_policy
1964 .unseen_code_for(name, col_schema.levels.len())
1965 .ok_or_else(|| {
1966 String::from(DataError::SchemaMismatch {
1967 reason: format!(
1968 "unseen level '{}' in categorical column '{}' at row {}; allowed levels: {}",
1969 raw,
1970 name,
1971 i + 1,
1972 col_schema.levels.join(",")
1973 ),
1974 })
1975 })?,
1976 }
1977 }
1978 };
1979 if !val.is_finite() {
1980 return Err(DataError::InvalidValue {
1981 reason: format!("non-finite value at row {}, column '{}'", i + 1, name),
1982 }
1983 .into());
1984 }
1985 column.push(val);
1986 }
1987 Ok(column)
1988}
1989
1990fn infer_schema_column(
1991 name: &str,
1992 records: &[StringRecord],
1993 col_idx: usize,
1994) -> Result<SchemaColumn, DataError> {
1995 let mut all_numeric = true;
1996 let mut all_binary = true;
1997 let mut levels = Vec::<String>::new();
1998 let mut level_index = HashMap::<String, usize>::new();
1999 for (i, rec) in records.iter().enumerate() {
2000 let raw = rec
2001 .get(col_idx)
2002 .ok_or_else(|| DataError::SchemaMismatch {
2003 reason: format!("missing field at row {}, col {}", i + 1, col_idx + 1),
2004 })?
2005 .trim();
2006 if raw.is_empty() {
2007 return Err(DataError::EmptyInput {
2008 reason: format!("empty field at row {}, column '{}'", i + 1, name),
2009 });
2010 }
2011 if let Ok(v) = raw.parse::<f64>() {
2012 if !v.is_finite() {
2013 return Err(DataError::InvalidValue {
2014 reason: format!("non-finite value at row {}, column '{}'", i + 1, name),
2015 });
2016 }
2017 if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
2018 all_binary = false;
2019 }
2020 } else {
2021 all_numeric = false;
2022 all_binary = false;
2023 level_index.entry(raw.to_string()).or_insert_with(|| {
2024 let idx = levels.len();
2025 levels.push(raw.to_string());
2026 idx
2027 });
2028 }
2029 }
2030 let kind = if all_numeric {
2031 if all_binary {
2032 ColumnKindTag::Binary
2033 } else {
2034 ColumnKindTag::Continuous
2035 }
2036 } else {
2037 ColumnKindTag::Categorical
2038 };
2039 if matches!(kind, ColumnKindTag::Categorical) {
2043 levels.sort();
2044 }
2045 Ok(SchemaColumn {
2046 name: name.to_string(),
2047 kind,
2048 levels: if matches!(kind, ColumnKindTag::Categorical) {
2049 levels
2050 } else {
2051 Vec::new()
2052 },
2053 })
2054}
2055
2056pub fn infer_and_encode_column_major(
2069 name: &str,
2070 column: &[&str],
2071 col_index: usize,
2072) -> Result<(SchemaColumn, Vec<f64>), String> {
2073 if column.is_empty() {
2074 return Err(DataError::EmptyInput {
2075 reason: "table data cannot be empty".to_string(),
2076 }
2077 .into());
2078 }
2079 let force_categorical = column.iter().any(|c| strip_categorical_sentinel(c).1);
2084 let mut all_numeric = !force_categorical;
2085 let mut all_binary = !force_categorical;
2086 let mut levels = Vec::<String>::new();
2087 let mut level_index = HashMap::<String, usize>::new();
2088 let mut trimmed = Vec::<&str>::with_capacity(column.len());
2089 let mut parsed = Vec::<Option<f64>>::with_capacity(column.len());
2096 for (i, raw_field) in column.iter().enumerate() {
2097 let (raw, _) = strip_categorical_sentinel(raw_field);
2100 let raw = raw.trim();
2101 if raw.is_empty() {
2102 return Err(DataError::EmptyInput {
2103 reason: format!("empty field at row {}, column '{}'", i + 1, name),
2104 }
2105 .into());
2106 }
2107 if !force_categorical {
2110 if let Ok(v) = raw.parse::<f64>() {
2111 if !v.is_finite() {
2112 return Err(DataError::InvalidValue {
2113 reason: format!("non-finite value at row {}, column '{}'", i + 1, name),
2114 }
2115 .into());
2116 }
2117 if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
2118 all_binary = false;
2119 }
2120 parsed.push(Some(v));
2121 trimmed.push(raw);
2122 continue;
2123 }
2124 all_numeric = false;
2125 all_binary = false;
2126 }
2127 level_index.entry(raw.to_string()).or_insert_with(|| {
2128 let idx = levels.len();
2129 levels.push(raw.to_string());
2130 idx
2131 });
2132 parsed.push(None);
2133 trimmed.push(raw);
2134 }
2135 let kind = if all_numeric {
2136 if all_binary {
2137 ColumnKindTag::Binary
2138 } else {
2139 ColumnKindTag::Continuous
2140 }
2141 } else {
2142 ColumnKindTag::Categorical
2143 };
2144 if matches!(kind, ColumnKindTag::Categorical) {
2157 levels.sort();
2158 }
2159 let schema = SchemaColumn {
2160 name: name.to_string(),
2161 kind,
2162 levels: if matches!(kind, ColumnKindTag::Categorical) {
2163 levels
2164 } else {
2165 Vec::new()
2166 },
2167 };
2168
2169 let level_map = if matches!(kind, ColumnKindTag::Categorical) {
2170 Some(
2171 schema
2172 .levels
2173 .iter()
2174 .enumerate()
2175 .map(|(idx, v)| (v.as_str(), idx as f64))
2176 .collect::<HashMap<_, _>>(),
2177 )
2178 } else {
2179 None
2180 };
2181
2182 let mut values = Vec::<f64>::with_capacity(trimmed.len());
2183 for (i, raw) in trimmed.iter().enumerate() {
2184 let raw = *raw;
2185 let val = match kind {
2186 ColumnKindTag::Continuous => parsed[i].ok_or_else(|| {
2190 String::from(DataError::EncodingFailure {
2191 reason: format!(
2192 "internal: continuous column '{}' lost its parsed value at row {} (col {})",
2193 name,
2194 i + 1,
2195 col_index
2196 ),
2197 })
2198 })?,
2199 ColumnKindTag::Binary => {
2200 let v = parsed[i].ok_or_else(|| {
2201 String::from(DataError::EncodingFailure {
2202 reason: format!(
2203 "internal: binary column '{}' lost its parsed value at row {} (col {})",
2204 name,
2205 i + 1,
2206 col_index
2207 ),
2208 })
2209 })?;
2210 if (v - 0.0).abs() >= 1e-12 && (v - 1.0).abs() >= 1e-12 {
2211 return Err(DataError::SchemaMismatch {
2212 reason: format!(
2213 "column '{}' is binary in schema but row {} has value {}; expected 0 or 1",
2214 name,
2215 i + 1,
2216 v
2217 ),
2218 }
2219 .into());
2220 }
2221 v
2222 }
2223 ColumnKindTag::Categorical => {
2224 let map = level_map.as_ref().ok_or_else(|| {
2225 String::from(DataError::EncodingFailure {
2226 reason: "internal categorical schema map missing".to_string(),
2227 })
2228 })?;
2229 *map.get(raw).ok_or_else(|| {
2230 String::from(DataError::EncodingFailure {
2231 reason: format!(
2232 "internal: level '{}' missing from freshly built map for column '{}' (col {})",
2233 raw, name, col_index
2234 ),
2235 })
2236 })?
2237 }
2238 };
2239 if !val.is_finite() {
2240 return Err(DataError::InvalidValue {
2241 reason: format!("non-finite value at row {}, column '{}'", i + 1, name),
2242 }
2243 .into());
2244 }
2245 values.push(val);
2246 }
2247 Ok((schema, values))
2248}
2249
2250#[cfg(test)]
2251mod tests {
2252 use super::*;
2253
2254 #[test]
2255 fn encode_records_rejects_empty_input() {
2256 let headers = vec!["x".to_string()];
2257 let schema = DataSchema {
2258 columns: vec![SchemaColumn {
2259 name: "x".to_string(),
2260 kind: ColumnKindTag::Continuous,
2261 levels: Vec::new(),
2262 }],
2263 };
2264
2265 let err = encode_recordswith_inferred_schema(headers.clone(), Vec::new())
2266 .expect_err("empty inferred records should error");
2267 assert_eq!(err, "table data cannot be empty");
2268
2269 let err =
2270 encode_recordswith_schema(headers, Vec::new(), &schema, UnseenCategoryPolicy::Error)
2271 .expect_err("empty schema-guided records should error");
2272 assert_eq!(err, "table data cannot be empty");
2273 }
2274
2275 #[test]
2276 fn column_major_matches_record_driven_inferred_encode() {
2277 let headers = vec!["cont".to_string(), "bin".to_string(), "cat".to_string()];
2282 let raw_rows = vec![
2283 vec!["1.5", "0", "a"],
2284 vec!["2.0", "1", "b"],
2285 vec!["-3.25", "1", "a"],
2286 vec!["0.0", "0", "c"],
2287 ];
2288 let records: Vec<StringRecord> = raw_rows
2289 .iter()
2290 .map(|r| StringRecord::from(r.clone()))
2291 .collect();
2292 let record_ds = encode_recordswith_inferred_schema(headers.clone(), records)
2293 .expect("record-driven encode");
2294
2295 for (j, name) in headers.iter().enumerate() {
2296 let column: Vec<&str> = raw_rows.iter().map(|r| r[j]).collect();
2297 let (schema_col, values) =
2298 infer_and_encode_column_major(name, &column, j + 1).expect("column-major encode");
2299 assert_eq!(schema_col.kind, record_ds.schema.columns[j].kind);
2300 assert_eq!(schema_col.levels, record_ds.schema.columns[j].levels);
2301 for (i, v) in values.iter().enumerate() {
2302 assert_eq!(*v, record_ds.values[[i, j]], "row {i} col {name}");
2303 }
2304 }
2305 }
2306
2307 #[test]
2308 fn encode_records_can_encode_unseen_named_categorical_column() {
2309 let schema = DataSchema {
2310 columns: vec![
2311 SchemaColumn {
2312 name: "g".to_string(),
2313 kind: ColumnKindTag::Categorical,
2314 levels: vec!["a".to_string(), "b".to_string()],
2315 },
2316 SchemaColumn {
2317 name: "x".to_string(),
2318 kind: ColumnKindTag::Categorical,
2319 levels: vec!["low".to_string(), "high".to_string()],
2320 },
2321 ],
2322 };
2323 let headers = vec!["g".to_string(), "x".to_string()];
2324 let records = vec![StringRecord::from(vec!["new-group", "low"])];
2325 let policy =
2326 UnseenCategoryPolicy::encode_unknown_for_columns(HashSet::from(["g".to_string()]));
2327
2328 let ds =
2329 encode_recordswith_schema(headers, records, &schema, policy).expect("encoded dataset");
2330
2331 assert_eq!(ds.values[[0, 0]], 2.0);
2332 assert_eq!(ds.values[[0, 1]], 0.0);
2333 }
2334
2335 #[test]
2336 fn numeric_valued_dictionary_column_classifies_and_decodes_as_numeric() {
2337 use arrow::array::{Array, ArrayRef, DictionaryArray, Int8Array, Int64Array};
2347 use arrow::datatypes::{DataType, Int8Type};
2348 use std::sync::Arc;
2349
2350 let keys = Int8Array::from(vec![0i8, 1, 0, 1, 0]);
2352 let dict_values: ArrayRef = Arc::new(Int64Array::from(vec![5i64, 7]));
2353 let dict: DictionaryArray<Int8Type> = DictionaryArray::new(keys, dict_values);
2354
2355 assert!(matches!(dict.data_type(), DataType::Dictionary(_, _)));
2358 assert!(
2359 !parquet_field_is_string(dict.data_type()),
2360 "Dictionary(Int8, Int64) must not be treated as a string column"
2361 );
2362
2363 let str_dict: DictionaryArray<Int8Type> = vec!["a", "b", "a"].into_iter().collect();
2365 assert!(
2366 parquet_field_is_string(str_dict.data_type()),
2367 "Dictionary(Int8, Utf8) must remain a string column"
2368 );
2369
2370 let decoded = decode_parquet_batch_column(&dict, dict.len(), 0, "x", false)
2374 .expect("numeric dictionary column should decode as numeric");
2375 match decoded {
2376 ParquetBatchColumn::Numeric(values) => {
2377 assert_eq!(values, vec![5.0, 7.0, 5.0, 7.0, 5.0]);
2378 }
2379 ParquetBatchColumn::Strings(_) => {
2380 panic!("numeric dictionary column was decoded as strings");
2381 }
2382 }
2383
2384 use arrow::datatypes::{Field, Schema};
2389 use arrow::record_batch::RecordBatch;
2390 use parquet::arrow::ArrowWriter;
2391
2392 let arrow_schema = Arc::new(Schema::new(vec![Field::new(
2393 "x",
2394 dict.data_type().clone(),
2395 false,
2396 )]));
2397 let batch = RecordBatch::try_new(arrow_schema.clone(), vec![Arc::new(dict.clone())])
2398 .expect("record batch with a dictionary numeric column");
2399
2400 let dir = tempfile::tempdir().expect("tempdir");
2401 let path = dir.path().join("dict_numeric.parquet");
2402 {
2403 let file = std::fs::File::create(&path).expect("create parquet");
2404 let mut writer =
2405 ArrowWriter::try_new(file, arrow_schema, None).expect("arrow parquet writer");
2406 writer.write(&batch).expect("write batch");
2407 writer.close().expect("close writer");
2408 }
2409
2410 let inferred =
2413 load_parquet_inferred(&path, &[], &HashSet::new()).expect("inferred parquet load");
2414 assert_eq!(inferred.column_kinds, vec![ColumnKindTag::Continuous]);
2415 assert_eq!(
2416 inferred.values.column(0).to_vec(),
2417 vec![5.0, 7.0, 5.0, 7.0, 5.0]
2418 );
2419
2420 let schema = DataSchema {
2424 columns: vec![SchemaColumn {
2425 name: "x".to_string(),
2426 kind: ColumnKindTag::Continuous,
2427 levels: Vec::new(),
2428 }],
2429 };
2430 let schema_loaded =
2431 load_parquet_with_schema(&path, &schema, UnseenCategoryPolicy::Error, &[])
2432 .expect("dictionary-encoded numeric parquet must load against a Continuous schema");
2433 assert_eq!(schema_loaded.column_kinds, vec![ColumnKindTag::Continuous]);
2434 assert_eq!(
2435 schema_loaded.values.column(0).to_vec(),
2436 vec![5.0, 7.0, 5.0, 7.0, 5.0]
2437 );
2438 }
2439
2440 #[test]
2441 fn encode_records_keeps_unlisted_categorical_columns_strict() {
2442 let schema = DataSchema {
2443 columns: vec![
2444 SchemaColumn {
2445 name: "g".to_string(),
2446 kind: ColumnKindTag::Categorical,
2447 levels: vec!["a".to_string(), "b".to_string()],
2448 },
2449 SchemaColumn {
2450 name: "x".to_string(),
2451 kind: ColumnKindTag::Categorical,
2452 levels: vec!["low".to_string(), "high".to_string()],
2453 },
2454 ],
2455 };
2456 let headers = vec!["g".to_string(), "x".to_string()];
2457 let records = vec![StringRecord::from(vec!["a", "new-level"])];
2458 let policy =
2459 UnseenCategoryPolicy::encode_unknown_for_columns(HashSet::from(["g".to_string()]));
2460
2461 let err = encode_recordswith_schema(headers, records, &schema, policy)
2462 .expect_err("ordinary categorical column should stay strict");
2463
2464 assert!(err.contains("unseen level 'new-level' in categorical column 'x'"));
2465 }
2466
2467 #[test]
2472 fn sentinel_strip_present_returns_rest_and_true() {
2473 let marked = format!("{}{}", CATEGORICAL_CELL_SENTINEL, "hello");
2474 let (rest, found) = strip_categorical_sentinel(&marked);
2475 assert_eq!(rest, "hello");
2476 assert!(found);
2477 }
2478
2479 #[test]
2480 fn sentinel_strip_absent_returns_original_and_false() {
2481 let (rest, found) = strip_categorical_sentinel("hello");
2482 assert_eq!(rest, "hello");
2483 assert!(!found);
2484 }
2485
2486 #[test]
2487 fn sentinel_strip_empty_string_returns_empty_and_false() {
2488 let (rest, found) = strip_categorical_sentinel("");
2489 assert_eq!(rest, "");
2490 assert!(!found);
2491 }
2492
2493 #[test]
2494 fn sentinel_strip_only_sentinel_returns_empty_and_true() {
2495 let marked = CATEGORICAL_CELL_SENTINEL.to_string();
2496 let (rest, found) = strip_categorical_sentinel(&marked);
2497 assert_eq!(rest, "");
2498 assert!(found);
2499 }
2500
2501 #[test]
2506 fn feature_ranges_two_columns() {
2507 let values = ndarray::arr2(&[[1.0_f64, 10.0], [3.0, 20.0], [2.0, 15.0]]);
2508 let ds = EncodedDataset {
2509 headers: vec!["a".to_string(), "b".to_string()],
2510 values,
2511 schema: DataSchema { columns: vec![] },
2512 column_kinds: vec![ColumnKindTag::Continuous, ColumnKindTag::Continuous],
2513 };
2514 let ranges = ds.feature_ranges();
2515 assert_eq!(ranges.len(), 2);
2516 assert_eq!(ranges[0], (1.0, 3.0));
2517 assert_eq!(ranges[1], (10.0, 20.0));
2518 }
2519
2520 #[test]
2521 fn feature_ranges_single_row_min_equals_max() {
2522 let values = ndarray::arr2(&[[5.0_f64, -3.0]]);
2523 let ds = EncodedDataset {
2524 headers: vec!["x".to_string(), "y".to_string()],
2525 values,
2526 schema: DataSchema { columns: vec![] },
2527 column_kinds: vec![ColumnKindTag::Continuous, ColumnKindTag::Continuous],
2528 };
2529 let ranges = ds.feature_ranges();
2530 assert_eq!(ranges[0], (5.0, 5.0));
2531 assert_eq!(ranges[1], (-3.0, -3.0));
2532 }
2533
2534 #[test]
2535 fn feature_ranges_all_nan_defaults_to_zero() {
2536 let values = ndarray::arr2(&[[f64::NAN], [f64::NAN]]);
2537 let ds = EncodedDataset {
2538 headers: vec!["x".to_string()],
2539 values,
2540 schema: DataSchema { columns: vec![] },
2541 column_kinds: vec![ColumnKindTag::Continuous],
2542 };
2543 let ranges = ds.feature_ranges();
2544 assert_eq!(ranges[0], (0.0, 0.0));
2545 }
2546
2547 #[test]
2552 fn column_map_indexes_by_name() {
2553 let values = ndarray::arr2(&[[0.0_f64, 1.0], [2.0, 3.0]]);
2554 let ds = EncodedDataset {
2555 headers: vec!["alpha".to_string(), "beta".to_string()],
2556 values,
2557 schema: DataSchema { columns: vec![] },
2558 column_kinds: vec![ColumnKindTag::Continuous, ColumnKindTag::Continuous],
2559 };
2560 let map = ds.column_map();
2561 assert_eq!(map["alpha"], 0);
2562 assert_eq!(map["beta"], 1);
2563 assert_eq!(map.len(), 2);
2564 }
2565}