1use std::path::Path;
6
7use crate::Datapoint;
8use crate::Dataset;
9
10impl<const COLS: usize, Data: Datapoint<COLS>> Dataset<COLS, Data> {
11 #[must_use]
15 pub fn new() -> Self {
16 Self {
17 labels: None,
18 data: Vec::new(),
19 }
20 }
21
22 pub fn push(&mut self, datapoint: Data) {
26 self.data.push(datapoint);
27 }
28
29 #[must_use]
33 pub fn n_rows(&self) -> usize {
34 match self.labels {
35 Some(_) => self.data.len() + 1,
36 None => self.data.len(),
37 }
38 }
39
40 #[must_use]
44 pub fn n_datapoints(&self) -> usize {
45 self.data.len()
46 }
47
48 #[must_use]
52 pub fn n_columns(&self) -> usize {
53 COLS
54 }
55
56 #[must_use]
60 pub fn get_labels(&self) -> Option<&[String; COLS]> {
61 self.labels.as_ref()
62 }
63
64 pub fn set_labels<'a, Labels>(&mut self, labels: Labels)
103 where
104 Labels: Into<Option<[&'a str; COLS]>>,
105 {
106 let labels: Option<[String; COLS]> = labels.into().map(|labels| {
107 labels
108 .into_iter()
109 .map(ToOwned::to_owned)
110 .collect::<Vec<String>>()
111 .try_into()
112 .expect("Failed to coerce vec into array")
113 });
114 self.labels = labels;
115 }
116
117 #[must_use]
131 pub fn with_labels<'a, Labels>(mut self, labels: Labels) -> Self
132 where
133 Labels: Into<Option<[&'a str; COLS]>>,
134 {
135 self.set_labels(labels);
136 self
137 }
138
139 #[must_use]
143 pub fn from_datapoints<IntoIter, Iter>(rows: IntoIter) -> Self
144 where
145 IntoIter: IntoIterator<Item = Data, IntoIter = Iter>,
146 Iter: Iterator<Item = Data>,
147 Data: Datapoint<COLS>,
148 {
149 Self {
150 labels: None,
151 data: rows.into_iter().collect(),
152 }
153 }
154}
155
156impl<const COLS: usize, Data: Datapoint<COLS>> Default for Dataset<COLS, Data> {
160 fn default() -> Self {
161 Self::new()
162 }
163}
164
165impl<const COLS: usize, DataElement: ToString> Dataset<COLS, [DataElement; COLS]> {
166 pub fn from_columns<IntoIter, Iter>(columns: [IntoIter; COLS]) -> Self
179 where
180 IntoIter: IntoIterator<Item = DataElement, IntoIter = Iter>,
181 Iter: Iterator<Item = DataElement>,
182 {
183 let mut columns: [Iter; COLS] = columns
184 .into_iter()
185 .map(IntoIterator::into_iter)
186 .collect::<Vec<Iter>>()
187 .try_into()
188 .map_err(|_| ())
189 .expect("Failed to coerce vec into array");
190 let mut data = Vec::new();
191 'outer: loop {
192 let mut temp = Vec::with_capacity(COLS);
193 for col in columns.iter_mut() {
194 if let Some(data) = col.next() {
195 temp.push(data);
196 } else {
197 break 'outer;
198 }
199 }
200 let row: [DataElement; COLS] = temp
202 .try_into()
203 .map_err(|_| ())
204 .expect("Failed to coerce vec into array");
205 data.push(row);
206 }
207
208 let labels = None;
209
210 Dataset { labels, data }
211 }
212}
213
214impl<const COLS: usize, Data: Datapoint<COLS>> Dataset<COLS, Data> {
215 pub fn save<P: AsRef<Path>>(self, filepath: P) -> Result<(), std::io::Error> {
244 let mut writer = csv::Writer::from_path(filepath)?;
245 if let Some(labels) = self.labels {
246 writer.write_record(&labels)?;
247 }
248 for datapoint in self.data {
249 writer.write_record(datapoint.record())?;
250 }
251 writer.flush()?;
252 Ok(())
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn new() {
262 let mut dataset = Dataset::new();
263 assert_eq!(dataset.n_datapoints(), 0);
264 dataset.push([1, 2, 3]);
265 assert_eq!(dataset.n_datapoints(), 1);
266 dataset.push([3, 4, 5]);
267 assert_eq!(dataset.n_datapoints(), 2);
268 assert_eq!(dataset.n_columns(), 3);
269 }
270
271 #[test]
272 fn labels() {
273 let x = [2, 3, 4];
274 let y = [5, 6, 7];
275 let mut dataset = Dataset::from_columns([x, y]);
276 assert_eq!(dataset.get_labels(), None);
277 dataset.set_labels(["x", "y"]);
278 assert_eq!(
279 dataset.get_labels(),
280 Some(&[String::from("x"), String::from("y")])
281 );
282 }
283
284 #[test]
285 fn size() {
286 let mut dataset = Dataset::new();
287 dataset.push([1, 2, 3]);
288 dataset.push([3, 4, 5]);
289 assert_eq!(dataset.n_columns(), 3);
290 assert_eq!(dataset.n_datapoints(), 2);
291 assert_eq!(dataset.n_rows(), 2);
292 dataset.set_labels(["a", "b", "c"]);
293 assert_eq!(dataset.n_columns(), 3);
294 assert_eq!(dataset.n_datapoints(), 2);
295 assert_eq!(dataset.n_rows(), 3);
296 }
297
298 fn check_size<const COLS: usize, Data: Datapoint<COLS>>(dataset: Dataset<COLS, Data>) {
300 assert_eq!(dataset.n_columns(), 2);
301 assert_eq!(dataset.n_rows(), 3);
302 }
303
304 #[test]
306 fn from_datapoints_array() {
307 let array = [[1, 2], [3, 4], [5, 6]];
308 let dataset = Dataset::from_datapoints(array);
309 println!("{:?}", dataset);
310 check_size(dataset);
311 }
312
313 #[test]
314 fn from_datapoints_iterator() {
315 let iterator = [[1, 2], [3, 4], [5, 6]].into_iter();
316 let dataset = Dataset::from_datapoints(iterator);
317 println!("{:?}", dataset);
318 check_size(dataset);
319 }
320
321 #[test]
322 fn from_datapoints_vec() {
323 let vector = vec![[1, 2], [3, 4], [5, 6]];
324 let dataset = Dataset::from_datapoints(vector);
325 println!("{:?}", dataset);
326 check_size(dataset);
327 }
328
329 #[test]
331 fn from_columns_array() {
332 let array = [[1, 3, 5], [2, 4, 6]];
333 let dataset = Dataset::from_columns(array);
334 println!("{:?}", dataset);
335 check_size(dataset);
336 }
337
338 #[test]
339 fn from_columns_iterator() {
340 let iterator = [[1, 3, 5].into_iter(), [2, 4, 6].into_iter()];
341 let dataset = Dataset::from_columns(iterator);
342 println!("{:?}", dataset);
343 check_size(dataset);
344 }
345
346 #[test]
347 fn from_columns_vec() {
348 let vector = [vec![1, 3, 5], vec![2, 4, 6]];
349 let dataset = Dataset::from_columns(vector);
350 println!("{:?}", dataset);
351 check_size(dataset);
352 }
353}