use csv::{ReaderBuilder, StringRecord};
use smartcore::linalg::basic::matrix::DenseMatrix;
use std::error::Error;
use std::fmt;
use std::fs::File;
use std::num::ParseFloatError;
use std::path::Path;
#[derive(Debug)]
pub enum CsvError {
Io(std::io::Error),
Parse(Box<dyn Error + Send + Sync>),
Shape(String),
}
impl fmt::Display for CsvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CsvError::Io(e) => write!(f, "I/O error: {e}"),
CsvError::Parse(e) => write!(f, "Parse error: {e}"),
CsvError::Shape(e) => write!(f, "Shape error: {e}"),
}
}
}
impl Error for CsvError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
CsvError::Io(e) => Some(e),
CsvError::Parse(e) => Some(&**e),
CsvError::Shape(_) => None,
}
}
}
pub fn load_csv_features<P: AsRef<Path>>(path: P) -> Result<DenseMatrix<f64>, CsvError> {
let mut reader = build_csv_reader(path.as_ref())?;
let mut features: Vec<Vec<f64>> = Vec::new();
let mut expected_width: Option<usize> = None;
for (row_idx, result) in reader.records().enumerate() {
let record = result.map_err(|e| CsvError::Parse(Box::new(e)))?;
let row = parse_feature_row(&record, row_idx)?;
ensure_consistent_width(&row, row_idx, &mut expected_width)?;
features.push(row);
}
if features.is_empty() {
return Err(CsvError::Shape("no rows found".to_string()));
}
let matrix = DenseMatrix::from_2d_vec(&features).map_err(|e| CsvError::Shape(e.to_string()))?;
Ok(matrix)
}
pub fn load_labeled_csv<P: AsRef<Path>>(
path: P,
target_col: usize,
) -> Result<(DenseMatrix<f64>, Vec<f64>), CsvError> {
let mut reader = build_csv_reader(path.as_ref())?;
let mut features: Vec<Vec<f64>> = Vec::new();
let mut targets: Vec<f64> = Vec::new();
let mut expected_width: Option<usize> = None;
for (row_idx, result) in reader.records().enumerate() {
let record = result.map_err(|e| CsvError::Parse(Box::new(e)))?;
let (row, target) = parse_labeled_row(&record, row_idx, target_col)?;
ensure_consistent_width(&row, row_idx, &mut expected_width)?;
targets.push(target);
features.push(row);
}
if features.is_empty() {
return Err(CsvError::Shape("no rows found".to_string()));
}
let matrix = DenseMatrix::from_2d_vec(&features).map_err(|e| CsvError::Shape(e.to_string()))?;
Ok((matrix, targets))
}
fn build_csv_reader(path: &Path) -> Result<csv::Reader<File>, CsvError> {
let file = File::open(path).map_err(CsvError::Io)?;
Ok(ReaderBuilder::new()
.has_headers(true)
.flexible(true)
.from_reader(file))
}
fn parse_feature_row(record: &StringRecord, row_idx: usize) -> Result<Vec<f64>, CsvError> {
if record.is_empty() {
return Err(CsvError::Shape(format!(
"row {}: expected at least one column",
row_idx + 1
)));
}
record
.iter()
.enumerate()
.map(|(col_idx, value)| parse_numeric_field(value, row_idx, col_idx))
.collect()
}
fn parse_labeled_row(
record: &StringRecord,
row_idx: usize,
target_col: usize,
) -> Result<(Vec<f64>, f64), CsvError> {
if record.len() <= target_col {
return Err(CsvError::Shape(format!(
"row {}: target column index {} out of bounds (row has {} columns)",
row_idx + 1,
target_col,
record.len()
)));
}
if record.len() <= 1 {
return Err(CsvError::Shape(format!(
"row {}: expected at least one feature column in addition to the target",
row_idx + 1
)));
}
let mut target = None;
let mut row = Vec::with_capacity(record.len() - 1);
for (col_idx, value) in record.iter().enumerate() {
let parsed = parse_numeric_field(value, row_idx, col_idx)?;
if col_idx == target_col {
target = Some(parsed);
} else {
row.push(parsed);
}
}
match target {
Some(target_value) => Ok((row, target_value)),
None => Err(CsvError::Shape(format!(
"row {}: missing target column {}",
row_idx + 1,
target_col
))),
}
}
fn parse_numeric_field(value: &str, row_idx: usize, col_idx: usize) -> Result<f64, CsvError> {
value.parse::<f64>().map_err(|err: ParseFloatError| {
CsvError::Parse(Box::new(FloatParseError::new(
row_idx + 1,
col_idx + 1,
err,
)))
})
}
fn ensure_consistent_width(
row: &[f64],
row_idx: usize,
expected_width: &mut Option<usize>,
) -> Result<(), CsvError> {
if row.is_empty() {
return Err(CsvError::Shape(format!(
"row {}: expected at least one column",
row_idx + 1
)));
}
match expected_width {
Some(width) if row.len() != *width => Err(CsvError::Shape(format!(
"row {}: expected {} columns but found {}",
row_idx + 1,
width,
row.len()
))),
Some(_) => Ok(()),
None => {
*expected_width = Some(row.len());
Ok(())
}
}
}
#[derive(Debug)]
struct FloatParseError {
row: usize,
column: usize,
source: ParseFloatError,
}
impl FloatParseError {
fn new(row: usize, column: usize, source: ParseFloatError) -> Self {
Self {
row,
column,
source,
}
}
}
impl fmt::Display for FloatParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"failed to parse float at row {}, column {}: {}",
self.row, self.column, self.source
)
}
}
impl Error for FloatParseError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
Some(&self.source)
}
}