1#![allow(unused)]
2
3use crate::data::{
23 Dataset, FeatureMatrix, DataError, FeatureMatrixError,
24 preprocessing::{CategoricalEncoder, detect_categorical_columns},
25};
26use serde::{Deserialize, Serialize};
27use std::collections::HashMap;
28use std::fs::File;
29use std::io::{BufReader, BufWriter};
30use std::path::{Path, PathBuf};
31use thiserror::Error;
32use std::str::FromStr;
33use ndarray::Array2;
34use csv::ReaderBuilder;
35use parquet::arrow::arrow_reader::ArrowReaderBuilder;
36use parquet::file::reader::FileReader;
37use std::sync::Arc;
38use arrow::array::{Array, Float64Array, Int64Array, Float32Array, Int32Array};
39use arrow::datatypes::DataType;
40
41#[derive(Error, Debug)]
46pub enum DataLoaderError {
47 #[error("IO error: {0}")]
49 IoError(#[from] std::io::Error),
50
51 #[error("CSV error: {0}")]
53 CsvError(#[from] csv::Error),
54
55 #[error("JSON error: {0}")]
57 JsonError(#[from] serde_json::Error),
58
59 #[error("Data error: {0}")]
61 DataError(#[from] DataError),
62
63 #[error("Feature matrix error: {0}")]
65 FeatureMatrixError(#[from] FeatureMatrixError),
66
67 #[error("Unsupported format: {format}")]
69 UnsupportedFormat { format: String },
70
71 #[error("Parse error: {field} = {value} ({message})")]
73 ParseError {
74 field: String,
75 value: String,
76 message: String,
77 },
78
79 #[error("Validation error: {0}")]
81 ValidationError(String),
82
83 #[error("Missing column: {0}")]
85 MissingColumn(String),
86
87 #[error("Invalid file: {0}")]
89 InvalidFile(String),
90
91 #[error("Preprocessing error: {0}")]
93 PreprocessingError(String),
94
95 #[error("Parquet error: {0}")]
97 ParquetError(#[from] parquet::errors::ParquetError),
98}
99
100pub type DataLoaderResult<T> = std::result::Result<T, DataLoaderError>;
102
103#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
107pub enum DataFormat {
108 Csv,
110 Json,
112 Parquet,
114 }
116
117impl std::str::FromStr for DataFormat {
118 type Err = DataLoaderError;
119
120 fn from_str(s: &str) -> Result<Self, Self::Err> {
122 match s.to_lowercase().as_str() {
123 "csv" => Ok(DataFormat::Csv),
124 "json" => Ok(DataFormat::Json),
125 "parquet" => Ok(DataFormat::Parquet),
126 _ => Err(DataLoaderError::UnsupportedFormat {
128 format: s.to_string(),
129 }),
130 }
131 }
132}
133
134impl std::fmt::Display for DataFormat {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 match self {
137 DataFormat::Csv => write!(f, "csv"),
138 DataFormat::Json => write!(f, "json"),
139 DataFormat::Parquet => write!(f, "parquet"),
140 }
141 }
142}
143
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct LoadOptions {
151 pub delimiter: u8,
153 pub has_headers: bool,
155 pub target_column: Option<String>,
157 pub feature_columns: Option<Vec<String>>,
159 pub skip_rows: usize,
161 pub max_rows: Option<usize>,
163 pub missing_value: Option<String>,
165 pub encoding: String,
167}
168
169impl Default for LoadOptions {
170 fn default() -> Self {
171 Self {
172 delimiter: b',',
173 has_headers: true,
174 target_column: None,
175 feature_columns: None,
176 skip_rows: 0,
177 max_rows: None,
178 missing_value: Some("NaN".to_string()),
179 encoding: "utf-8".to_string(),
180 }
181 }
182}
183
184impl LoadOptions {
185 pub fn new() -> Self {
187 Self::default()
188 }
189
190 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
192 self.delimiter = delimiter;
193 self
194 }
195
196 pub fn with_headers(mut self, has_headers: bool) -> Self {
198 self.has_headers = has_headers;
199 self
200 }
201
202 pub fn with_target_column(mut self, target_column: &str) -> Self {
204 self.target_column = Some(target_column.to_string());
205 self
206 }
207
208 pub fn with_feature_columns(mut self, feature_columns: &[&str]) -> Self {
210 self.feature_columns = Some(feature_columns.iter().map(|s| s.to_string()).collect());
211 self
212 }
213
214 pub fn with_skip_rows(mut self, skip_rows: usize) -> Self {
216 self.skip_rows = skip_rows;
217 self
218 }
219
220 pub fn with_max_rows(mut self, max_rows: usize) -> Self {
222 self.max_rows = Some(max_rows);
223 self
224 }
225
226 pub fn with_missing_value(mut self, missing_value: &str) -> Self {
228 self.missing_value = Some(missing_value.to_string());
229 self
230 }
231}
232
233pub struct DataLoader {
245 options: LoadOptions,
247}
248
249impl DataLoader {
250 pub fn new() -> DataLoaderResult<Self> {
255 Ok(Self {
256 options: LoadOptions::default(),
257 })
258 }
259
260 pub fn with_options(options: LoadOptions) -> Self {
268 Self { options }
269 }
270
271
272 pub fn load_file<P: AsRef<Path>>(&self, path: P) -> DataLoaderResult<Dataset> {
287 load_from_file(path, &self.options)
288 }
289
290 pub fn load_csv_with_categorical<P: AsRef<Path>>(
316 &self,
317 path: P,
318 target_column: &str,
319 categorical_columns: Option<&[usize]>,
320 categorical_threshold: f64,
321 ) -> DataLoaderResult<Dataset> {
322 let mut reader = ReaderBuilder::new()
323 .delimiter(self.options.delimiter)
324 .has_headers(self.options.has_headers)
325 .from_path(&path)
326 .map_err(DataLoaderError::CsvError)?;
327
328 let headers: Vec<String> = reader
329 .headers()
330 .map_err(DataLoaderError::CsvError)?
331 .iter()
332 .map(|s| s.to_string())
333 .collect();
334
335 let target_col_idx = headers.iter()
337 .position(|h| h == target_column)
338 .ok_or_else(|| DataLoaderError::MissingColumn(
339 format!("Target column '{}' not found in CSV headers. Available columns: {:?}",
340 target_column, headers)
341 ))?;
342
343 let mut string_data = Vec::new();
345 for (row_idx, result) in reader.records().enumerate() {
346 let record = result.map_err(|e| DataLoaderError::CsvError(e))?;
347 let row: Vec<String> = record.iter().map(|s| s.to_string()).collect();
348
349 if row.len() != headers.len() {
350 return Err(DataLoaderError::ValidationError(
351 format!("Row {} has {} columns, expected {} columns. Row: {:?}",
352 row_idx, row.len(), headers.len(), row)
353 ));
354 }
355 string_data.push(row);
356 }
357
358 if string_data.is_empty() {
359 return Err(DataLoaderError::ValidationError(
360 "CSV file contains no data rows".to_string()
361 ));
362 }
363
364 let n_features_total = headers.len();
366 let feature_col_indices: Vec<usize> = (0..n_features_total)
367 .filter(|&i| i != target_col_idx)
368 .collect();
369
370 let feature_data: Vec<Vec<String>> = string_data.iter()
372 .map(|row| {
373 feature_col_indices.iter()
374 .map(|&i| row[i].clone())
375 .collect()
376 })
377 .collect();
378
379 let cat_cols_adjusted = if let Some(cols) = categorical_columns {
381 cols.iter()
382 .filter_map(|&col_idx| {
383 if col_idx == target_col_idx {
384 None } else if col_idx < target_col_idx {
386 Some(col_idx)
387 } else {
388 Some(col_idx - 1)
389 }
390 })
391 .collect::<Vec<_>>()
392 } else {
393 detect_categorical_columns(&feature_data, categorical_threshold)
395 };
396
397 let mut encoder = CategoricalEncoder::new();
399 let encoded_features = encoder.fit_transform(&feature_data, &cat_cols_adjusted)
400 .map_err(|e| DataLoaderError::PreprocessingError(e.to_string()))?;
401
402 let targets: Vec<f64> = string_data.iter()
404 .enumerate()
405 .map(|(row_idx, row)| {
406 row[target_col_idx].parse::<f64>()
407 .map_err(|_| DataLoaderError::ParseError {
408 field: target_column.to_string(),
409 value: row[target_col_idx].clone(),
410 message: format!("Failed to parse target column at row {} as f64", row_idx),
411 })
412 })
413 .collect::<Result<Vec<_>, _>>()?;
414
415 for (i, &target) in targets.iter().enumerate() {
417 if !target.is_finite() {
418 return Err(DataLoaderError::ValidationError(
419 format!("Non-finite target value at row {}: {}", i, target)
420 ));
421 }
422 }
423
424 let n_samples = encoded_features.len();
426 let n_features = encoded_features[0].len();
427
428 let flat_data: Vec<f64> = encoded_features.into_iter().flatten().collect();
429 let array = Array2::from_shape_vec((n_samples, n_features), flat_data)
430 .map_err(|e| DataLoaderError::ValidationError(
431 format!("Failed to create feature matrix: {}", e)
432 ))?;
433
434 let feature_names: Vec<String> = feature_col_indices.iter()
436 .map(|&i| headers[i].clone())
437 .collect();
438
439 let features = FeatureMatrix::with_feature_names(array, feature_names)?;
440
441 Dataset::new(features, targets)
442 .map_err(DataLoaderError::DataError)
443 }
444
445 pub fn load_csv<P: AsRef<Path>>(&self, path: P) -> DataLoaderResult<Dataset> {
449 let target_col = self.options.target_column.as_deref().unwrap_or("");
451 if target_col.is_empty() {
452 load_csv(path, &self.options)
453 } else {
454 self.load_csv_with_categorical(path, target_col, Some(&[]), 0.0)
455 }
456 }
457
458 pub fn load_json<P: AsRef<Path>>(&self, path: P) -> DataLoaderResult<Dataset> {
472 load_json(path, &self.options)
473 }
474
475 pub fn load_parquet<P: AsRef<Path>>(&self, path: P) -> DataLoaderResult<Dataset> {
481 load_parquet(path, &self.options)
482 }
483
484 pub fn save_dataset<P: AsRef<Path>>(
497 &self,
498 dataset: &Dataset,
499 path: P,
500 format: DataFormat,
501 ) -> DataLoaderResult<()> {
502 save_dataset(dataset, path, format)
503 }
504
505 pub fn split_data(
522 &self,
523 dataset: &Dataset,
524 test_size: f64,
525 shuffle: bool,
526 random_state: Option<u64>,
527 ) -> DataLoaderResult<(Dataset, Dataset)> {
528 split_data(dataset, test_size, shuffle, random_state)
529 }
530
531 pub fn cross_validation_splits(
549 &self,
550 dataset: &Dataset,
551 n_splits: usize,
552 shuffle: bool,
553 random_state: Option<u64>,
554 ) -> DataLoaderResult<Vec<(Dataset, Dataset)>> {
555 cross_validation_split(dataset, n_splits, shuffle, random_state)
556 }
557}
558
559impl Default for DataLoader {
560 fn default() -> Self {
561 Self::new().unwrap()
562 }
563}
564
565pub fn load_from_file<P: AsRef<Path>>(
582 path: P,
583 options: &LoadOptions,
584) -> DataLoaderResult<Dataset> {
585 let path = path.as_ref();
586 let extension = path
587 .extension()
588 .and_then(|ext| ext.to_str())
589 .ok_or_else(|| DataLoaderError::InvalidFile("Cannot determine file extension".to_string()))?;
590
591 let format = DataFormat::from_str(extension)?;
592
593 match format {
594 DataFormat::Csv => load_csv(path, options),
595 DataFormat::Json => load_json(path, options),
596 DataFormat::Parquet => load_parquet(path, options),
597 }
598}
599
600pub fn load_csv<P: AsRef<Path>>(
620 path: P,
621 options: &LoadOptions,
622) -> DataLoaderResult<Dataset> {
623 let file = File::open(path)?;
624 let mut reader = csv::ReaderBuilder::new()
625 .delimiter(options.delimiter)
626 .has_headers(options.has_headers)
627 .from_reader(BufReader::new(file));
628
629 if options.skip_rows > 0 {
631 for _ in 0..options.skip_rows {
632 if reader.records().next().is_none() {
633 break;
634 }
635 }
636 }
637
638 let headers: Vec<String> = if options.has_headers {
639 reader
640 .headers()?
641 .iter()
642 .map(|s| s.to_string())
643 .collect()
644 } else {
645 let first_record = reader.records().next();
647 if let Some(record) = first_record {
648 (0..record?.len())
649 .map(|i| format!("feature_{}", i))
650 .collect()
651 } else {
652 return Err(DataLoaderError::ValidationError("Empty CSV file".to_string()));
653 }
654 };
655
656 let mut data = Vec::new();
657 let mut targets = Vec::new();
658 let mut record_count = 0;
659
660 let target_index = if let Some(ref target_col) = options.target_column {
662 headers
663 .iter()
664 .position(|h| h == target_col)
665 .ok_or_else(|| DataLoaderError::MissingColumn(target_col.clone()))?
666 } else {
667 headers.len() - 1
669 };
670
671 let feature_indices: Vec<usize> = if let Some(ref feature_cols) = options.feature_columns {
673 feature_cols
674 .iter()
675 .map(|col| {
676 headers
677 .iter()
678 .position(|h| h == col)
679 .ok_or_else(|| DataLoaderError::MissingColumn(col.clone()))
680 })
681 .collect::<Result<Vec<_>, _>>()?
682 } else {
683 (0..headers.len())
685 .filter(|&i| i != target_index)
686 .collect()
687 };
688
689 for result in reader.records() {
690 if let Some(max_rows) = options.max_rows {
691 if record_count >= max_rows {
692 break;
693 }
694 }
695
696 let record = result?;
697 let mut row = Vec::with_capacity(feature_indices.len());
698
699 for &idx in &feature_indices {
701 let value_str = record.get(idx).unwrap_or("");
702 let value = if let Some(ref missing) = options.missing_value {
703 if value_str == missing {
704 f64::NAN
705 } else {
706 value_str.parse().unwrap_or(f64::NAN)
707 }
708 } else {
709 value_str.parse().unwrap_or(f64::NAN)
710 };
711 row.push(value);
712 }
713
714 let target_str = record.get(target_index).unwrap_or("");
716 let target = if let Some(ref missing) = options.missing_value {
717 if target_str == missing {
718 f64::NAN
719 } else {
720 target_str.parse().unwrap_or(f64::NAN)
721 }
722 } else {
723 target_str.parse().unwrap_or(f64::NAN)
724 };
725
726 if !target.is_nan() {
728 data.push(row);
729 targets.push(target);
730 record_count += 1;
731 }
732 }
733
734 if data.is_empty() {
735 return Err(DataLoaderError::ValidationError("No valid data found".to_string()));
736 }
737
738 let n_samples = data.len();
740 let n_features = feature_indices.len();
741 let flat_data: Vec<f64> = data.into_iter().flatten().collect();
742 let array = ndarray::Array2::from_shape_vec((n_samples, n_features), flat_data)
743 .map_err(|e| DataLoaderError::ValidationError(e.to_string()))?;
744
745 let feature_names: Vec<String> = feature_indices
746 .iter()
747 .map(|&idx| headers[idx].clone())
748 .collect();
749
750 let features = FeatureMatrix::with_feature_names(array, feature_names)?;
751 Dataset::new(features, targets)
752 .map_err(DataLoaderError::DataError)
753}
754
755pub fn load_json<P: AsRef<Path>>(
778 path: P,
779 options: &LoadOptions,
780) -> DataLoaderResult<Dataset> {
781 let file = File::open(path)?;
782 let reader = BufReader::new(file);
783
784 #[derive(Deserialize)]
785 struct JsonRecord {
786 #[serde(flatten)]
787 values: HashMap<String, serde_json::Value>,
788 }
789
790 let records: Vec<JsonRecord> = serde_json::from_reader(reader)?;
791
792 if records.is_empty() {
793 return Err(DataLoaderError::ValidationError("Empty JSON file".to_string()));
794 }
795
796 let target_column = if let Some(ref target_col) = options.target_column {
798 target_col.clone()
799 } else {
800 let possible_targets = ["target", "label", "y", "class"];
802 possible_targets
803 .iter()
804 .find(|&&col| records[0].values.contains_key(col))
805 .map(|s| s.to_string())
806 .unwrap_or_else(|| {
807 records[0]
809 .values
810 .keys()
811 .last()
812 .unwrap()
813 .to_string()
814 })
815 };
816
817 let feature_columns: Vec<String> = if let Some(ref feature_cols) = options.feature_columns {
819 feature_cols.clone()
820 } else {
821 records[0]
822 .values
823 .keys()
824 .filter(|&k| k != &target_column)
825 .cloned()
826 .collect()
827 };
828
829 let mut data = Vec::new();
830 let mut targets = Vec::new();
831 let mut record_count = 0;
832
833 for record in records {
834 if let Some(max_rows) = options.max_rows {
835 if record_count >= max_rows {
836 break;
837 }
838 }
839
840 let mut row = Vec::with_capacity(feature_columns.len());
841
842 for col in &feature_columns {
844 let value = record.values.get(col).and_then(|v| v.as_f64()).unwrap_or(f64::NAN);
845 row.push(value);
846 }
847
848 let target = record
850 .values
851 .get(&target_column)
852 .and_then(|v| v.as_f64())
853 .unwrap_or(f64::NAN);
854
855 if !target.is_nan() {
857 data.push(row);
858 targets.push(target);
859 record_count += 1;
860 }
861 }
862
863 if data.is_empty() {
864 return Err(DataLoaderError::ValidationError("No valid data found".to_string()));
865 }
866
867 let n_samples = data.len();
869 let n_features = feature_columns.len();
870 let flat_data: Vec<f64> = data.into_iter().flatten().collect();
871 let array = ndarray::Array2::from_shape_vec((n_samples, n_features), flat_data)
872 .map_err(|e| DataLoaderError::ValidationError(e.to_string()))?;
873
874 let features = FeatureMatrix::with_feature_names(array, feature_columns)?;
875 Dataset::new(features, targets)
876 .map_err(DataLoaderError::DataError)
877}
878
879
880pub fn load_parquet<P: AsRef<Path>>(
933 path: P,
934 options: &LoadOptions,
935) -> DataLoaderResult<Dataset> {
936 let file = File::open(path)?;
937 let builder = ArrowReaderBuilder::try_new(file)?;
938 let schema = builder.schema().clone();
939
940 let field_names: Vec<String> = schema.fields().iter()
941 .map(|f| f.name().clone())
942 .collect();
943
944 if field_names.is_empty() {
945 return Err(DataLoaderError::ValidationError("Parquet file has no columns".to_string()));
946 }
947
948 let target_column = if let Some(ref target_col) = options.target_column {
950 target_col.clone()
951 } else {
952 let possible_targets = ["target", "label", "y", "class"];
954 possible_targets
955 .iter()
956 .find(|&&col| field_names.contains(&col.to_string()))
957 .map(|s| s.to_string())
958 .unwrap_or_else(|| field_names.last().unwrap().clone())
959 };
960
961 let target_idx = field_names.iter()
962 .position(|h| h == &target_column)
963 .ok_or_else(|| {
964 DataLoaderError::MissingColumn(
965 format!("Target column '{}' not found in Parquet file. Available columns: {:?}",
966 target_column, field_names)
967 )
968 })?;
969
970 let feature_names: Vec<String> = if let Some(ref feature_cols) = options.feature_columns {
972 feature_cols.clone()
973 } else {
974 field_names.iter()
975 .filter(|&name| name != &target_column)
976 .cloned()
977 .collect()
978 };
979
980 let feature_indices: Vec<usize> = feature_names.iter()
982 .map(|name| {
983 field_names.iter()
984 .position(|f| f == name)
985 .ok_or_else(|| DataLoaderError::MissingColumn(name.clone()))
986 })
987 .collect::<Result<Vec<_>, _>>()?;
988
989 let reader = builder.build()?;
991 let mut data: Vec<Vec<f64>> = Vec::new();
992 let mut targets: Vec<f64> = Vec::new();
993
994 for batch_result in reader {
995 let batch = batch_result.map_err(|e| DataLoaderError::PreprocessingError(e.to_string()))?;
996
997 if let Some(max_rows) = options.max_rows {
999 if data.len() >= max_rows {
1000 break;
1001 }
1002 }
1003
1004 let num_rows = batch.num_rows();
1005
1006 let target_column = batch.column(target_idx);
1008
1009 for row_idx in 0..num_rows {
1011 if let Some(max_rows) = options.max_rows {
1012 if data.len() >= max_rows {
1013 break;
1014 }
1015 }
1016
1017 let mut feature_row = Vec::with_capacity(feature_indices.len());
1019 for &feature_idx in &feature_indices {
1020 let array = batch.column(feature_idx);
1021 let value = extract_numeric_value(array, row_idx)?;
1022
1023 let final_value = if !value.is_finite() {
1025 options.missing_value.as_ref()
1026 .and_then(|s| s.parse::<f64>().ok())
1027 .unwrap_or(f64::NAN)
1028 } else {
1029 value
1030 };
1031 feature_row.push(final_value);
1032 }
1033
1034 let target_value = extract_numeric_value(target_column, row_idx)?;
1036
1037 if target_value.is_finite() {
1039 let has_non_finite = feature_row.iter().any(|&x| !x.is_finite());
1041 if !has_non_finite {
1042 data.push(feature_row);
1043 targets.push(target_value);
1044 }
1045 }
1046 }
1047 }
1048
1049 if data.is_empty() {
1050 return Err(DataLoaderError::ValidationError(
1051 "No valid data found in Parquet file".to_string()
1052 ));
1053 }
1054
1055 let n_samples = data.len();
1057 let n_features = feature_names.len();
1058
1059 let flat_data: Vec<f64> = data.into_iter().flatten().collect();
1060 let array = Array2::from_shape_vec((n_samples, n_features), flat_data)
1061 .map_err(|e| DataLoaderError::ValidationError(
1062 format!("Failed to create feature matrix: {}", e)
1063 ))?;
1064
1065 let features = FeatureMatrix::with_feature_names(array, feature_names)?;
1066 Dataset::new(features, targets)
1067 .map_err(DataLoaderError::DataError)
1068}
1069
1070fn extract_numeric_value(array: &dyn Array, row_idx: usize) -> DataLoaderResult<f64> {
1097
1098 Ok(match array.data_type() {
1099 DataType::Float64 => {
1100 let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
1101 arr.value(row_idx)
1102 },
1103 DataType::Float32 => {
1104 let arr = array.as_any().downcast_ref::<Float32Array>().unwrap();
1105 arr.value(row_idx) as f64
1106 },
1107 DataType::Int64 => {
1108 let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
1109 arr.value(row_idx) as f64
1110 },
1111 DataType::Int32 => {
1112 let arr = array.as_any().downcast_ref::<Int32Array>().unwrap();
1113 arr.value(row_idx) as f64
1114 },
1115 _ => {
1116 if let Some(arr) = array.as_any().downcast_ref::<Float64Array>() {
1118 arr.value(row_idx)
1119 } else if let Some(arr) = array.as_any().downcast_ref::<Float32Array>() {
1120 arr.value(row_idx) as f64
1121 } else if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
1122 arr.value(row_idx) as f64
1123 } else if let Some(arr) = array.as_any().downcast_ref::<Int32Array>() {
1124 arr.value(row_idx) as f64
1125 } else {
1126 f64::NAN
1128 }
1129 }
1130 })
1131}
1132
1133pub fn save_dataset<P: AsRef<Path>>(
1151 dataset: &Dataset,
1152 path: P,
1153 format: DataFormat,
1154) -> DataLoaderResult<()> {
1155 match format {
1156 DataFormat::Csv => save_dataset_csv(dataset, path),
1157 DataFormat::Json => save_dataset_json(dataset, path),
1158 DataFormat::Parquet => save_dataset_parquet(dataset, path),
1159 }
1160}
1161
1162fn save_dataset_csv<P: AsRef<Path>>(dataset: &Dataset, path: P) -> DataLoaderResult<()> {
1172 let file = File::create(path)?;
1173 let mut writer = csv::Writer::from_writer(BufWriter::new(file));
1174
1175 let mut headers: Vec<String> = dataset.features().feature_names().to_vec();
1177 headers.push("target".to_string());
1178 writer.write_record(&headers)?;
1179
1180 for i in 0..dataset.n_samples() {
1182 let sample = dataset.features().get_sample(i)?;
1183 let target = dataset.targets()[i];
1184
1185 let mut record: Vec<String> = sample.iter().map(|x| x.to_string()).collect();
1186 record.push(target.to_string());
1187
1188 writer.write_record(&record)?;
1189 }
1190
1191 writer.flush()?;
1192 Ok(())
1193}
1194
1195fn save_dataset_json<P: AsRef<Path>>(dataset: &Dataset, path: P) -> DataLoaderResult<()> {
1210 use serde_json::Value;
1211
1212 let mut records = Vec::new();
1213
1214 for i in 0..dataset.n_samples() {
1215 let sample = dataset.features().get_sample(i)?;
1216 let target = dataset.targets()[i];
1217
1218 let mut record = serde_json::Map::new();
1219
1220 for (j, &value) in sample.iter().enumerate() {
1222 let feature_name = dataset.features().feature_names()[j].clone();
1223 record.insert(feature_name, Value::from(value));
1224 }
1225
1226 record.insert("target".to_string(), Value::from(target));
1228
1229 records.push(Value::Object(record));
1230 }
1231
1232 let file = File::create(path)?;
1233 serde_json::to_writer_pretty(BufWriter::new(file), &records)?;
1234 Ok(())
1235}
1236
1237fn save_dataset_parquet<P: AsRef<Path>>(_dataset: &Dataset, _path: P) -> DataLoaderResult<()> {
1242 Err(DataLoaderError::UnsupportedFormat {
1244 format: "parquet".to_string(),
1245 })
1246}
1247
1248pub fn split_data(
1267 dataset: &Dataset,
1268 test_size: f64,
1269 shuffle: bool,
1270 random_state: Option<u64>,
1271) -> DataLoaderResult<(Dataset, Dataset)> {
1272 if !(0.0..1.0).contains(&test_size) {
1273 return Err(DataLoaderError::ValidationError(
1274 "test_size must be between 0 and 1".to_string(),
1275 ));
1276 }
1277
1278 dataset
1279 .train_test_split(test_size, shuffle)
1280 .map_err(DataLoaderError::DataError)
1281}
1282
1283pub fn cross_validation_split(
1301 dataset: &Dataset,
1302 n_splits: usize,
1303 shuffle: bool,
1304 _random_state: Option<u64>,
1305) -> DataLoaderResult<Vec<(Dataset, Dataset)>> {
1306 if n_splits < 2 {
1307 return Err(DataLoaderError::ValidationError(
1308 "n_splits must be at least 2".to_string(),
1309 ));
1310 }
1311
1312 if dataset.n_samples() < n_splits {
1313 return Err(DataLoaderError::ValidationError(
1314 "Number of samples must be at least n_splits".to_string(),
1315 ));
1316 }
1317
1318 let mut splits = Vec::new();
1319 let fold_size = dataset.n_samples() / n_splits;
1320
1321 let mut indices: Vec<usize> = (0..dataset.n_samples()).collect();
1323 if shuffle {
1324 use rand::seq::SliceRandom;
1325 use rand::thread_rng;
1326 indices.shuffle(&mut thread_rng());
1327 }
1328
1329 for fold in 0..n_splits {
1330 let test_start = fold * fold_size;
1331 let test_end = if fold == n_splits - 1 {
1332 dataset.n_samples()
1333 } else {
1334 (fold + 1) * fold_size
1335 };
1336
1337 let test_indices: Vec<usize> = indices[test_start..test_end].to_vec();
1338 let train_indices: Vec<usize> = indices[0..test_start]
1339 .iter()
1340 .chain(&indices[test_end..])
1341 .cloned()
1342 .collect();
1343
1344 let train_dataset = dataset.select_samples(&train_indices)?;
1345 let test_dataset = dataset.select_samples(&test_indices)?;
1346
1347 splits.push((train_dataset, test_dataset));
1348 }
1349
1350 Ok(splits)
1351}
1352