1use csv::{ReaderBuilder, WriterBuilder};
2use ndarray::{Array2, Axis};
3use rand::seq::SliceRandom;
4use rand::thread_rng;
5use std::boxed::Box;
6use std::error::Error;
7use std::fs::File;
8use std::io::BufReader;
9
10pub fn train_test_split(
21 x: Array2<f64>,
22 y: Array2<f64>,
23 split_ratio: f64,
24) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>, Array2<f64>), Box<dyn Error>> {
25 if split_ratio <= 0.0 || split_ratio >= 1.0 {
26 return Err("Split ratio should be between 0 and 1".into());
27 }
28
29 let num_samples = x.nrows();
30 let num_train = (num_samples as f64 * split_ratio).round() as usize;
31
32 let mut indices: Vec<usize> = (0..num_samples).collect();
33 let mut rng = thread_rng();
34 indices.shuffle(&mut rng);
35
36 let train_indices = &indices[..num_train];
37 let test_indices = &indices[num_train..];
38
39 let x_train = x.select(Axis(0), train_indices);
40 let y_train = y.select(Axis(0), train_indices);
41 let x_test = x.select(Axis(0), test_indices);
42 let y_test = y.select(Axis(0), test_indices);
43
44 Ok((x_train, y_train, x_test, y_test))
45}
46
47pub fn read(filepath: String) -> Result<Array2<f64>, Box<dyn Error>> {
55 let file = File::open(filepath)?;
56 let reader = BufReader::new(file);
57
58 let mut csv_reader = ReaderBuilder::new()
59 .has_headers(false) .from_reader(reader);
61
62 let mut data: Vec<f64> = Vec::new();
63 let mut rows = 0;
64 let mut cols = 0;
65
66 for result in csv_reader.records() {
67 let record = result?;
68 if rows == 0 {
69 cols = record.len();
70 }
71 let row: Vec<f64> = record
72 .iter()
73 .map(|s| s.parse::<f64>().unwrap_or(f64::NAN)) .collect();
75 data.extend(row);
76 rows += 1;
77 }
78
79 let array = Array2::from_shape_vec((rows, cols), data)?;
80 Ok(array)
81}
82
83pub fn read_input_output(
93 filepath: String,
94 output_columns: Vec<String>,
95 input_exclude_columns: Vec<String>,
96) -> Result<(Array2<f64>, Array2<f64>), Box<dyn Error>> {
97 let file = File::open(filepath)?;
98 let reader = BufReader::new(file);
99 let mut csv_reader = ReaderBuilder::new().has_headers(true).from_reader(reader);
100
101 let headers = csv_reader.headers()?.clone();
102 let mut output_indices = Vec::new();
103 let mut input_indices = Vec::new();
104
105 for (i, header) in headers.iter().enumerate() {
107 if output_columns.contains(&header.to_string()) {
108 output_indices.push(i);
109 } else if !input_exclude_columns.contains(&header.to_string()) {
110 input_indices.push(i);
111 }
112 }
113
114 let mut input_data: Vec<f64> = Vec::new();
115 let mut output_data: Vec<f64> = Vec::new();
116 let mut row_count = 0;
117 let input_cols = input_indices.len();
118 let output_cols = output_indices.len();
119
120 for result in csv_reader.records() {
121 let record = result?;
122 for &i in &input_indices {
123 input_data.push(record[i].parse::<f64>().unwrap_or(0.0));
124 }
125 for &i in &output_indices {
126 output_data.push(record[i].parse::<f64>().unwrap_or(0.0));
127 }
128 row_count += 1;
129 }
130
131 let input_array = Array2::from_shape_vec((row_count, input_cols), input_data)
133 .map_err(|_| "Shape mismatch in input array")?;
134 let output_array = Array2::from_shape_vec((row_count, output_cols), output_data)
135 .map_err(|_| "Shape mismatch in output array")?;
136
137 Ok((input_array, output_array))
138}
139
140pub fn csv_write(
150 filepath: String,
151 headers: Vec<String>,
152 array: &Array2<f64>,
153) -> Result<(), Box<dyn Error>> {
154 let mut writer = WriterBuilder::new().from_path(filepath)?;
155
156 for row in array.axis_iter(ndarray::Axis(0)) {
161 let row_strings: Vec<String> = row.iter().map(|&val| val.to_string()).collect();
162 writer.write_record(&row_strings)?;
163 }
164
165 writer.flush()?;
166 Ok(())
167}