1use std::fs;
8use std::path::Path;
9
10use crate::dataset::{Dataset, Sample};
11
12#[derive(Debug)]
27pub struct CsvDataset {
28 samples: Vec<Sample>,
29 feature_shape: Vec<usize>,
30 target_shape: Vec<usize>,
31}
32
33#[derive(Debug, Clone)]
35pub struct CsvConfig {
36 pub has_header: bool,
38 pub feature_cols: Vec<usize>,
40 pub target_cols: Vec<usize>,
42 pub delimiter: u8,
44}
45
46impl Default for CsvConfig {
47 fn default() -> Self {
48 Self {
49 has_header: true,
50 feature_cols: Vec::new(),
51 target_cols: Vec::new(),
52 delimiter: b',',
53 }
54 }
55}
56
57impl CsvConfig {
58 pub fn has_header(mut self, h: bool) -> Self {
59 self.has_header = h;
60 self
61 }
62 pub fn feature_cols(mut self, cols: Vec<usize>) -> Self {
63 self.feature_cols = cols;
64 self
65 }
66 pub fn target_cols(mut self, cols: Vec<usize>) -> Self {
67 self.target_cols = cols;
68 self
69 }
70 pub fn delimiter(mut self, d: u8) -> Self {
71 self.delimiter = d;
72 self
73 }
74}
75
76impl CsvDataset {
77 pub fn load<P: AsRef<Path>>(path: P, config: CsvConfig) -> Result<Self, String> {
79 let content = fs::read_to_string(path.as_ref())
80 .map_err(|e| format!("CsvDataset: failed to read {:?}: {}", path.as_ref(), e))?;
81 Self::from_string(&content, config)
82 }
83
84 pub fn from_string(content: &str, config: CsvConfig) -> Result<Self, String> {
86 let delim = config.delimiter as char;
87 let lines: Vec<&str> = content.lines().filter(|l| !l.trim().is_empty()).collect();
88
89 if lines.is_empty() {
90 return Err("CsvDataset: empty CSV".to_string());
91 }
92
93 let start = if config.has_header { 1 } else { 0 };
94 if start >= lines.len() {
95 return Err("CsvDataset: CSV has only a header, no data".to_string());
96 }
97
98 let first_row: Vec<&str> = lines[start].split(delim).collect();
100 let num_cols = first_row.len();
101
102 let feat_cols = if config.feature_cols.is_empty() {
103 (0..num_cols.saturating_sub(1)).collect::<Vec<_>>()
105 } else {
106 config.feature_cols
107 };
108
109 let tgt_cols = if config.target_cols.is_empty() {
110 vec![num_cols - 1]
112 } else {
113 config.target_cols
114 };
115
116 let mut samples = Vec::with_capacity(lines.len() - start);
117
118 for (line_no, &line) in lines[start..].iter().enumerate() {
119 let cols: Vec<&str> = line.split(delim).collect();
120 if cols.len() != num_cols {
121 return Err(format!(
122 "CsvDataset: line {} has {} columns, expected {}",
123 line_no + start + 1,
124 cols.len(),
125 num_cols
126 ));
127 }
128
129 let mut features = Vec::with_capacity(feat_cols.len());
130 for &c in &feat_cols {
131 let val: f64 = cols[c].trim().parse().map_err(|e| {
132 format!(
133 "CsvDataset: line {}, col {}: parse error: {}",
134 line_no + start + 1,
135 c,
136 e
137 )
138 })?;
139 features.push(val);
140 }
141
142 let mut target = Vec::with_capacity(tgt_cols.len());
143 for &c in &tgt_cols {
144 let val: f64 = cols[c].trim().parse().map_err(|e| {
145 format!(
146 "CsvDataset: line {}, col {}: parse error: {}",
147 line_no + start + 1,
148 c,
149 e
150 )
151 })?;
152 target.push(val);
153 }
154
155 samples.push(Sample {
156 features,
157 feature_shape: vec![feat_cols.len()],
158 target,
159 target_shape: vec![tgt_cols.len()],
160 });
161 }
162
163 let feature_shape = vec![feat_cols.len()];
164 let target_shape = vec![tgt_cols.len()];
165
166 Ok(Self {
167 samples,
168 feature_shape,
169 target_shape,
170 })
171 }
172}
173
174impl Dataset for CsvDataset {
175 fn len(&self) -> usize {
176 self.samples.len()
177 }
178
179 fn get(&self, index: usize) -> Sample {
180 self.samples[index].clone()
181 }
182
183 fn feature_shape(&self) -> &[usize] {
184 &self.feature_shape
185 }
186
187 fn target_shape(&self) -> &[usize] {
188 &self.target_shape
189 }
190
191 fn name(&self) -> &str {
192 "csv"
193 }
194}
195
196#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn csv_with_header() {
204 let csv = "a,b,c\n1.0,2.0,0.0\n3.0,4.0,1.0\n5.0,6.0,0.0\n";
205 let config = CsvConfig::default();
206 let ds = CsvDataset::from_string(csv, config).unwrap();
207 assert_eq!(ds.len(), 3);
208 assert_eq!(ds.feature_shape(), &[2]);
209 assert_eq!(ds.target_shape(), &[1]);
210 assert_eq!(ds.get(0).features, vec![1.0, 2.0]);
211 assert_eq!(ds.get(0).target, vec![0.0]);
212 assert_eq!(ds.get(2).features, vec![5.0, 6.0]);
213 }
214
215 #[test]
216 fn csv_no_header() {
217 let csv = "1.0,2.0,3.0\n4.0,5.0,6.0\n";
218 let config = CsvConfig::default().has_header(false);
219 let ds = CsvDataset::from_string(csv, config).unwrap();
220 assert_eq!(ds.len(), 2);
221 assert_eq!(ds.get(0).features, vec![1.0, 2.0]);
222 assert_eq!(ds.get(0).target, vec![3.0]);
223 }
224
225 #[test]
226 fn csv_custom_columns() {
227 let csv = "a,b,c,d\n1,2,3,4\n5,6,7,8\n";
228 let config = CsvConfig::default()
229 .feature_cols(vec![0, 2])
230 .target_cols(vec![1, 3]);
231 let ds = CsvDataset::from_string(csv, config).unwrap();
232 assert_eq!(ds.feature_shape(), &[2]);
233 assert_eq!(ds.target_shape(), &[2]);
234 assert_eq!(ds.get(0).features, vec![1.0, 3.0]);
235 assert_eq!(ds.get(0).target, vec![2.0, 4.0]);
236 }
237
238 #[test]
239 fn csv_tab_delimiter() {
240 let csv = "a\tb\tc\n1.0\t2.0\t0.0\n3.0\t4.0\t1.0\n";
241 let config = CsvConfig::default().delimiter(b'\t');
242 let ds = CsvDataset::from_string(csv, config).unwrap();
243 assert_eq!(ds.len(), 2);
244 assert_eq!(ds.get(0).features, vec![1.0, 2.0]);
245 }
246
247 #[test]
248 fn csv_parse_error() {
249 let csv = "a,b,c\n1.0,hello,0.0\n";
250 let config = CsvConfig::default();
251 let result = CsvDataset::from_string(csv, config);
252 assert!(result.is_err());
253 assert!(result.unwrap_err().contains("parse error"));
254 }
255
256 #[test]
257 fn csv_empty() {
258 let csv = "";
259 let result = CsvDataset::from_string(csv, CsvConfig::default());
260 assert!(result.is_err());
261 }
262}