use super::{Dataset, DatasetError, DatasetMetadata};
use csv::{ReaderBuilder, StringRecord, Trim};
use nalgebra::DVector;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct CsvLoader {
pub has_headers: bool,
pub delimiter: u8,
pub target_columns: TargetColumns,
pub skip_columns: Vec<String>,
pub max_rows: Option<usize>,
pub infer_types: bool,
pub missing_values: Vec<String>,
}
#[derive(Debug, Clone)]
pub enum TargetColumns {
None,
Indices(Vec<usize>),
Names(Vec<String>),
Last(usize),
}
impl Default for CsvLoader {
fn default() -> Self {
Self {
has_headers: true,
delimiter: b',',
target_columns: TargetColumns::None,
skip_columns: Vec::new(),
max_rows: None,
infer_types: true,
missing_values: vec![
"".to_string(),
"NA".to_string(),
"N/A".to_string(),
"null".to_string(),
"NULL".to_string(),
"nan".to_string(),
"NaN".to_string(),
],
}
}
}
impl CsvLoader {
pub fn new() -> Self {
Self::default()
}
pub fn with_headers(mut self, has_headers: bool) -> Self {
self.has_headers = has_headers;
self
}
pub fn with_delimiter(mut self, delimiter: u8) -> Self {
self.delimiter = delimiter;
self
}
pub fn with_target_indices(mut self, indices: Vec<usize>) -> Self {
self.target_columns = TargetColumns::Indices(indices);
self
}
pub fn with_target_names(mut self, names: Vec<String>) -> Self {
self.target_columns = TargetColumns::Names(names);
self
}
pub fn with_last_n_targets(mut self, n: usize) -> Self {
self.target_columns = TargetColumns::Last(n);
self
}
pub fn with_skip_columns(mut self, skip_columns: Vec<String>) -> Self {
self.skip_columns = skip_columns;
self
}
pub fn with_max_rows(mut self, max_rows: usize) -> Self {
self.max_rows = Some(max_rows);
self
}
pub async fn from_path<P: AsRef<Path>>(&self, path: P) -> Result<Dataset, DatasetError> {
let content = tokio::fs::read_to_string(path.as_ref()).await?;
self.from_string(&content).await
}
pub async fn from_string(&self, data: &str) -> Result<Dataset, DatasetError> {
use std::io::Cursor;
let cursor = Cursor::new(data.as_bytes());
let mut csv_reader = ReaderBuilder::new()
.has_headers(self.has_headers)
.delimiter(self.delimiter)
.trim(Trim::All)
.from_reader(cursor);
let headers = if self.has_headers {
match csv_reader.headers() {
Ok(h) => Some(h.clone()),
Err(e) => return Err(DatasetError::CsvError(format!("Failed to read headers: {}", e))),
}
} else {
None
};
let column_info = self.analyze_columns(&headers, &mut csv_reader)?;
let (features, targets, _stats) = self.process_records(csv_reader, &column_info).await?;
let metadata = DatasetMetadata {
sample_count: features.len(),
feature_count: column_info.feature_columns.len(),
target_count: if targets.is_some() { Some(column_info.target_columns.len()) } else { None },
source: None,
data_type: Some("CSV".to_string()),
};
let mut dataset = Dataset::new(features, targets);
dataset.feature_names = Some(column_info.feature_names);
dataset.target_names = column_info.target_names;
dataset.metadata = metadata;
Ok(dataset)
}
fn analyze_columns<R>(&self, headers: &Option<StringRecord>, csv_reader: &mut csv::Reader<R>) -> Result<ColumnInfo, DatasetError>
where
R: std::io::Read,
{
let total_columns = if let Some(ref headers) = headers {
headers.len()
} else {
let mut record = StringRecord::new();
if let Err(e) = csv_reader.read_record(&mut record) {
return Err(DatasetError::CsvError(format!("Failed to read CSV record: {}", e)));
}
if record.is_empty() {
return Err(DatasetError::Validation("Empty CSV file".to_string()));
}
record.len()
};
let target_indices = match &self.target_columns {
TargetColumns::None => Vec::new(),
TargetColumns::Indices(indices) => indices.clone(),
TargetColumns::Names(names) => {
if let Some(ref headers) = headers {
let mut indices = Vec::new();
for name in names {
if let Some(index) = headers.iter().position(|h| h == name) {
indices.push(index);
} else {
return Err(DatasetError::Configuration(
format!("Target column '{}' not found in headers", name)
));
}
}
indices
} else {
return Err(DatasetError::Configuration(
"Cannot use target column names without headers".to_string()
));
}
},
TargetColumns::Last(n) => {
if *n > total_columns {
return Err(DatasetError::Configuration(
format!("Cannot use last {} columns as targets, only {} columns available", n, total_columns)
));
}
(total_columns - n..total_columns).collect()
},
};
let skip_indices: Vec<usize> = if !self.skip_columns.is_empty() {
if let Some(ref headers) = headers {
let mut indices = Vec::new();
for name in &self.skip_columns {
if let Some(index) = headers.iter().position(|h| h == name) {
indices.push(index);
}
}
indices
} else {
Vec::new() }
} else {
Vec::new()
};
let mut feature_indices = Vec::new();
for i in 0..total_columns {
if !target_indices.contains(&i) && !skip_indices.contains(&i) {
feature_indices.push(i);
}
}
let feature_names = if let Some(ref headers) = headers {
feature_indices.iter()
.map(|&i| headers.get(i).unwrap_or(&format!("feature_{}", i)).to_string())
.collect()
} else {
feature_indices.iter()
.map(|i| format!("feature_{}", i))
.collect()
};
let target_names = if !target_indices.is_empty() {
Some(if let Some(ref headers) = headers {
target_indices.iter()
.map(|&i| headers.get(i).unwrap_or(&format!("target_{}", i)).to_string())
.collect()
} else {
target_indices.iter()
.map(|i| format!("target_{}", i))
.collect()
})
} else {
None
};
Ok(ColumnInfo {
feature_columns: feature_indices,
target_columns: target_indices,
feature_names,
target_names,
})
}
async fn process_records<R>(&self, mut csv_reader: csv::Reader<R>, column_info: &ColumnInfo) -> Result<(Vec<DVector<f64>>, Option<Vec<DVector<f64>>>, DatasetStats), DatasetError>
where
R: std::io::Read,
{
let mut features = Vec::new();
let mut targets = if column_info.target_columns.is_empty() { None } else { Some(Vec::new()) };
let mut record = StringRecord::new();
let mut row_count = 0;
while let Ok(has_record) = csv_reader.read_record(&mut record) {
if !has_record {
break;
}
if let Some(max_rows) = self.max_rows {
if row_count >= max_rows {
break;
}
}
let mut feature_values = Vec::with_capacity(column_info.feature_columns.len());
for &col_idx in &column_info.feature_columns {
let value = record.get(col_idx).unwrap_or("").trim();
let parsed_value = self.parse_value(value)?;
feature_values.push(parsed_value);
}
features.push(DVector::from_vec(feature_values));
if let Some(ref mut targets_vec) = targets {
let mut target_values = Vec::with_capacity(column_info.target_columns.len());
for &col_idx in &column_info.target_columns {
let value = record.get(col_idx).unwrap_or("").trim();
let parsed_value = self.parse_value(value)?;
target_values.push(parsed_value);
}
targets_vec.push(DVector::from_vec(target_values));
}
row_count += 1;
if row_count % 10000 == 0 {
tokio::task::yield_now().await;
}
}
Ok((features, targets, DatasetStats {}))
}
fn parse_value(&self, value: &str) -> Result<f64, DatasetError> {
if self.missing_values.iter().any(|mv| mv == value) {
Ok(0.0) } else if let Ok(parsed) = value.parse::<f64>() {
Ok(parsed)
} else {
match value.to_lowercase().as_str() {
"true" | "yes" | "1" | "y" => Ok(1.0),
"false" | "no" | "0" | "n" => Ok(0.0),
_ => Err(DatasetError::Parse(
format!("Cannot parse value '{}' as number", value)
))
}
}
}
}
#[derive(Debug)]
struct ColumnInfo {
feature_columns: Vec<usize>,
target_columns: Vec<usize>,
feature_names: Vec<String>,
target_names: Option<Vec<String>>,
}
#[derive(Debug)]
struct DatasetStats {}