gbrt_rs/io/
data_loader.rs

1#![allow(unused)]
2
3//! Data Loading and I/O Operations
4//! 
5//! This module provides comprehensive data loading capabilities for multiple formats
6//! with built-in preprocessing support:
7//! 
8//! - **CSV**: Full-featured CSV parsing with headers, delimiters, missing value handling
9//! - **JSON**: Structured JSON records with automatic schema inference
10//! - **Parquet**: Placeholder for Parquet support (requires additional dependencies)
11//! - **Categorical Encoding**: Automatic detection and encoding of categorical features
12//! - **Flexible Target Selection**: Specify target column by name or use defaults
13//! - **Train/Test Splitting**: Built-in splitting with shuffling and stratification support
14//! - **Cross-Validation**: K-fold cross-validation split generation
15//! 
16//! # DataLoader Architecture
17//! 
18//! The [`DataLoader`] struct uses a builder pattern with [`LoadOptions`] for configuration.
19//! It provides both high-level convenience methods and low-level control over the loading process.
20//!
21
22use crate::data::{
23    Dataset, FeatureMatrix, DataError, FeatureMatrixError,
24    preprocessing::{CategoricalEncoder, detect_categorical_columns},
25};
26use serde::{Deserialize, Serialize};
27use std::collections::HashMap;
28use std::fs::File;
29use std::io::{BufReader, BufWriter};
30use std::path::{Path, PathBuf};
31use thiserror::Error;
32use std::str::FromStr;
33use ndarray::Array2;
34use csv::ReaderBuilder;
35use parquet::arrow::arrow_reader::ArrowReaderBuilder;
36use parquet::file::reader::FileReader;
37use std::sync::Arc;
38use arrow::array::{Array, Float64Array, Int64Array, Float32Array, Int32Array};
39use arrow::datatypes::DataType;
40
41/// Errors that can occur during data loading operations.
42/// 
43/// This error type covers all failure modes: file I/O, parsing, validation,
44/// and preprocessing errors.
45#[derive(Error, Debug)]
46pub enum DataLoaderError {
47    /// File system I/O error.
48    #[error("IO error: {0}")]
49    IoError(#[from] std::io::Error),
50
51    /// CSV parsing error.
52    #[error("CSV error: {0}")]
53    CsvError(#[from] csv::Error),
54
55    /// JSON deserialization error.
56    #[error("JSON error: {0}")]
57    JsonError(#[from] serde_json::Error),
58
59    /// Data validation error (shape, NaN values, etc.).
60    #[error("Data error: {0}")]
61    DataError(#[from] DataError),
62
63    /// Feature matrix construction error.
64    #[error("Feature matrix error: {0}")]
65    FeatureMatrixError(#[from] FeatureMatrixError),
66
67    /// Unsupported file format.
68    #[error("Unsupported format: {format}")]
69    UnsupportedFormat { format: String },
70
71    /// Value parsing error with context.
72    #[error("Parse error: {field} = {value} ({message})")]
73    ParseError {
74        field: String,
75        value: String,
76        message: String,
77    },
78
79    /// Data validation error with custom message.
80    #[error("Validation error: {0}")]
81    ValidationError(String),
82
83    /// Requested column not found in data.
84    #[error("Missing column: {0}")]
85    MissingColumn(String),
86
87    /// File format or structure is invalid.
88    #[error("Invalid file: {0}")]
89    InvalidFile(String),
90
91    /// Categorical encoding or preprocessing failure.
92    #[error("Preprocessing error: {0}")]
93    PreprocessingError(String),
94
95    /// Parquet file reading error.
96    #[error("Parquet error: {0}")]
97    ParquetError(#[from] parquet::errors::ParquetError),
98}
99
100/// Result type for data loader operations.
101pub type DataLoaderResult<T> = std::result::Result<T, DataLoaderError>;
102
103/// Supported data file formats for loading and saving.
104/// 
105/// Format is auto-detected from file extension when using [`load_from_file()`].
106#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
107pub enum DataFormat {
108    /// Comma-separated values with headers.
109    Csv,
110    /// JSON array of objects.
111    Json,
112    /// Apache Parquet columnar format (placeholder).
113    Parquet,
114    // Arrow, // Future support
115}
116
117impl std::str::FromStr for DataFormat {
118    type Err = DataLoaderError;
119
120    /// Parses DataFormat from string (file extension).
121    fn from_str(s: &str) -> Result<Self, Self::Err> {
122        match s.to_lowercase().as_str() {
123            "csv" => Ok(DataFormat::Csv),
124            "json" => Ok(DataFormat::Json),
125            "parquet" => Ok(DataFormat::Parquet),
126            // "arrow" => Ok(DataFormat::Arrow),
127            _ => Err(DataLoaderError::UnsupportedFormat {
128                format: s.to_string(),
129            }),
130        }
131    }
132}
133
134impl std::fmt::Display for DataFormat {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        match self {
137            DataFormat::Csv => write!(f, "csv"),
138            DataFormat::Json => write!(f, "json"),
139            DataFormat::Parquet => write!(f, "parquet"),
140        }
141    }
142}
143
144
145/// Configuration options for data loading.
146/// 
147/// Uses builder pattern for flexible configuration of CSV/JSON parsing,
148/// column selection, and preprocessing options.
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct LoadOptions {
151    /// Field delimiter for CSV files (default: comma).
152    pub delimiter: u8,
153    /// Whether first row contains headers (default: true).
154    pub has_headers: bool,
155    /// Name of target column to predict.
156    pub target_column: Option<String>,
157    /// List of feature columns to use (None = all except target).
158    pub feature_columns: Option<Vec<String>>,
159    /// Number of rows to skip at start of file.
160    pub skip_rows: usize,
161    /// Maximum number of rows to read (None = all).
162    pub max_rows: Option<usize>,
163    /// String representation of missing values (default: "NaN").
164    pub missing_value: Option<String>,
165    /// File encoding (default: "utf-8").
166    pub encoding: String,
167}
168
169impl Default for LoadOptions {
170    fn default() -> Self {
171        Self {
172            delimiter: b',',
173            has_headers: true,
174            target_column: None,
175            feature_columns: None,
176            skip_rows: 0,
177            max_rows: None,
178            missing_value: Some("NaN".to_string()),
179            encoding: "utf-8".to_string(),
180        }
181    }
182}
183
184impl LoadOptions {
185    /// Creates new LoadOptions with default values.
186    pub fn new() -> Self {
187        Self::default()
188    }
189
190    /// Sets CSV delimiter character.
191    pub fn with_delimiter(mut self, delimiter: u8) -> Self {
192        self.delimiter = delimiter;
193        self
194    }
195
196    /// Sets whether CSV has header row.
197    pub fn with_headers(mut self, has_headers: bool) -> Self {
198        self.has_headers = has_headers;
199        self
200    }
201
202    /// Sets target column by name.
203    pub fn with_target_column(mut self, target_column: &str) -> Self {
204        self.target_column = Some(target_column.to_string());
205        self
206    }
207
208    /// Sets explicit feature columns (excludes others).
209    pub fn with_feature_columns(mut self, feature_columns: &[&str]) -> Self {
210        self.feature_columns = Some(feature_columns.iter().map(|s| s.to_string()).collect());
211        self
212    }
213
214    /// Sets number of rows to skip at start.
215    pub fn with_skip_rows(mut self, skip_rows: usize) -> Self {
216        self.skip_rows = skip_rows;
217        self
218    }
219
220    /// Sets maximum rows to read.
221    pub fn with_max_rows(mut self, max_rows: usize) -> Self {
222        self.max_rows = Some(max_rows);
223        self
224    }
225
226    /// Sets string representation for missing values.
227    pub fn with_missing_value(mut self, missing_value: &str) -> Self {
228        self.missing_value = Some(missing_value.to_string());
229        self
230    }
231}
232
233/// Main data loader for various file formats with preprocessing capabilities.
234/// 
235/// [`DataLoader`] provides a unified interface for loading datasets from
236/// multiple formats with built-in preprocessing, categorical encoding, and
237/// validation.
238/// 
239/// # Usage Patterns
240/// 
241/// - **Simple**: Use `new()` with defaults
242/// - **Configured**: Use `with_options()` for custom settings
243/// - **Advanced**: Use format-specific methods for full control
244pub struct DataLoader {
245    /// Configuration options for loading behavior.
246    options: LoadOptions,
247}
248
249impl DataLoader {
250    /// Creates a new DataLoader with default options.
251    /// 
252    /// # Returns
253    /// `Ok(DataLoader)` if creation succeeds (always succeeds currently)
254    pub fn new() -> DataLoaderResult<Self> {
255        Ok(Self {
256            options: LoadOptions::default(),
257        })
258    }
259
260    /// Creates a new DataLoader with custom options.
261    ///
262    /// # Parameters
263    /// - `options`: LoadOptions configuration
264    ///
265    /// # Returns
266    /// Configured DataLoader instance
267    pub fn with_options(options: LoadOptions) -> Self {
268        Self { options }
269    }
270
271
272    /// Loads dataset from file with auto-detected format.
273    /// 
274    /// Format is determined from file extension (.csv, .json, .parquet).
275    /// 
276    /// # Parameters
277    /// - `path`: Path to data file
278    /// 
279    /// # Returns
280    /// Loaded dataset
281    /// 
282    /// # Errors
283    /// - `DataLoaderError::InvalidFile` if extension cannot be determined
284    /// - `DataLoaderError::UnsupportedFormat` if format is not supported
285    /// - Format-specific errors from underlying loaders
286    pub fn load_file<P: AsRef<Path>>(&self, path: P) -> DataLoaderResult<Dataset> {
287        load_from_file(path, &self.options)
288    }
289
290    /// Loads dataset from CSV file with automatic categorical encoding.
291    /// 
292    /// This is the recommended method for loading CSV data. It handles:
293    /// - Automatic target column extraction
294    /// - Categorical feature detection and encoding
295    /// - Missing value and validation checks
296    /// 
297    /// # Parameters
298    /// - `path`: Path to CSV file
299    /// - `target_column`: Name of target column
300    /// - `categorical_columns`: Optional explicit categorical column indices (None = auto-detect)
301    /// - `categorical_threshold`: Ratio threshold for auto-detecting categorical columns (0.0-1.0)
302    /// 
303    /// # Returns
304    /// Dataset with encoded features and extracted targets
305    /// 
306    /// # Errors
307    /// - `DataLoaderError::MissingColumn` if target column not found
308    /// - `DataLoaderError::ValidationError` if data is malformed
309    /// - `DataLoaderError::PreprocessingError` if encoding fails
310    /// - `DataLoaderError::ParseError` if target values cannot be parsed
311    /// 
312    /// # Categorical Detection
313    /// If `categorical_columns` is `None`, automatically detects categorical columns
314    /// where unique values / total samples < `categorical_threshold`.
315    pub fn load_csv_with_categorical<P: AsRef<Path>>(
316        &self,
317        path: P,
318        target_column: &str,
319        categorical_columns: Option<&[usize]>,
320        categorical_threshold: f64,
321    ) -> DataLoaderResult<Dataset> {
322        let mut reader = ReaderBuilder::new()
323            .delimiter(self.options.delimiter)
324            .has_headers(self.options.has_headers)
325            .from_path(&path)
326            .map_err(DataLoaderError::CsvError)?;
327        
328        let headers: Vec<String> = reader
329            .headers()
330            .map_err(DataLoaderError::CsvError)?
331            .iter()
332            .map(|s| s.to_string())
333            .collect();
334        
335        // Find target column index
336        let target_col_idx = headers.iter()
337            .position(|h| h == target_column)
338            .ok_or_else(|| DataLoaderError::MissingColumn(
339                format!("Target column '{}' not found in CSV headers. Available columns: {:?}", 
340                       target_column, headers)
341            ))?;
342        
343        // Read all data as strings
344        let mut string_data = Vec::new();
345        for (row_idx, result) in reader.records().enumerate() {
346            let record = result.map_err(|e| DataLoaderError::CsvError(e))?;
347            let row: Vec<String> = record.iter().map(|s| s.to_string()).collect();
348            
349            if row.len() != headers.len() {
350                return Err(DataLoaderError::ValidationError(
351                    format!("Row {} has {} columns, expected {} columns. Row: {:?}", 
352                           row_idx, row.len(), headers.len(), row)
353                ));
354            }
355            string_data.push(row);
356        }
357        
358        if string_data.is_empty() {
359            return Err(DataLoaderError::ValidationError(
360                "CSV file contains no data rows".to_string()
361            ));
362        }
363        
364        // Determine feature columns (exclude target)
365        let n_features_total = headers.len();
366        let feature_col_indices: Vec<usize> = (0..n_features_total)
367            .filter(|&i| i != target_col_idx)
368            .collect();
369        
370        // Extract feature data (excluding target)
371        let feature_data: Vec<Vec<String>> = string_data.iter()
372            .map(|row| {
373                feature_col_indices.iter()
374                    .map(|&i| row[i].clone())
375                    .collect()
376            })
377            .collect();
378        
379        // Adjust categorical column indices for target exclusion
380        let cat_cols_adjusted = if let Some(cols) = categorical_columns {
381            cols.iter()
382                .filter_map(|&col_idx| {
383                    if col_idx == target_col_idx {
384                        None // Skip target column
385                    } else if col_idx < target_col_idx {
386                        Some(col_idx)
387                    } else {
388                        Some(col_idx - 1)
389                    }
390                })
391                .collect::<Vec<_>>()
392        } else {
393            // Auto-detect on feature columns only
394            detect_categorical_columns(&feature_data, categorical_threshold)
395        };
396        
397        // Encode categorical features
398        let mut encoder = CategoricalEncoder::new();
399        let encoded_features = encoder.fit_transform(&feature_data, &cat_cols_adjusted)
400            .map_err(|e| DataLoaderError::PreprocessingError(e.to_string()))?;
401        
402        // Extract and validate targets
403        let targets: Vec<f64> = string_data.iter()
404            .enumerate()
405            .map(|(row_idx, row)| {
406                row[target_col_idx].parse::<f64>()
407                    .map_err(|_| DataLoaderError::ParseError {
408                        field: target_column.to_string(),
409                        value: row[target_col_idx].clone(),
410                        message: format!("Failed to parse target column at row {} as f64", row_idx),
411                    })
412            })
413            .collect::<Result<Vec<_>, _>>()?;
414        
415        // Check for NaN/inf in targets
416        for (i, &target) in targets.iter().enumerate() {
417            if !target.is_finite() {
418                return Err(DataLoaderError::ValidationError(
419                    format!("Non-finite target value at row {}: {}", i, target)
420                ));
421            }
422        }
423        
424        // Create FeatureMatrix
425        let n_samples = encoded_features.len();
426        let n_features = encoded_features[0].len();
427        
428        let flat_data: Vec<f64> = encoded_features.into_iter().flatten().collect();
429        let array = Array2::from_shape_vec((n_samples, n_features), flat_data)
430            .map_err(|e| DataLoaderError::ValidationError(
431                format!("Failed to create feature matrix: {}", e)
432            ))?;
433        
434        // Get feature names (excluding target)
435        let feature_names: Vec<String> = feature_col_indices.iter()
436            .map(|&i| headers[i].clone())
437            .collect();
438        
439        let features = FeatureMatrix::with_feature_names(array, feature_names)?;
440        
441        Dataset::new(features, targets)
442            .map_err(DataLoaderError::DataError)
443    }
444
445    /// Original load_csv method - delegates to categorical version with no categorical columns.
446    /// 
447    /// Maintains backward compatibility for simple CSV loading.    
448    pub fn load_csv<P: AsRef<Path>>(&self, path: P) -> DataLoaderResult<Dataset> {
449        // Default to last column as target, no categorical encoding for backward compatibility
450        let target_col = self.options.target_column.as_deref().unwrap_or("");
451        if target_col.is_empty() {
452            load_csv(path, &self.options)
453        } else {
454            self.load_csv_with_categorical(path, target_col, Some(&[]), 0.0)
455        }
456    }
457
458    /// Loads dataset from JSON file.
459    /// 
460    /// Expects JSON array of objects where each object represents a sample.
461    /// 
462    /// # Parameters
463    /// - `path`: Path to JSON file
464    /// 
465    /// # Returns
466    /// Loaded dataset
467    /// 
468    /// # Errors
469    /// - `DataLoaderError::JsonError` if JSON parsing fails
470    /// - `DataLoaderError::ValidationError` if data is malformed
471    pub fn load_json<P: AsRef<Path>>(&self, path: P) -> DataLoaderResult<Dataset> {
472        load_json(path, &self.options)
473    }
474
475    /// Loads dataset from Parquet file (placeholder).
476    /// 
477    /// # Note
478    /// This method currently returns `UnsupportedFormat` error. To enable Parquet
479    /// support, add the `parquet` crate as a dependency and implement the logic.
480    pub fn load_parquet<P: AsRef<Path>>(&self, path: P) -> DataLoaderResult<Dataset> {
481        load_parquet(path, &self.options)
482    }
483
484    /// Saves dataset to file in specified format.
485    /// 
486    /// # Parameters
487    /// - `dataset`: Dataset to save
488    /// - `path`: Output file path
489    /// - `format`: DataFormat to use
490    /// 
491    /// # Returns
492    /// `Ok(())` on successful save
493    /// 
494    /// # Errors
495    /// - Format-specific save errors
496    pub fn save_dataset<P: AsRef<Path>>(
497        &self,
498        dataset: &Dataset,
499        path: P,
500        format: DataFormat,
501    ) -> DataLoaderResult<()> {
502        save_dataset(dataset, path, format)
503    }
504
505    /// Splits dataset into training and testing sets.
506    /// 
507    /// Convenience method that delegates to [`split_data()`] function.
508    /// 
509    /// # Parameters
510    /// - `dataset`: Dataset to split
511    /// - `test_size`: Fraction of samples for test set (0.0-1.0)
512    /// - `shuffle`: Whether to randomize sample order
513    /// - `random_state`: Optional seed for reproducible shuffling
514    /// 
515    /// # Returns
516    /// Tuple of (train_dataset, test_dataset)
517    /// 
518    /// # Errors
519    /// - `DataLoaderError::ValidationError` if test_size not in (0, 1)
520    /// - `DataLoaderError::ValidationError` if dataset too small
521    pub fn split_data(
522        &self,
523        dataset: &Dataset,
524        test_size: f64,
525        shuffle: bool,
526        random_state: Option<u64>,
527    ) -> DataLoaderResult<(Dataset, Dataset)> {
528        split_data(dataset, test_size, shuffle, random_state)
529    }
530
531    /// Creates k-fold cross-validation splits.
532    /// 
533    /// Generates `n_splits` train/test dataset pairs for cross-validation.
534    /// Each fold uses a different portion as test set.
535    /// 
536    /// # Parameters
537    /// - `dataset`: Dataset to split
538    /// - `n_splits`: Number of folds (k)
539    /// - `shuffle`: Whether to shuffle before splitting
540    /// - `random_state`: Optional seed for reproducible shuffling
541    /// 
542    /// # Returns
543    /// Vector of (train_dataset, test_dataset) tuples
544    /// 
545    /// # Errors
546    /// - `DataLoaderError::ValidationError` if n_splits < 2
547    /// - `DataLoaderError::ValidationError` if dataset too small
548    pub fn cross_validation_splits(
549        &self,
550        dataset: &Dataset,
551        n_splits: usize,
552        shuffle: bool,
553        random_state: Option<u64>,
554    ) -> DataLoaderResult<Vec<(Dataset, Dataset)>> {
555        cross_validation_split(dataset, n_splits, shuffle, random_state)
556    }
557}
558
559impl Default for DataLoader {
560    fn default() -> Self {
561        Self::new().unwrap()
562    }
563}
564
565// ============================================================================
566// Standalone Data Loading Functions
567// ============================================================================
568
569/// Loads dataset from file with auto-detected format (CSV, JSON, Parquet).
570/// 
571/// # Parameters
572/// - `path`: Path to data file
573/// - `options`: LoadOptions configuration
574/// 
575/// # Returns
576/// Loaded dataset
577/// 
578/// # Errors
579/// - `DataLoaderError::InvalidFile` if extension cannot be determined
580/// - `DataLoaderError::UnsupportedFormat` if format is not supported
581pub fn load_from_file<P: AsRef<Path>>(
582    path: P,
583    options: &LoadOptions,
584) -> DataLoaderResult<Dataset> {
585    let path = path.as_ref();
586    let extension = path
587        .extension()
588        .and_then(|ext| ext.to_str())
589        .ok_or_else(|| DataLoaderError::InvalidFile("Cannot determine file extension".to_string()))?;
590
591    let format = DataFormat::from_str(extension)?;
592
593    match format {
594        DataFormat::Csv => load_csv(path, options),
595        DataFormat::Json => load_json(path, options),
596        DataFormat::Parquet => load_parquet(path, options),
597    }
598}
599
600/// Loads dataset from CSV file (legacy implementation).
601/// 
602/// This is the core CSV loading function used by DataLoader. It handles:
603/// - Custom delimiters and headers
604/// - Missing value substitution
605/// - Target and feature column selection
606/// - Target parsing and validation
607/// 
608/// # Parameters
609/// - `path`: Path to CSV file
610/// - `options`: LoadOptions configuration
611/// 
612/// # Returns
613/// Dataset with extracted features and targets
614/// 
615/// # Errors
616/// - All CSV-related errors from csv crate
617/// - `DataLoaderError::MissingColumn` if target or features not found
618/// - `DataLoaderError::ValidationError` for malformed data
619pub fn load_csv<P: AsRef<Path>>(
620    path: P,
621    options: &LoadOptions,
622) -> DataLoaderResult<Dataset> {
623    let file = File::open(path)?;
624    let mut reader = csv::ReaderBuilder::new()
625        .delimiter(options.delimiter)
626        .has_headers(options.has_headers)
627        .from_reader(BufReader::new(file));
628
629    // Skip rows if specified
630    if options.skip_rows > 0 {
631        for _ in 0..options.skip_rows {
632            if reader.records().next().is_none() {
633                break;
634            }
635        }
636    }
637
638    let headers: Vec<String> = if options.has_headers {
639        reader
640            .headers()?
641            .iter()
642            .map(|s| s.to_string())
643            .collect()
644    } else {
645        // Generate default headers
646        let first_record = reader.records().next();
647        if let Some(record) = first_record {
648            (0..record?.len())
649                .map(|i| format!("feature_{}", i))
650                .collect()
651        } else {
652            return Err(DataLoaderError::ValidationError("Empty CSV file".to_string()));
653        }
654    };
655
656    let mut data = Vec::new();
657    let mut targets = Vec::new();
658    let mut record_count = 0;
659
660    // Determine target column index
661    let target_index = if let Some(ref target_col) = options.target_column {
662        headers
663            .iter()
664            .position(|h| h == target_col)
665            .ok_or_else(|| DataLoaderError::MissingColumn(target_col.clone()))?
666    } else {
667        // Use last column as target by default
668        headers.len() - 1
669    };
670
671    // Determine feature columns
672    let feature_indices: Vec<usize> = if let Some(ref feature_cols) = options.feature_columns {
673        feature_cols
674            .iter()
675            .map(|col| {
676                headers
677                    .iter()
678                    .position(|h| h == col)
679                    .ok_or_else(|| DataLoaderError::MissingColumn(col.clone()))
680            })
681            .collect::<Result<Vec<_>, _>>()?
682    } else {
683        // Use all columns except target
684        (0..headers.len())
685            .filter(|&i| i != target_index)
686            .collect()
687    };
688
689    for result in reader.records() {
690        if let Some(max_rows) = options.max_rows {
691            if record_count >= max_rows {
692                break;
693            }
694        }
695
696        let record = result?;
697        let mut row = Vec::with_capacity(feature_indices.len());
698
699        // Extract features
700        for &idx in &feature_indices {
701            let value_str = record.get(idx).unwrap_or("");
702            let value = if let Some(ref missing) = options.missing_value {
703                if value_str == missing {
704                    f64::NAN
705                } else {
706                    value_str.parse().unwrap_or(f64::NAN)
707                }
708            } else {
709                value_str.parse().unwrap_or(f64::NAN)
710            };
711            row.push(value);
712        }
713
714        // Extract target
715        let target_str = record.get(target_index).unwrap_or("");
716        let target = if let Some(ref missing) = options.missing_value {
717            if target_str == missing {
718                f64::NAN
719            } else {
720                target_str.parse().unwrap_or(f64::NAN)
721            }
722        } else {
723            target_str.parse().unwrap_or(f64::NAN)
724        };
725
726        // Skip rows with NaN targets
727        if !target.is_nan() {
728            data.push(row);
729            targets.push(target);
730            record_count += 1;
731        }
732    }
733
734    if data.is_empty() {
735        return Err(DataLoaderError::ValidationError("No valid data found".to_string()));
736    }
737
738    // Convert to FeatureMatrix
739    let n_samples = data.len();
740    let n_features = feature_indices.len();
741    let flat_data: Vec<f64> = data.into_iter().flatten().collect();
742    let array = ndarray::Array2::from_shape_vec((n_samples, n_features), flat_data)
743        .map_err(|e| DataLoaderError::ValidationError(e.to_string()))?;
744
745    let feature_names: Vec<String> = feature_indices
746        .iter()
747        .map(|&idx| headers[idx].clone())
748        .collect();
749
750    let features = FeatureMatrix::with_feature_names(array, feature_names)?;
751    Dataset::new(features, targets)
752        .map_err(DataLoaderError::DataError)
753}
754
755/// Loads dataset from JSON file.
756/// 
757/// Expects JSON array of objects where each object represents a sample.
758/// 
759/// # Parameters
760/// - `path`: Path to JSON file
761/// - `options`: LoadOptions configuration
762/// 
763/// # Returns
764/// Loaded dataset
765/// 
766/// # Errors
767/// - `DataLoaderError::JsonError` if JSON parsing fails
768/// - `DataLoaderError::ValidationError` if data is malformed
769/// 
770/// # JSON Format
771/// ```json
772/// [
773///   {"feature1": 1.0, "feature2": 2.0, "target": 0.0},
774///   {"feature1": 3.0, "feature2": 4.0, "target": 1.0}
775/// ]
776/// ```
777pub fn load_json<P: AsRef<Path>>(
778    path: P,
779    options: &LoadOptions,
780) -> DataLoaderResult<Dataset> {
781    let file = File::open(path)?;
782    let reader = BufReader::new(file);
783
784    #[derive(Deserialize)]
785    struct JsonRecord {
786        #[serde(flatten)]
787        values: HashMap<String, serde_json::Value>,
788    }
789
790    let records: Vec<JsonRecord> = serde_json::from_reader(reader)?;
791
792    if records.is_empty() {
793        return Err(DataLoaderError::ValidationError("Empty JSON file".to_string()));
794    }
795
796    // Determine target column
797    let target_column = if let Some(ref target_col) = options.target_column {
798        target_col.clone()
799    } else {
800        // Try to find a common target column name
801        let possible_targets = ["target", "label", "y", "class"];
802        possible_targets
803            .iter()
804            .find(|&&col| records[0].values.contains_key(col))
805            .map(|s| s.to_string())
806            .unwrap_or_else(|| {
807                // Use the last column as target
808                records[0]
809                    .values
810                    .keys()
811                    .last()
812                    .unwrap()
813                    .to_string()
814            })
815    };
816
817    // Determine feature columns
818    let feature_columns: Vec<String> = if let Some(ref feature_cols) = options.feature_columns {
819        feature_cols.clone()
820    } else {
821        records[0]
822            .values
823            .keys()
824            .filter(|&k| k != &target_column)
825            .cloned()
826            .collect()
827    };
828
829    let mut data = Vec::new();
830    let mut targets = Vec::new();
831    let mut record_count = 0;
832
833    for record in records {
834        if let Some(max_rows) = options.max_rows {
835            if record_count >= max_rows {
836                break;
837            }
838        }
839
840        let mut row = Vec::with_capacity(feature_columns.len());
841
842        // Extract features
843        for col in &feature_columns {
844            let value = record.values.get(col).and_then(|v| v.as_f64()).unwrap_or(f64::NAN);
845            row.push(value);
846        }
847
848        // Extract target
849        let target = record
850            .values
851            .get(&target_column)
852            .and_then(|v| v.as_f64())
853            .unwrap_or(f64::NAN);
854
855        // Skip rows with NaN targets
856        if !target.is_nan() {
857            data.push(row);
858            targets.push(target);
859            record_count += 1;
860        }
861    }
862
863    if data.is_empty() {
864        return Err(DataLoaderError::ValidationError("No valid data found".to_string()));
865    }
866
867    // Convert to FeatureMatrix
868    let n_samples = data.len();
869    let n_features = feature_columns.len();
870    let flat_data: Vec<f64> = data.into_iter().flatten().collect();
871    let array = ndarray::Array2::from_shape_vec((n_samples, n_features), flat_data)
872        .map_err(|e| DataLoaderError::ValidationError(e.to_string()))?;
873
874    let features = FeatureMatrix::with_feature_names(array, feature_columns)?;
875    Dataset::new(features, targets)
876        .map_err(DataLoaderError::DataError)
877}
878
879
880/// Loads a dataset from an Apache Parquet file.
881///
882/// This function reads Parquet data using the Arrow integration, which provides
883/// efficient columnar access and batch processing capabilities. It automatically
884/// handles schema inference, type conversion, and validation.
885///
886/// # Target Column Selection
887///
888/// The target column is determined by the following priority:
889/// 1. `options.target_column` if explicitly specified
890/// 2. First matching column from `["target", "label", "y", "class"]`
891/// 3. The last column in the file (fallback)
892///
893/// # Feature Column Selection
894///
895/// - If `options.feature_columns` is specified, uses only those columns
896/// - Otherwise, uses all columns except the target column
897///
898/// # Type Support
899///
900/// Converts the following Arrow types to `f64`:
901/// - `Float64` → direct conversion
902/// - `Float32` → cast to `f64`
903/// - `Int64` → cast to `f64`
904/// - `Int32` → cast to `f64`
905/// - Other types → `NaN` (skipped if in target, preserved if in features based on `missing_value`)
906///
907/// # Parameters
908/// - `path`: Path to the Parquet file
909/// - `options`: Configuration for loading behavior (target column, missing values, etc.)
910///
911/// # Returns
912/// `Ok(Dataset)` containing features and targets, or `Err(DataLoaderError)` if loading fails.
913///
914/// # Errors
915/// - `DataLoaderError::IoError` if file cannot be opened
916/// - `DataLoaderError::ParquetError` if Parquet reading fails
917/// - `DataLoaderError::MissingColumn` if configured target column doesn't exist
918/// - `DataLoaderError::ValidationError` if file is empty or malformed
919/// - `DataLoaderError::PreprocessingError` if data conversion fails
920/// - `DataLoaderError::DataError` if dataset validation fails
921///
922/// # Memory Efficiency
923///
924/// Reads data in batches of 8192 rows to limit memory usage on large files.
925/// Respects `options.max_rows` if set for additional memory control.
926///
927/// # Row Filtering
928///
929/// Rows are skipped if:
930/// - Target value is not finite (NaN or infinite)
931/// - Any feature value is not finite AND `missing_value` is not configured
932pub fn load_parquet<P: AsRef<Path>>(
933    path: P,
934    options: &LoadOptions,
935) -> DataLoaderResult<Dataset> {
936    let file = File::open(path)?;
937    let builder = ArrowReaderBuilder::try_new(file)?;
938    let schema = builder.schema().clone();
939    
940    let field_names: Vec<String> = schema.fields().iter()
941        .map(|f| f.name().clone())
942        .collect();
943    
944    if field_names.is_empty() {
945        return Err(DataLoaderError::ValidationError("Parquet file has no columns".to_string()));
946    }
947    
948    // Determine target column
949    let target_column = if let Some(ref target_col) = options.target_column {
950        target_col.clone()
951    } else {
952        // Try common target names
953        let possible_targets = ["target", "label", "y", "class"];
954        possible_targets
955            .iter()
956            .find(|&&col| field_names.contains(&col.to_string()))
957            .map(|s| s.to_string())
958            .unwrap_or_else(|| field_names.last().unwrap().clone())
959    };
960    
961    let target_idx = field_names.iter()
962        .position(|h| h == &target_column)
963        .ok_or_else(|| {
964            DataLoaderError::MissingColumn(
965                format!("Target column '{}' not found in Parquet file. Available columns: {:?}", 
966                       target_column, field_names)
967            )
968        })?;
969    
970    // Determine feature columns
971    let feature_names: Vec<String> = if let Some(ref feature_cols) = options.feature_columns {
972        feature_cols.clone()
973    } else {
974        field_names.iter()
975            .filter(|&name| name != &target_column)
976            .cloned()
977            .collect()
978    };
979    
980    // Map feature names to indices
981    let feature_indices: Vec<usize> = feature_names.iter()
982        .map(|name| {
983            field_names.iter()
984                .position(|f| f == name)
985                .ok_or_else(|| DataLoaderError::MissingColumn(name.clone()))
986        })
987        .collect::<Result<Vec<_>, _>>()?;
988    
989    // Read batches
990    let reader = builder.build()?;
991    let mut data: Vec<Vec<f64>> = Vec::new();
992    let mut targets: Vec<f64> = Vec::new();
993    
994    for batch_result in reader {
995        let batch = batch_result.map_err(|e| DataLoaderError::PreprocessingError(e.to_string()))?;
996        
997        // Check if we've reached max rows
998        if let Some(max_rows) = options.max_rows {
999            if data.len() >= max_rows {
1000                break;
1001            }
1002        }
1003        
1004        let num_rows = batch.num_rows();
1005        
1006        // Extract target column
1007        let target_column = batch.column(target_idx);
1008        
1009        // Process each row
1010        for row_idx in 0..num_rows {
1011            if let Some(max_rows) = options.max_rows {
1012                if data.len() >= max_rows {
1013                    break;
1014                }
1015            }
1016            
1017            // Extract features
1018            let mut feature_row = Vec::with_capacity(feature_indices.len());
1019            for &feature_idx in &feature_indices {
1020                let array = batch.column(feature_idx);
1021                let value = extract_numeric_value(array, row_idx)?;
1022                
1023                // Handle missing values
1024                let final_value = if !value.is_finite() {
1025                    options.missing_value.as_ref()
1026                        .and_then(|s| s.parse::<f64>().ok())
1027                        .unwrap_or(f64::NAN)
1028                } else {
1029                    value
1030                };
1031                feature_row.push(final_value);
1032            }
1033            
1034            // Extract target
1035            let target_value = extract_numeric_value(target_column, row_idx)?;
1036            
1037            // Skip rows with NaN targets
1038            if target_value.is_finite() {
1039                // Check for non-finite feature values
1040                let has_non_finite = feature_row.iter().any(|&x| !x.is_finite());
1041                if !has_non_finite {
1042                    data.push(feature_row);
1043                    targets.push(target_value);
1044                }
1045            }
1046        }
1047    }
1048    
1049    if data.is_empty() {
1050        return Err(DataLoaderError::ValidationError(
1051            "No valid data found in Parquet file".to_string()
1052        ));
1053    }
1054    
1055    // Create FeatureMatrix
1056    let n_samples = data.len();
1057    let n_features = feature_names.len();
1058    
1059    let flat_data: Vec<f64> = data.into_iter().flatten().collect();
1060    let array = Array2::from_shape_vec((n_samples, n_features), flat_data)
1061        .map_err(|e| DataLoaderError::ValidationError(
1062            format!("Failed to create feature matrix: {}", e)
1063        ))?;
1064    
1065    let features = FeatureMatrix::with_feature_names(array, feature_names)?;
1066    Dataset::new(features, targets)
1067        .map_err(DataLoaderError::DataError)
1068}
1069
1070/// Extracts a numeric value from an Arrow array at a specific row index.
1071/// 
1072/// This helper function converts Arrow's native types to f64 for consistency
1073/// with the rest of the gradient boosting implementation.
1074/// 
1075/// # Supported Types
1076/// 
1077/// - `Float64`: Direct extraction
1078/// - `Float32`: Casts to f64
1079/// - `Int64`: Casts to f64 (may lose precision for very large values)
1080/// - `Int32`: Casts to f64
1081/// - All others: Returns `NaN`
1082/// 
1083/// # Parameters
1084/// - `array`: Reference to a dynamic Arrow array (`dyn Array`)
1085/// - `row_idx`: Row index to extract (0-based)
1086/// 
1087/// # Returns
1088/// `Ok(f64)` value extracted and converted, or `Err` if the operation fails.
1089/// 
1090/// # Panics
1091/// This function will panic if `row_idx` is out of bounds for the given array.
1092/// The caller should ensure bounds checking.
1093/// 
1094/// # Safety
1095/// Downcasting to specific array types is safe when checked against `array.data_type()`.
1096fn extract_numeric_value(array: &dyn Array, row_idx: usize) -> DataLoaderResult<f64> {
1097    
1098    Ok(match array.data_type() {
1099        DataType::Float64 => {
1100            let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
1101            arr.value(row_idx)
1102        },
1103        DataType::Float32 => {
1104            let arr = array.as_any().downcast_ref::<Float32Array>().unwrap();
1105            arr.value(row_idx) as f64
1106        },
1107        DataType::Int64 => {
1108            let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
1109            arr.value(row_idx) as f64
1110        },
1111        DataType::Int32 => {
1112            let arr = array.as_any().downcast_ref::<Int32Array>().unwrap();
1113            arr.value(row_idx) as f64
1114        },
1115        _ => {
1116            // Try numeric conversion for other types
1117            if let Some(arr) = array.as_any().downcast_ref::<Float64Array>() {
1118                arr.value(row_idx)
1119            } else if let Some(arr) = array.as_any().downcast_ref::<Float32Array>() {
1120                arr.value(row_idx) as f64
1121            } else if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
1122                arr.value(row_idx) as f64
1123            } else if let Some(arr) = array.as_any().downcast_ref::<Int32Array>() {
1124                arr.value(row_idx) as f64
1125            } else {
1126                // For unsupported types, return NaN
1127                f64::NAN
1128            }
1129        }
1130    })
1131}
1132
1133// ============================================================================
1134// Dataset Saving Functions
1135// ============================================================================
1136
1137/// Saves dataset to file in specified format.
1138/// 
1139/// # Parameters
1140/// - `dataset`: Dataset to save
1141/// - `path`: Output file path
1142/// - `format`: DataFormat to use
1143/// 
1144/// # Returns
1145/// `Ok(())` on successful save
1146/// 
1147/// # Errors
1148/// - Format-specific save errors
1149/// - I/O errors
1150pub fn save_dataset<P: AsRef<Path>>(
1151    dataset: &Dataset,
1152    path: P,
1153    format: DataFormat,
1154) -> DataLoaderResult<()> {
1155    match format {
1156        DataFormat::Csv => save_dataset_csv(dataset, path),
1157        DataFormat::Json => save_dataset_json(dataset, path),
1158        DataFormat::Parquet => save_dataset_parquet(dataset, path),
1159    }
1160}
1161
1162/// Saves dataset to CSV file.
1163///
1164/// # Parameters
1165/// - `dataset`: Dataset to save
1166/// - `path`: Output CSV file path
1167///
1168    /// # Format
1169/// - First row: feature_names_1, feature_names_2, ..., "target"
1170/// - Subsequent rows: feature values, target value
1171fn save_dataset_csv<P: AsRef<Path>>(dataset: &Dataset, path: P) -> DataLoaderResult<()> {
1172    let file = File::create(path)?;
1173    let mut writer = csv::Writer::from_writer(BufWriter::new(file));
1174
1175    // Write headers
1176    let mut headers: Vec<String> = dataset.features().feature_names().to_vec();
1177    headers.push("target".to_string());
1178    writer.write_record(&headers)?;
1179
1180    // Write data
1181    for i in 0..dataset.n_samples() {
1182        let sample = dataset.features().get_sample(i)?;
1183        let target = dataset.targets()[i];
1184
1185        let mut record: Vec<String> = sample.iter().map(|x| x.to_string()).collect();
1186        record.push(target.to_string());
1187
1188        writer.write_record(&record)?;
1189    }
1190
1191    writer.flush()?;
1192    Ok(())
1193}
1194
1195/// Saves dataset to JSON file.
1196///
1197/// # Parameters
1198/// - `dataset`: Dataset to save
1199/// - `path`: Output JSON file path
1200///
1201/// # Format
1202/// JSON array of objects, each containing all features and target:
1203/// ```json
1204/// [
1205///   {"feature1": 1.0, "feature2": 2.0, "target": 0.0},
1206///   {"feature1": 3.0, "feature2": 4.0, "target": 1.0}
1207/// ]
1208/// ```
1209fn save_dataset_json<P: AsRef<Path>>(dataset: &Dataset, path: P) -> DataLoaderResult<()> {
1210    use serde_json::Value;
1211
1212    let mut records = Vec::new();
1213
1214    for i in 0..dataset.n_samples() {
1215        let sample = dataset.features().get_sample(i)?;
1216        let target = dataset.targets()[i];
1217
1218        let mut record = serde_json::Map::new();
1219        
1220        // Add features
1221        for (j, &value) in sample.iter().enumerate() {
1222            let feature_name = dataset.features().feature_names()[j].clone();
1223            record.insert(feature_name, Value::from(value));
1224        }
1225        
1226        // Add target
1227        record.insert("target".to_string(), Value::from(target));
1228        
1229        records.push(Value::Object(record));
1230    }
1231
1232    let file = File::create(path)?;
1233    serde_json::to_writer_pretty(BufWriter::new(file), &records)?;
1234    Ok(())
1235}
1236
1237/// Saves dataset to Parquet file (placeholder).
1238///
1239/// # Note
1240/// This method currently returns `UnsupportedFormat` error.
1241fn save_dataset_parquet<P: AsRef<Path>>(_dataset: &Dataset, _path: P) -> DataLoaderResult<()> {
1242    // Placeholder implementation
1243    Err(DataLoaderError::UnsupportedFormat {
1244        format: "parquet".to_string(),
1245    })
1246}
1247
1248// ============================================================================
1249// Splitting and Cross-Validation Functions
1250// ============================================================================
1251
1252/// Splits dataset into training and testing sets.
1253/// 
1254/// # Parameters
1255/// - `dataset`: Dataset to split
1256/// - `test_size`: Fraction of samples for test set (0.0-1.0)
1257/// - `shuffle`: Whether to randomize sample order before splitting
1258/// - `random_state`: Optional seed for reproducible shuffling
1259/// 
1260/// # Returns
1261/// Tuple of (train_dataset, test_dataset)
1262/// 
1263/// # Errors
1264/// - `DataLoaderError::ValidationError` if test_size not in (0, 1)
1265/// - `DataLoaderError::ValidationError` if dataset too small
1266pub fn split_data(
1267    dataset: &Dataset,
1268    test_size: f64,
1269    shuffle: bool,
1270    random_state: Option<u64>,
1271) -> DataLoaderResult<(Dataset, Dataset)> {
1272    if !(0.0..1.0).contains(&test_size) {
1273        return Err(DataLoaderError::ValidationError(
1274            "test_size must be between 0 and 1".to_string(),
1275        ));
1276    }
1277
1278    dataset
1279        .train_test_split(test_size, shuffle)
1280        .map_err(DataLoaderError::DataError)
1281}
1282
1283/// Creates k-fold cross-validation splits.
1284/// 
1285/// Generates `n_splits` train/test dataset pairs for cross-validation.
1286/// Each fold uses a different, non-overlapping portion as test set.
1287/// 
1288/// # Parameters
1289/// - `dataset`: Dataset to split
1290/// - `n_splits`: Number of folds (k)
1291/// - `shuffle`: Whether to shuffle before splitting
1292/// - `random_state`: Optional seed for reproducible shuffling
1293/// 
1294/// # Returns
1295/// Vector of (train_dataset, test_dataset) tuples, length = n_splits
1296/// 
1297    /// # Errors
1298/// - `DataLoaderError::ValidationError` if n_splits < 2
1299/// - `DataLoaderError::ValidationError` if dataset too small
1300pub fn cross_validation_split(
1301    dataset: &Dataset,
1302    n_splits: usize,
1303    shuffle: bool,
1304    _random_state: Option<u64>,
1305) -> DataLoaderResult<Vec<(Dataset, Dataset)>> {
1306    if n_splits < 2 {
1307        return Err(DataLoaderError::ValidationError(
1308            "n_splits must be at least 2".to_string(),
1309        ));
1310    }
1311
1312    if dataset.n_samples() < n_splits {
1313        return Err(DataLoaderError::ValidationError(
1314            "Number of samples must be at least n_splits".to_string(),
1315        ));
1316    }
1317
1318    let mut splits = Vec::new();
1319    let fold_size = dataset.n_samples() / n_splits;
1320
1321    // Create indices
1322    let mut indices: Vec<usize> = (0..dataset.n_samples()).collect();
1323    if shuffle {
1324        use rand::seq::SliceRandom;
1325        use rand::thread_rng;
1326        indices.shuffle(&mut thread_rng());
1327    }
1328
1329    for fold in 0..n_splits {
1330        let test_start = fold * fold_size;
1331        let test_end = if fold == n_splits - 1 {
1332            dataset.n_samples()
1333        } else {
1334            (fold + 1) * fold_size
1335        };
1336
1337        let test_indices: Vec<usize> = indices[test_start..test_end].to_vec();
1338        let train_indices: Vec<usize> = indices[0..test_start]
1339            .iter()
1340            .chain(&indices[test_end..])
1341            .cloned()
1342            .collect();
1343
1344        let train_dataset = dataset.select_samples(&train_indices)?;
1345        let test_dataset = dataset.select_samples(&test_indices)?;
1346
1347        splits.push((train_dataset, test_dataset));
1348    }
1349
1350    Ok(splits)
1351}
1352