1use anofox_ml_core::Float;
2use ndarray::{Array1, Array2};
3use std::path::Path;
4use std::str::FromStr;
5
6pub type CsvReadResult<F> = Result<(Array2<F>, Option<Array1<F>>, Option<Vec<String>>), CsvError>;
8
9#[derive(Debug, Clone)]
11pub struct CsvReadOptions {
12 pub has_header: bool,
14 pub delimiter: u8,
16 pub target_column: Option<usize>,
19}
20
21impl Default for CsvReadOptions {
22 fn default() -> Self {
23 Self {
24 has_header: true,
25 delimiter: b',',
26 target_column: None,
27 }
28 }
29}
30
31impl CsvReadOptions {
32 pub fn new() -> Self {
34 Self::default()
35 }
36
37 pub fn with_header(mut self, has_header: bool) -> Self {
39 self.has_header = has_header;
40 self
41 }
42
43 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
45 self.delimiter = delimiter;
46 self
47 }
48
49 pub fn with_target_column(mut self, col: usize) -> Self {
51 self.target_column = Some(col);
52 self
53 }
54}
55
56fn parse_record_to_floats<F: Float + FromStr>(
58 record: &csv::StringRecord,
59 row_idx: usize,
60) -> Result<Vec<F>, CsvError> {
61 record
62 .iter()
63 .enumerate()
64 .map(|(col_idx, field)| {
65 let trimmed = field.trim();
66 F::from_str(trimmed).map_err(|_| {
67 CsvError::Parse(format!(
68 "cannot parse '{}' as float at row {}, col {}",
69 trimmed, row_idx, col_idx
70 ))
71 })
72 })
73 .collect()
74}
75
76fn validate_column_consistency<F: Float>(
78 all_values: &[Vec<F>],
79 n_cols: usize,
80) -> Result<(), CsvError> {
81 for (i, row) in all_values.iter().enumerate() {
82 if row.len() != n_cols {
83 return Err(CsvError::Parse(format!(
84 "row {} has {} columns, expected {}",
85 i,
86 row.len(),
87 n_cols
88 )));
89 }
90 }
91 Ok(())
92}
93
94fn split_features_and_target<F: Float>(
97 all_values: Vec<Vec<F>>,
98 target_col: usize,
99 n_rows: usize,
100 n_cols: usize,
101) -> Result<(Array2<F>, Array1<F>), CsvError> {
102 if target_col >= n_cols {
103 return Err(CsvError::Parse(format!(
104 "target_column {} out of range (file has {} columns)",
105 target_col, n_cols
106 )));
107 }
108
109 let feature_cols = n_cols - 1;
110 let mut x_data = Vec::with_capacity(n_rows * feature_cols);
111 let mut y_data = Vec::with_capacity(n_rows);
112
113 for row in &all_values {
114 y_data.push(row[target_col]);
115 for (j, &val) in row.iter().enumerate() {
116 if j != target_col {
117 x_data.push(val);
118 }
119 }
120 }
121
122 let x = Array2::from_shape_vec((n_rows, feature_cols), x_data)
123 .map_err(|e| CsvError::Parse(e.to_string()))?;
124 let y = Array1::from_vec(y_data);
125
126 Ok((x, y))
127}
128
129fn parse_headers(
131 reader: &mut csv::Reader<std::fs::File>,
132 has_header: bool,
133) -> Result<Option<Vec<String>>, CsvError> {
134 if has_header {
135 Ok(Some(
136 reader
137 .headers()
138 .map_err(|e| CsvError::Io(e.to_string()))?
139 .iter()
140 .map(|s| s.to_string())
141 .collect(),
142 ))
143 } else {
144 Ok(None)
145 }
146}
147
148fn read_all_records<F: Float + FromStr>(
150 reader: &mut csv::Reader<std::fs::File>,
151) -> Result<Vec<Vec<F>>, CsvError> {
152 let mut all_values: Vec<Vec<F>> = Vec::new();
153 for (row_idx, result) in reader.records().enumerate() {
154 let record = result.map_err(|e| CsvError::Parse(format!("row {}: {}", row_idx, e)))?;
155 all_values.push(parse_record_to_floats(&record, row_idx)?);
156 }
157 if all_values.is_empty() {
158 return Err(CsvError::Empty);
159 }
160 Ok(all_values)
161}
162
163fn assemble_result<F: Float>(
166 all_values: Vec<Vec<F>>,
167 target_column: Option<usize>,
168 headers: Option<Vec<String>>,
169) -> CsvReadResult<F> {
170 let n_rows = all_values.len();
171 let n_cols = all_values[0].len();
172 validate_column_consistency(&all_values, n_cols)?;
173
174 match target_column {
175 Some(target_col) => {
176 let (x, y) = split_features_and_target(all_values, target_col, n_rows, n_cols)?;
177 Ok((x, Some(y), headers))
178 }
179 None => {
180 let flat: Vec<F> = all_values.into_iter().flatten().collect();
181 let x = Array2::from_shape_vec((n_rows, n_cols), flat)
182 .map_err(|e| CsvError::Parse(e.to_string()))?;
183 Ok((x, None, headers))
184 }
185 }
186}
187
188pub fn read_csv<F, P>(path: P, options: &CsvReadOptions) -> CsvReadResult<F>
196where
197 F: Float + FromStr,
198 P: AsRef<Path>,
199{
200 let mut reader = csv::ReaderBuilder::new()
201 .has_headers(options.has_header)
202 .delimiter(options.delimiter)
203 .from_path(path.as_ref())
204 .map_err(|e| CsvError::Io(e.to_string()))?;
205
206 let headers = parse_headers(&mut reader, options.has_header)?;
207 let all_values = read_all_records(&mut reader)?;
208 assemble_result(all_values, options.target_column, headers)
209}
210
211pub fn read_csv_with_header<F, P>(
214 path: P,
215 target_column: usize,
216) -> Result<(Array2<F>, Array1<F>), CsvError>
217where
218 F: Float + FromStr,
219 P: AsRef<Path>,
220{
221 let options = CsvReadOptions::new().with_target_column(target_column);
222 let (x, y, _) = read_csv(path, &options)?;
223 match y {
224 Some(y) => Ok((x, y)),
225 None => Err(CsvError::Parse("target_column should have been set".into())),
226 }
227}
228
229#[derive(Debug)]
231pub enum CsvError {
232 Io(String),
234 Parse(String),
236 Empty,
238}
239
240impl std::fmt::Display for CsvError {
241 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242 match self {
243 CsvError::Io(msg) => write!(f, "CSV I/O error: {}", msg),
244 CsvError::Parse(msg) => write!(f, "CSV parse error: {}", msg),
245 CsvError::Empty => write!(f, "CSV file is empty"),
246 }
247 }
248}
249
250impl std::error::Error for CsvError {}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use approx::assert_abs_diff_eq;
256 use std::io::Write;
257
258 fn write_temp_csv(content: &str) -> tempfile::NamedTempFile {
259 let mut file = tempfile::NamedTempFile::new().unwrap();
260 file.write_all(content.as_bytes()).unwrap();
261 file.flush().unwrap();
262 file
263 }
264
265 #[test]
266 fn test_read_csv_basic() {
267 let csv = "a,b,c\n1.0,2.0,3.0\n4.0,5.0,6.0\n";
268 let file = write_temp_csv(csv);
269 let options = CsvReadOptions::new();
270 let (x, y, headers): (Array2<f64>, _, _) = read_csv(file.path(), &options).unwrap();
271
272 assert_eq!(x.shape(), &[2, 3]);
273 assert_abs_diff_eq!(x[[0, 0]], 1.0);
274 assert_abs_diff_eq!(x[[1, 2]], 6.0);
275 assert!(y.is_none());
276 assert_eq!(headers.unwrap(), vec!["a", "b", "c"]);
277 }
278
279 #[test]
280 fn test_read_csv_with_target() {
281 let csv = "f1,f2,label\n1.0,2.0,0.0\n3.0,4.0,1.0\n5.0,6.0,0.0\n";
282 let file = write_temp_csv(csv);
283 let options = CsvReadOptions::new().with_target_column(2);
284 let (x, y, _): (Array2<f64>, _, _) = read_csv(file.path(), &options).unwrap();
285
286 assert_eq!(x.shape(), &[3, 2]);
287 assert_abs_diff_eq!(x[[0, 0]], 1.0);
288 assert_abs_diff_eq!(x[[2, 1]], 6.0);
289
290 let y = y.unwrap();
291 assert_abs_diff_eq!(y[0], 0.0);
292 assert_abs_diff_eq!(y[1], 1.0);
293 assert_abs_diff_eq!(y[2], 0.0);
294 }
295
296 #[test]
297 fn test_read_csv_no_header() {
298 let csv = "1.0,2.0\n3.0,4.0\n";
299 let file = write_temp_csv(csv);
300 let options = CsvReadOptions::new().with_header(false);
301 let (x, _, headers): (Array2<f64>, _, _) = read_csv(file.path(), &options).unwrap();
302
303 assert_eq!(x.shape(), &[2, 2]);
304 assert!(headers.is_none());
305 }
306
307 #[test]
308 fn test_read_csv_semicolon_delimiter() {
309 let csv = "a;b\n1.0;2.0\n3.0;4.0\n";
310 let file = write_temp_csv(csv);
311 let options = CsvReadOptions::new().with_delimiter(b';');
312 let (x, _, _): (Array2<f64>, _, _) = read_csv(file.path(), &options).unwrap();
313
314 assert_eq!(x.shape(), &[2, 2]);
315 assert_abs_diff_eq!(x[[0, 1]], 2.0);
316 }
317
318 #[test]
319 fn test_read_csv_convenience() {
320 let csv = "f1,f2,y\n1.0,2.0,10.0\n3.0,4.0,20.0\n";
321 let file = write_temp_csv(csv);
322 let (x, y): (Array2<f64>, Array1<f64>) = read_csv_with_header(file.path(), 2).unwrap();
323
324 assert_eq!(x.shape(), &[2, 2]);
325 assert_abs_diff_eq!(y[0], 10.0);
326 assert_abs_diff_eq!(y[1], 20.0);
327 }
328
329 #[test]
330 fn test_read_csv_empty_file() {
331 let csv = "a,b\n";
332 let file = write_temp_csv(csv);
333 let options = CsvReadOptions::new();
334 let result: Result<(Array2<f64>, _, _), _> = read_csv(file.path(), &options);
335 assert!(result.is_err());
336 }
337
338 #[test]
339 fn test_read_csv_parse_error() {
340 let csv = "a,b\n1.0,not_a_number\n";
341 let file = write_temp_csv(csv);
342 let options = CsvReadOptions::new();
343 let result: Result<(Array2<f64>, _, _), _> = read_csv(file.path(), &options);
344 assert!(result.is_err());
345 }
346}