delfi/
dataset.rs

1/*!
2Implementations on the [Dataset] struct
3*/
4
5use std::path::Path;
6
7use crate::Datapoint;
8use crate::Dataset;
9
10impl<const COLS: usize, Data: Datapoint<COLS>> Dataset<COLS, Data> {
11    /**
12    Function for creating new (empty) dataset
13    */
14    #[must_use]
15    pub fn new() -> Self {
16        Self {
17            labels: None,
18            data: Vec::new(),
19        }
20    }
21
22    /**
23    Push a new row to the dataset
24    */
25    pub fn push(&mut self, datapoint: Data) {
26        self.data.push(datapoint);
27    }
28
29    /**
30    Get current number of rows in dataset, which is equal to the number of datapoints, plus 1 if there is a header row
31    */
32    #[must_use]
33    pub fn n_rows(&self) -> usize {
34        match self.labels {
35            Some(_) => self.data.len() + 1,
36            None => self.data.len(),
37        }
38    }
39
40    /**
41    Get current number of rows in dataset, which is equal to the number of datapoints, plus 1 if there is a header row
42    */
43    #[must_use]
44    pub fn n_datapoints(&self) -> usize {
45        self.data.len()
46    }
47
48    /**
49    Get current number of rows in dataset
50    */
51    #[must_use]
52    pub fn n_columns(&self) -> usize {
53        COLS
54    }
55
56    /**
57    Get current labels
58    */
59    #[must_use]
60    pub fn get_labels(&self) -> Option<&[String; COLS]> {
61        self.labels.as_ref()
62    }
63
64    /**
65    Set labels for the given dataset.
66    Constructors return dataset with labels set to None unless otherwise specified.
67
68    ```
69    use delfi::Dataset;
70
71    let t = [0, 1, 2, 3, 4, 5];
72    let x = [2, 3, 5, 8, 12, 17];
73    let mut dataset = Dataset::from_columns([t, x]);
74    dataset.set_labels(["time", "length"]);
75    ```
76
77    Labels can also be turned off
78
79    ```
80    # use delfi::Dataset;
81    #
82    # let t = [0, 1, 2, 3, 4, 5];
83    # let x = [2, 3, 5, 8, 12, 17];
84    # let mut dataset = Dataset::from_columns([t, x]);
85    # dataset.set_labels(["time", "length"]);
86    #
87    dataset.set_labels(None);
88    ```
89
90    They also technically accept labels to be passed via `Some(_)` (but why would you?):
91
92    ```
93    # use delfi::Dataset;
94    #
95    # let t = [0, 1, 2, 3, 4, 5];
96    # let x = [2, 3, 5, 8, 12, 17];
97    # let mut dataset = Dataset::from_columns([t, x]);
98    #
99    dataset.set_labels(Some(["time", "length"]));
100    ```
101    */
102    pub fn set_labels<'a, Labels>(&mut self, labels: Labels)
103    where
104        Labels: Into<Option<[&'a str; COLS]>>,
105    {
106        let labels: Option<[String; COLS]> = labels.into().map(|labels| {
107            labels
108                .into_iter()
109                .map(ToOwned::to_owned)
110                .collect::<Vec<String>>()
111                .try_into()
112                .expect("Failed to coerce vec into array")
113        });
114        self.labels = labels;
115    }
116
117    /**
118    Take dataset, set labels, and return dataset. Useful when constructing datasets.
119
120    ```
121    use delfi::Dataset;
122
123    let t = [0, 1, 2, 3, 4, 5];
124    let x = [2, 3, 5, 8, 12, 17];
125    let _ = Dataset::from_columns([&t, &x]).with_labels(["time", "length"]);
126    ```
127
128    See set_labels() for detail on possible parameters.
129    */
130    #[must_use]
131    pub fn with_labels<'a, Labels>(mut self, labels: Labels) -> Self
132    where
133        Labels: Into<Option<[&'a str; COLS]>>,
134    {
135        self.set_labels(labels);
136        self
137    }
138
139    /**
140    Create a dataset from an iterator over datapoints
141    */
142    #[must_use]
143    pub fn from_datapoints<IntoIter, Iter>(rows: IntoIter) -> Self
144    where
145        IntoIter: IntoIterator<Item = Data, IntoIter = Iter>,
146        Iter: Iterator<Item = Data>,
147        Data: Datapoint<COLS>,
148    {
149        Self {
150            labels: None,
151            data: rows.into_iter().collect(),
152        }
153    }
154}
155
156/**
157Default is equivalent to new
158*/
159impl<const COLS: usize, Data: Datapoint<COLS>> Default for Dataset<COLS, Data> {
160    fn default() -> Self {
161        Self::new()
162    }
163}
164
165impl<const COLS: usize, DataElement: ToString> Dataset<COLS, [DataElement; COLS]> {
166    /**
167    Takes in a set of columns and creates a dataset from these.
168
169    # Examples
170    ```
171    use delfi::Dataset;
172
173    let t = [0, 1, 2, 3, 4, 5];
174    let x = [2, 3, 5, 8, 12, 17];
175    let _ = Dataset::from_columns([t, x]);
176    ```
177    */
178    pub fn from_columns<IntoIter, Iter>(columns: [IntoIter; COLS]) -> Self
179    where
180        IntoIter: IntoIterator<Item = DataElement, IntoIter = Iter>,
181        Iter: Iterator<Item = DataElement>,
182    {
183        let mut columns: [Iter; COLS] = columns
184            .into_iter()
185            .map(IntoIterator::into_iter)
186            .collect::<Vec<Iter>>()
187            .try_into()
188            .map_err(|_| ())
189            .expect("Failed to coerce vec into array");
190        let mut data = Vec::new();
191        'outer: loop {
192            let mut temp = Vec::with_capacity(COLS);
193            for col in columns.iter_mut() {
194                if let Some(data) = col.next() {
195                    temp.push(data);
196                } else {
197                    break 'outer;
198                }
199            }
200            // map_err is required to avoid restricting Debug to be implemented for IntoIterator
201            let row: [DataElement; COLS] = temp
202                .try_into()
203                .map_err(|_| ())
204                .expect("Failed to coerce vec into array");
205            data.push(row);
206        }
207
208        let labels = None;
209
210        Dataset { labels, data }
211    }
212}
213
214impl<const COLS: usize, Data: Datapoint<COLS>> Dataset<COLS, Data> {
215    /**
216    Saves a dataset to a given file. The filepath must be valid.
217    Accepts anything path-like.
218
219    # Examples
220    ```
221    # use delfi::Dataset;
222    #
223    # let t = [0, 1, 2, 3, 4, 5];
224    # let x = [2, 3, 5, 8, 12, 17];
225    # let dataset = Dataset::from_columns([t, x]);
226    #
227    dataset.save("./resources/data/examples/save-short.csv").unwrap();
228    ```
229
230    ```
231    # use delfi::Dataset;
232    #
233    # let t = [0, 1, 2, 3, 4, 5];
234    # let x = [2, 3, 5, 8, 12, 17];
235    # let dataset = Dataset::from_columns([t, x]);
236    #
237    let directory = std::fs::canonicalize("./resources/data/examples/").unwrap();
238    let filepath = directory.join("save-long.csv");
239    dataset.save(&filepath).unwrap();
240    ```
241
242    */
243    pub fn save<P: AsRef<Path>>(self, filepath: P) -> Result<(), std::io::Error> {
244        let mut writer = csv::Writer::from_path(filepath)?;
245        if let Some(labels) = self.labels {
246            writer.write_record(&labels)?;
247        }
248        for datapoint in self.data {
249            writer.write_record(datapoint.record())?;
250        }
251        writer.flush()?;
252        Ok(())
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn new() {
262        let mut dataset = Dataset::new();
263        assert_eq!(dataset.n_datapoints(), 0);
264        dataset.push([1, 2, 3]);
265        assert_eq!(dataset.n_datapoints(), 1);
266        dataset.push([3, 4, 5]);
267        assert_eq!(dataset.n_datapoints(), 2);
268        assert_eq!(dataset.n_columns(), 3);
269    }
270
271    #[test]
272    fn labels() {
273        let x = [2, 3, 4];
274        let y = [5, 6, 7];
275        let mut dataset = Dataset::from_columns([x, y]);
276        assert_eq!(dataset.get_labels(), None);
277        dataset.set_labels(["x", "y"]);
278        assert_eq!(
279            dataset.get_labels(),
280            Some(&[String::from("x"), String::from("y")])
281        );
282    }
283
284    #[test]
285    fn size() {
286        let mut dataset = Dataset::new();
287        dataset.push([1, 2, 3]);
288        dataset.push([3, 4, 5]);
289        assert_eq!(dataset.n_columns(), 3);
290        assert_eq!(dataset.n_datapoints(), 2);
291        assert_eq!(dataset.n_rows(), 2);
292        dataset.set_labels(["a", "b", "c"]);
293        assert_eq!(dataset.n_columns(), 3);
294        assert_eq!(dataset.n_datapoints(), 2);
295        assert_eq!(dataset.n_rows(), 3);
296    }
297
298    // Check constructors
299    fn check_size<const COLS: usize, Data: Datapoint<COLS>>(dataset: Dataset<COLS, Data>) {
300        assert_eq!(dataset.n_columns(), 2);
301        assert_eq!(dataset.n_rows(), 3);
302    }
303
304    // Rows
305    #[test]
306    fn from_datapoints_array() {
307        let array = [[1, 2], [3, 4], [5, 6]];
308        let dataset = Dataset::from_datapoints(array);
309        println!("{:?}", dataset);
310        check_size(dataset);
311    }
312
313    #[test]
314    fn from_datapoints_iterator() {
315        let iterator = [[1, 2], [3, 4], [5, 6]].into_iter();
316        let dataset = Dataset::from_datapoints(iterator);
317        println!("{:?}", dataset);
318        check_size(dataset);
319    }
320
321    #[test]
322    fn from_datapoints_vec() {
323        let vector = vec![[1, 2], [3, 4], [5, 6]];
324        let dataset = Dataset::from_datapoints(vector);
325        println!("{:?}", dataset);
326        check_size(dataset);
327    }
328
329    // Columns
330    #[test]
331    fn from_columns_array() {
332        let array = [[1, 3, 5], [2, 4, 6]];
333        let dataset = Dataset::from_columns(array);
334        println!("{:?}", dataset);
335        check_size(dataset);
336    }
337
338    #[test]
339    fn from_columns_iterator() {
340        let iterator = [[1, 3, 5].into_iter(), [2, 4, 6].into_iter()];
341        let dataset = Dataset::from_columns(iterator);
342        println!("{:?}", dataset);
343        check_size(dataset);
344    }
345
346    #[test]
347    fn from_columns_vec() {
348        let vector = [vec![1, 3, 5], vec![2, 4, 6]];
349        let dataset = Dataset::from_columns(vector);
350        println!("{:?}", dataset);
351        check_size(dataset);
352    }
353}