Skip to main content

anofox_ml_io/
csv_reader.rs

1use anofox_ml_core::Float;
2use ndarray::{Array1, Array2};
3use std::path::Path;
4use std::str::FromStr;
5
6/// Result of reading a CSV file: (features, optional target, optional headers).
7pub type CsvReadResult<F> = Result<(Array2<F>, Option<Array1<F>>, Option<Vec<String>>), CsvError>;
8
9/// Options for reading CSV files.
10#[derive(Debug, Clone)]
11pub struct CsvReadOptions {
12    /// Whether the first row is a header (default: true).
13    pub has_header: bool,
14    /// Field delimiter (default: b',').
15    pub delimiter: u8,
16    /// Column index to use as the target variable (y). If `None`, no target
17    /// is extracted and only the feature matrix X is returned.
18    pub target_column: Option<usize>,
19}
20
21impl Default for CsvReadOptions {
22    fn default() -> Self {
23        Self {
24            has_header: true,
25            delimiter: b',',
26            target_column: None,
27        }
28    }
29}
30
31impl CsvReadOptions {
32    /// Create default options.
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    /// Set whether the file has a header row.
38    pub fn with_header(mut self, has_header: bool) -> Self {
39        self.has_header = has_header;
40        self
41    }
42
43    /// Set the field delimiter.
44    pub fn with_delimiter(mut self, delimiter: u8) -> Self {
45        self.delimiter = delimiter;
46        self
47    }
48
49    /// Set the target column index (for supervised learning).
50    pub fn with_target_column(mut self, col: usize) -> Self {
51        self.target_column = Some(col);
52        self
53    }
54}
55
56/// Parse a single CSV record into a vector of floats.
57fn parse_record_to_floats<F: Float + FromStr>(
58    record: &csv::StringRecord,
59    row_idx: usize,
60) -> Result<Vec<F>, CsvError> {
61    record
62        .iter()
63        .enumerate()
64        .map(|(col_idx, field)| {
65            let trimmed = field.trim();
66            F::from_str(trimmed).map_err(|_| {
67                CsvError::Parse(format!(
68                    "cannot parse '{}' as float at row {}, col {}",
69                    trimmed, row_idx, col_idx
70                ))
71            })
72        })
73        .collect()
74}
75
76/// Validate that every row in `all_values` has exactly `n_cols` columns.
77fn validate_column_consistency<F: Float>(
78    all_values: &[Vec<F>],
79    n_cols: usize,
80) -> Result<(), CsvError> {
81    for (i, row) in all_values.iter().enumerate() {
82        if row.len() != n_cols {
83            return Err(CsvError::Parse(format!(
84                "row {} has {} columns, expected {}",
85                i,
86                row.len(),
87                n_cols
88            )));
89        }
90    }
91    Ok(())
92}
93
94/// Split the parsed values into a feature matrix and target vector by
95/// extracting the column at `target_col`.
96fn split_features_and_target<F: Float>(
97    all_values: Vec<Vec<F>>,
98    target_col: usize,
99    n_rows: usize,
100    n_cols: usize,
101) -> Result<(Array2<F>, Array1<F>), CsvError> {
102    if target_col >= n_cols {
103        return Err(CsvError::Parse(format!(
104            "target_column {} out of range (file has {} columns)",
105            target_col, n_cols
106        )));
107    }
108
109    let feature_cols = n_cols - 1;
110    let mut x_data = Vec::with_capacity(n_rows * feature_cols);
111    let mut y_data = Vec::with_capacity(n_rows);
112
113    for row in &all_values {
114        y_data.push(row[target_col]);
115        for (j, &val) in row.iter().enumerate() {
116            if j != target_col {
117                x_data.push(val);
118            }
119        }
120    }
121
122    let x = Array2::from_shape_vec((n_rows, feature_cols), x_data)
123        .map_err(|e| CsvError::Parse(e.to_string()))?;
124    let y = Array1::from_vec(y_data);
125
126    Ok((x, y))
127}
128
129/// Parse header names from the reader, if configured.
130fn parse_headers(
131    reader: &mut csv::Reader<std::fs::File>,
132    has_header: bool,
133) -> Result<Option<Vec<String>>, CsvError> {
134    if has_header {
135        Ok(Some(
136            reader
137                .headers()
138                .map_err(|e| CsvError::Io(e.to_string()))?
139                .iter()
140                .map(|s| s.to_string())
141                .collect(),
142        ))
143    } else {
144        Ok(None)
145    }
146}
147
148/// Read all CSV records into a vector of parsed float rows.
149fn read_all_records<F: Float + FromStr>(
150    reader: &mut csv::Reader<std::fs::File>,
151) -> Result<Vec<Vec<F>>, CsvError> {
152    let mut all_values: Vec<Vec<F>> = Vec::new();
153    for (row_idx, result) in reader.records().enumerate() {
154        let record = result.map_err(|e| CsvError::Parse(format!("row {}: {}", row_idx, e)))?;
155        all_values.push(parse_record_to_floats(&record, row_idx)?);
156    }
157    if all_values.is_empty() {
158        return Err(CsvError::Empty);
159    }
160    Ok(all_values)
161}
162
163/// Validate columns and assemble the final result, optionally splitting
164/// a target column out of the feature matrix.
165fn assemble_result<F: Float>(
166    all_values: Vec<Vec<F>>,
167    target_column: Option<usize>,
168    headers: Option<Vec<String>>,
169) -> CsvReadResult<F> {
170    let n_rows = all_values.len();
171    let n_cols = all_values[0].len();
172    validate_column_consistency(&all_values, n_cols)?;
173
174    match target_column {
175        Some(target_col) => {
176            let (x, y) = split_features_and_target(all_values, target_col, n_rows, n_cols)?;
177            Ok((x, Some(y), headers))
178        }
179        None => {
180            let flat: Vec<F> = all_values.into_iter().flatten().collect();
181            let x = Array2::from_shape_vec((n_rows, n_cols), flat)
182                .map_err(|e| CsvError::Parse(e.to_string()))?;
183            Ok((x, None, headers))
184        }
185    }
186}
187
188/// Read a CSV file into an ndarray feature matrix (and optionally a target vector).
189///
190/// Returns `(X, Option<y>, Option<header_names>)`.
191///
192/// - If `options.target_column` is set, that column is extracted as `y` and
193///   excluded from `X`.
194/// - If `options.has_header` is true, header names are returned.
195pub fn read_csv<F, P>(path: P, options: &CsvReadOptions) -> CsvReadResult<F>
196where
197    F: Float + FromStr,
198    P: AsRef<Path>,
199{
200    let mut reader = csv::ReaderBuilder::new()
201        .has_headers(options.has_header)
202        .delimiter(options.delimiter)
203        .from_path(path.as_ref())
204        .map_err(|e| CsvError::Io(e.to_string()))?;
205
206    let headers = parse_headers(&mut reader, options.has_header)?;
207    let all_values = read_all_records(&mut reader)?;
208    assemble_result(all_values, options.target_column, headers)
209}
210
211/// Convenience function: read a CSV file with headers, returning only the
212/// feature matrix and target vector.
213pub fn read_csv_with_header<F, P>(
214    path: P,
215    target_column: usize,
216) -> Result<(Array2<F>, Array1<F>), CsvError>
217where
218    F: Float + FromStr,
219    P: AsRef<Path>,
220{
221    let options = CsvReadOptions::new().with_target_column(target_column);
222    let (x, y, _) = read_csv(path, &options)?;
223    match y {
224        Some(y) => Ok((x, y)),
225        None => Err(CsvError::Parse("target_column should have been set".into())),
226    }
227}
228
229/// Errors that can occur when reading CSV files.
230#[derive(Debug)]
231pub enum CsvError {
232    /// I/O error (file not found, permission denied, etc.)
233    Io(String),
234    /// Parse error (invalid float, inconsistent columns, etc.)
235    Parse(String),
236    /// The CSV file is empty.
237    Empty,
238}
239
240impl std::fmt::Display for CsvError {
241    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242        match self {
243            CsvError::Io(msg) => write!(f, "CSV I/O error: {}", msg),
244            CsvError::Parse(msg) => write!(f, "CSV parse error: {}", msg),
245            CsvError::Empty => write!(f, "CSV file is empty"),
246        }
247    }
248}
249
250impl std::error::Error for CsvError {}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use approx::assert_abs_diff_eq;
256    use std::io::Write;
257
258    fn write_temp_csv(content: &str) -> tempfile::NamedTempFile {
259        let mut file = tempfile::NamedTempFile::new().unwrap();
260        file.write_all(content.as_bytes()).unwrap();
261        file.flush().unwrap();
262        file
263    }
264
265    #[test]
266    fn test_read_csv_basic() {
267        let csv = "a,b,c\n1.0,2.0,3.0\n4.0,5.0,6.0\n";
268        let file = write_temp_csv(csv);
269        let options = CsvReadOptions::new();
270        let (x, y, headers): (Array2<f64>, _, _) = read_csv(file.path(), &options).unwrap();
271
272        assert_eq!(x.shape(), &[2, 3]);
273        assert_abs_diff_eq!(x[[0, 0]], 1.0);
274        assert_abs_diff_eq!(x[[1, 2]], 6.0);
275        assert!(y.is_none());
276        assert_eq!(headers.unwrap(), vec!["a", "b", "c"]);
277    }
278
279    #[test]
280    fn test_read_csv_with_target() {
281        let csv = "f1,f2,label\n1.0,2.0,0.0\n3.0,4.0,1.0\n5.0,6.0,0.0\n";
282        let file = write_temp_csv(csv);
283        let options = CsvReadOptions::new().with_target_column(2);
284        let (x, y, _): (Array2<f64>, _, _) = read_csv(file.path(), &options).unwrap();
285
286        assert_eq!(x.shape(), &[3, 2]);
287        assert_abs_diff_eq!(x[[0, 0]], 1.0);
288        assert_abs_diff_eq!(x[[2, 1]], 6.0);
289
290        let y = y.unwrap();
291        assert_abs_diff_eq!(y[0], 0.0);
292        assert_abs_diff_eq!(y[1], 1.0);
293        assert_abs_diff_eq!(y[2], 0.0);
294    }
295
296    #[test]
297    fn test_read_csv_no_header() {
298        let csv = "1.0,2.0\n3.0,4.0\n";
299        let file = write_temp_csv(csv);
300        let options = CsvReadOptions::new().with_header(false);
301        let (x, _, headers): (Array2<f64>, _, _) = read_csv(file.path(), &options).unwrap();
302
303        assert_eq!(x.shape(), &[2, 2]);
304        assert!(headers.is_none());
305    }
306
307    #[test]
308    fn test_read_csv_semicolon_delimiter() {
309        let csv = "a;b\n1.0;2.0\n3.0;4.0\n";
310        let file = write_temp_csv(csv);
311        let options = CsvReadOptions::new().with_delimiter(b';');
312        let (x, _, _): (Array2<f64>, _, _) = read_csv(file.path(), &options).unwrap();
313
314        assert_eq!(x.shape(), &[2, 2]);
315        assert_abs_diff_eq!(x[[0, 1]], 2.0);
316    }
317
318    #[test]
319    fn test_read_csv_convenience() {
320        let csv = "f1,f2,y\n1.0,2.0,10.0\n3.0,4.0,20.0\n";
321        let file = write_temp_csv(csv);
322        let (x, y): (Array2<f64>, Array1<f64>) = read_csv_with_header(file.path(), 2).unwrap();
323
324        assert_eq!(x.shape(), &[2, 2]);
325        assert_abs_diff_eq!(y[0], 10.0);
326        assert_abs_diff_eq!(y[1], 20.0);
327    }
328
329    #[test]
330    fn test_read_csv_empty_file() {
331        let csv = "a,b\n";
332        let file = write_temp_csv(csv);
333        let options = CsvReadOptions::new();
334        let result: Result<(Array2<f64>, _, _), _> = read_csv(file.path(), &options);
335        assert!(result.is_err());
336    }
337
338    #[test]
339    fn test_read_csv_parse_error() {
340        let csv = "a,b\n1.0,not_a_number\n";
341        let file = write_temp_csv(csv);
342        let options = CsvReadOptions::new();
343        let result: Result<(Array2<f64>, _, _), _> = read_csv(file.path(), &options);
344        assert!(result.is_err());
345    }
346}