use crate::rows::TestCase;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct CsvParserOptions {
pub delimiter: u8,
pub has_headers: bool,
}
impl Default for CsvParserOptions {
fn default() -> Self {
Self {
delimiter: b',',
has_headers: true,
}
}
}
macro_rules! cannot_read_csv {
($path:expr) => {
|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("cannot read CSV {}: {}", $path.as_ref().display(), e),
)
}
};
}
macro_rules! cannot_parse_csv_row {
($row_idx:expr) => {
|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("cannot parse CSV row {}: {}", $row_idx + 1, e),
)
}
};
}
macro_rules! cannot_parse_cell {
($row_idx:expr, $col_idx:expr) => {
|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"cannot parse row {}, column {} as f64: {e}",
$row_idx + 1,
$col_idx + 1
),
)
}
};
}
fn parse_cell<T>(cell: &str, row_idx: usize, col_idx: usize) -> std::io::Result<T>
where
T: std::str::FromStr,
<T as std::str::FromStr>::Err: std::fmt::Display,
{
cell.parse::<T>()
.map_err(cannot_parse_cell!(row_idx, col_idx))
}
pub fn read_csv_vectors<T>(
path: impl AsRef<Path>,
options: &CsvParserOptions,
) -> std::io::Result<Vec<Vec<T>>>
where
T: std::str::FromStr,
<T as std::str::FromStr>::Err: std::fmt::Display,
{
let mut reader = csv::ReaderBuilder::new()
.delimiter(options.delimiter)
.has_headers(options.has_headers)
.from_path(path.as_ref())
.map_err(cannot_read_csv!(path))?;
let mut out = Vec::new();
for (row_idx, row) in reader.records().enumerate() {
let record = row.map_err(cannot_parse_csv_row!(row_idx))?;
let mut vals = Vec::with_capacity(record.len());
for (i, v) in record.iter().enumerate() {
vals.push(parse_cell(v, row_idx, i)?);
}
out.push(vals);
}
Ok(out)
}
pub fn read_csv_cases<T>(
path: impl AsRef<Path>,
options: &CsvParserOptions,
) -> std::io::Result<Vec<TestCase<T>>>
where
T: std::str::FromStr,
<T as std::str::FromStr>::Err: std::fmt::Display,
{
let mut reader = csv::ReaderBuilder::new()
.delimiter(options.delimiter)
.has_headers(options.has_headers)
.from_path(path.as_ref())
.map_err(cannot_read_csv!(path))?;
let mut out = Vec::new();
for (row_idx, row) in reader.records().enumerate() {
let record = row.map_err(cannot_parse_csv_row!(row_idx))?;
if record.len() < 2 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"row {} has {} column(s); expected at least 2",
row_idx + 1,
record.len()
),
));
}
let expected_col = record.len() - 1;
let mut inputs = Vec::with_capacity(expected_col);
for i in 0..expected_col {
let v = record.get(i).ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("missing value at row {}, column {}", row_idx + 1, i + 1),
)
})?;
inputs.push(parse_cell(v, row_idx, i)?);
}
let expected_cell = record.get(expected_col).ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"missing expected value at row {}, column {}",
row_idx + 1,
expected_col + 1
),
)
})?;
let expected = parse_cell(expected_cell, row_idx, expected_col)?;
out.push(TestCase {
inputs,
expected: vec![expected],
});
}
Ok(out)
}
pub fn read_split_csv_cases<T>(
inputs_path: impl AsRef<Path>,
expected_path: impl AsRef<Path>,
options: &CsvParserOptions,
) -> std::io::Result<Vec<TestCase<T>>>
where
T: std::str::FromStr,
<T as std::str::FromStr>::Err: std::fmt::Display,
{
let inputs_vectors = read_csv_vectors(inputs_path, options)?;
let expected_vectors = read_csv_vectors(expected_path, options)?;
if inputs_vectors.len() != expected_vectors.len() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"row count mismatch: inputs has {} rows, expected has {} rows",
inputs_vectors.len(),
expected_vectors.len()
),
));
}
let mut out = Vec::with_capacity(inputs_vectors.len());
for (inputs, expected) in inputs_vectors.into_iter().zip(expected_vectors) {
out.push(TestCase { inputs, expected });
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn read_accuracy_csv_last_column_expected() {
let tmp = tempdir().expect("tempdir");
let csv_path = tmp.path().join("cases.csv");
std::fs::write(&csv_path, "x,y,expected\n1.0,2.0,3.0\n2.0,5.0,7.0\n").expect("write");
let cases = read_csv_cases::<f64>(&csv_path, &CsvParserOptions::default()).expect("parse");
assert_eq!(cases.len(), 2);
assert_eq!(cases[0].inputs, vec![1.0, 2.0]);
assert_eq!(cases[0].expected, vec![3.0]);
}
#[test]
fn read_csv_empty_file() {
let tmp = tempdir().expect("tempdir");
let csv_path = tmp.path().join("empty.csv");
std::fs::write(&csv_path, "").expect("write");
let cases = read_csv_cases::<f64>(&csv_path, &CsvParserOptions::default()).expect("parse");
assert_eq!(cases.len(), 0);
}
#[test]
fn read_csv_only_headers() {
let tmp = tempdir().expect("tempdir");
let csv_path = tmp.path().join("headers.csv");
std::fs::write(&csv_path, "col1,col2,expected\n").expect("write");
let cases = read_csv_cases::<f64>(&csv_path, &CsvParserOptions::default()).expect("parse");
assert_eq!(cases.len(), 0);
}
#[test]
fn read_csv_invalid_numeric() {
let tmp = tempdir().expect("tempdir");
let csv_path = tmp.path().join("invalid.csv");
std::fs::write(&csv_path, "1.0,abc,3.0\n").expect("write");
let res = read_csv_cases::<f64>(
&csv_path,
&CsvParserOptions {
has_headers: false,
..Default::default()
},
);
assert!(res.is_err());
assert!(
res.unwrap_err()
.to_string()
.contains("cannot parse row 1, column 2")
);
}
#[test]
fn read_csv_different_delimiter() {
let tmp = tempdir().expect("tempdir");
let csv_path = tmp.path().join("semi.csv");
std::fs::write(&csv_path, "1.0;2.0;3.0\n").expect("write");
let options = CsvParserOptions {
delimiter: b';',
has_headers: false,
};
let cases = read_csv_cases::<f64>(&csv_path, &options).expect("parse");
assert_eq!(cases[0].inputs, vec![1.0, 2.0]);
assert_eq!(cases[0].expected, vec![3.0]);
}
#[test]
fn read_csv_pathological_numbers() {
let tmp = tempdir().expect("tempdir");
let csv_path = tmp.path().join("patho.csv");
std::fs::write(&csv_path, "NaN,inf,-inf,1.0\n").expect("write");
let cases = read_csv_cases::<f64>(
&csv_path,
&CsvParserOptions {
has_headers: false,
..Default::default()
},
)
.expect("parse");
assert!(cases[0].inputs[0].is_nan());
assert!(cases[0].inputs[1].is_infinite());
assert!(cases[0].inputs[2].is_infinite());
assert_eq!(cases[0].expected, vec![1.0]);
}
#[test]
fn read_csv_generic_f32() {
let tmp = tempdir().expect("tempdir");
let csv_path = tmp.path().join("f32.csv");
std::fs::write(&csv_path, "1.0,2.0\n").expect("write");
let cases = read_csv_cases::<f32>(
&csv_path,
&CsvParserOptions {
has_headers: false,
..Default::default()
},
)
.expect("parse");
assert_eq!(cases[0].inputs, vec![1.0f32]);
assert_eq!(cases[0].expected, vec![2.0f32]);
}
#[test]
fn test_read_separate_csvs() {
let tmp = tempdir().expect("tempdir");
let inputs_path = tmp.path().join("inputs.csv");
let expected_path = tmp.path().join("expected.csv");
std::fs::write(&inputs_path, "x,y\n1.0,2.0\n2.0,5.0\n").expect("write");
std::fs::write(&expected_path, "expected\n3.0\n7.0\n").expect("write");
let cases =
read_split_csv_cases::<f64>(&inputs_path, &expected_path, &CsvParserOptions::default())
.expect("parse");
assert_eq!(cases.len(), 2);
assert_eq!(cases[0].inputs, vec![1.0, 2.0]);
assert_eq!(cases[0].expected, vec![3.0]);
}
#[test]
fn test_read_separate_csvs_mismatch() {
let tmp = tempdir().expect("tempdir");
let inputs_path = tmp.path().join("inputs.csv");
let expected_path = tmp.path().join("expected.csv");
std::fs::write(&inputs_path, "x,y\n1.0,2.0\n2.0,5.0\n").expect("write");
std::fs::write(&expected_path, "expected\n3.0\n").expect("write");
let res =
read_split_csv_cases::<f64>(&inputs_path, &expected_path, &CsvParserOptions::default());
assert!(res.is_err());
assert!(res.unwrap_err().to_string().contains("row count mismatch"));
}
}