ferrite_rs/csv_io/
mod.rs

1use csv::{ReaderBuilder, WriterBuilder};
2use ndarray::{Array2, Axis};
3use rand::seq::SliceRandom;
4use rand::thread_rng;
5use std::boxed::Box;
6use std::error::Error;
7use std::fs::File;
8use std::io::BufReader;
9
10/// Function to split dataset into train and test sets
11///
12/// # Parameters:
13/// - `x: Array2<f64>` - Feature matrix
14/// - `y: Array2<f64>` - Target matrix
15/// - `split_ratio: f64` - Ratio for the training set (e.g., 0.8 for 80% train, 20% test)
16///
17/// # Returns:
18/// - `Result<(Array2<f64>, Array2<f64>, Array2<f64>, Array2<f64>), Box<dyn Error>>`
19///   - Tuple containing (x_train, y_train, x_test, y_test)
20pub fn train_test_split(
21    x: Array2<f64>,
22    y: Array2<f64>,
23    split_ratio: f64,
24) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>, Array2<f64>), Box<dyn Error>> {
25    if split_ratio <= 0.0 || split_ratio >= 1.0 {
26        return Err("Split ratio should be between 0 and 1".into());
27    }
28
29    let num_samples = x.nrows();
30    let num_train = (num_samples as f64 * split_ratio).round() as usize;
31
32    let mut indices: Vec<usize> = (0..num_samples).collect();
33    let mut rng = thread_rng();
34    indices.shuffle(&mut rng);
35
36    let train_indices = &indices[..num_train];
37    let test_indices = &indices[num_train..];
38
39    let x_train = x.select(Axis(0), train_indices);
40    let y_train = y.select(Axis(0), train_indices);
41    let x_test = x.select(Axis(0), test_indices);
42    let y_test = y.select(Axis(0), test_indices);
43
44    Ok((x_train, y_train, x_test, y_test))
45}
46
47/// Function to read and parse a CSV file without headers
48///
49/// # Parameters:
50/// - `filepath: String` - Relative path of the CSV file wrt Cargo.toml file of the project
51///
52/// # Returns:
53/// - `Result<Array2<f64>, Box<dyn Error>>` - 2D Array of the CSV file without headers
54pub fn read(filepath: String) -> Result<Array2<f64>, Box<dyn Error>> {
55    let file = File::open(filepath)?;
56    let reader = BufReader::new(file);
57
58    let mut csv_reader = ReaderBuilder::new()
59        .has_headers(false) // Ignore headers
60        .from_reader(reader);
61
62    let mut data: Vec<f64> = Vec::new();
63    let mut rows = 0;
64    let mut cols = 0;
65
66    for result in csv_reader.records() {
67        let record = result?;
68        if rows == 0 {
69            cols = record.len();
70        }
71        let row: Vec<f64> = record
72            .iter()
73            .map(|s| s.parse::<f64>().unwrap_or(f64::NAN)) // Convert to f64
74            .collect();
75        data.extend(row);
76        rows += 1;
77    }
78
79    let array = Array2::from_shape_vec((rows, cols), data)?;
80    Ok(array)
81}
82
83/// Function to parse CSV and extract input & output columns
84///
85/// # Parameters:
86/// - `filepath: String` - Relative path of the CSV file wrt Cargo.toml file of the project
87/// - `output_columns: Vec<String>` - Column names to extract as output
88/// - `input_exclude_columns: Vec<String>` - Column names to be excluded from input
89///
90/// # Returns:
91/// - `Result<(Array2<f64>, Array2<f64>), Box<dyn Error>>` - Tuple (Input Array, Output Array)
92pub fn read_input_output(
93    filepath: String,
94    output_columns: Vec<String>,
95    input_exclude_columns: Vec<String>,
96) -> Result<(Array2<f64>, Array2<f64>), Box<dyn Error>> {
97    let file = File::open(filepath)?;
98    let reader = BufReader::new(file);
99    let mut csv_reader = ReaderBuilder::new().has_headers(true).from_reader(reader);
100
101    let headers = csv_reader.headers()?.clone();
102    let mut output_indices = Vec::new();
103    let mut input_indices = Vec::new();
104
105    // Identify column indices for input and output
106    for (i, header) in headers.iter().enumerate() {
107        if output_columns.contains(&header.to_string()) {
108            output_indices.push(i);
109        } else if !input_exclude_columns.contains(&header.to_string()) {
110            input_indices.push(i);
111        }
112    }
113
114    let mut input_data: Vec<f64> = Vec::new();
115    let mut output_data: Vec<f64> = Vec::new();
116    let mut row_count = 0;
117    let input_cols = input_indices.len();
118    let output_cols = output_indices.len();
119
120    for result in csv_reader.records() {
121        let record = result?;
122        for &i in &input_indices {
123            input_data.push(record[i].parse::<f64>().unwrap_or(0.0));
124        }
125        for &i in &output_indices {
126            output_data.push(record[i].parse::<f64>().unwrap_or(0.0));
127        }
128        row_count += 1;
129    }
130
131    // Convert to heap-allocated Array2
132    let input_array = Array2::from_shape_vec((row_count, input_cols), input_data)
133        .map_err(|_| "Shape mismatch in input array")?;
134    let output_array = Array2::from_shape_vec((row_count, output_cols), output_data)
135        .map_err(|_| "Shape mismatch in output array")?;
136
137    Ok((input_array, output_array))
138}
139
140/// Function to save a 2D Array as a CSV file
141///
142/// # Parameters:
143/// - `filepath: String` - Relative path of the CSV file wrt Cargo.toml file of the project
144/// - `headers: Vec<String>` - Column names
145/// - `array: Array2<f64>` - The array to be saved
146///
147/// # Returns:
148/// - `Result<(), Box<dyn Error>>`
149pub fn csv_write(
150    filepath: String,
151    headers: Vec<String>,
152    array: &Array2<f64>,
153) -> Result<(), Box<dyn Error>> {
154    let mut writer = WriterBuilder::new().from_path(filepath)?;
155
156    // Write headers
157    // writer.write_record(&headers)?;
158
159    // Write data
160    for row in array.axis_iter(ndarray::Axis(0)) {
161        let row_strings: Vec<String> = row.iter().map(|&val| val.to_string()).collect();
162        writer.write_record(&row_strings)?;
163    }
164
165    writer.flush()?;
166    Ok(())
167}