burn_dataset/dataset/
in_memory.rs

1use std::{
2    fs::File,
3    io::{BufRead, BufReader},
4    path::Path,
5};
6
7use serde::de::DeserializeOwned;
8
9use crate::Dataset;
10
11/// Dataset where all items are stored in ram.
12pub struct InMemDataset<I> {
13    items: Vec<I>,
14}
15
16impl<I> InMemDataset<I> {
17    /// Creates a new in memory dataset from the given items.
18    pub fn new(items: Vec<I>) -> Self {
19        InMemDataset { items }
20    }
21}
22
23impl<I> Dataset<I> for InMemDataset<I>
24where
25    I: Clone + Send + Sync,
26{
27    fn get(&self, index: usize) -> Option<I> {
28        self.items.get(index).cloned()
29    }
30    fn len(&self) -> usize {
31        self.items.len()
32    }
33}
34
35impl<I> InMemDataset<I>
36where
37    I: Clone + DeserializeOwned,
38{
39    /// Create from a dataset. All items are loaded in memory.
40    pub fn from_dataset(dataset: &impl Dataset<I>) -> Self {
41        let items: Vec<I> = dataset.iter().collect();
42        Self::new(items)
43    }
44
45    /// Create from a json rows file (one json per line).
46    ///
47    /// [Supported field types](https://docs.rs/serde_json/latest/serde_json/value/enum.Value.html)
48    pub fn from_json_rows<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
49        let file = File::open(path)?;
50        let reader = BufReader::new(file);
51        let mut items = Vec::new();
52
53        for line in reader.lines() {
54            let item = serde_json::from_str(line.unwrap().as_str()).unwrap();
55            items.push(item);
56        }
57
58        let dataset = Self::new(items);
59
60        Ok(dataset)
61    }
62
63    /// Create from a csv file.
64    ///
65    /// The provided `csv::ReaderBuilder` can be configured to fit your csv format.
66    ///
67    /// The supported field types are: String, integer, float, and bool.
68    ///
69    /// See:
70    /// - [Reading with Serde](https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde)
71    /// - [Delimiters, quotes and variable length records](https://docs.rs/csv/latest/csv/tutorial/index.html#delimiters-quotes-and-variable-length-records)
72    pub fn from_csv<P: AsRef<Path>>(
73        path: P,
74        builder: &csv::ReaderBuilder,
75    ) -> Result<Self, std::io::Error> {
76        let mut rdr = builder.from_path(path)?;
77
78        let mut items = Vec::new();
79
80        for result in rdr.deserialize() {
81            let item: I = result?;
82            items.push(item);
83        }
84
85        let dataset = Self::new(items);
86
87        Ok(dataset)
88    }
89}
90
91#[cfg(test)]
92mod tests {
93
94    use super::*;
95    use crate::{SqliteDataset, test_data};
96
97    use rstest::{fixture, rstest};
98    use serde::{Deserialize, Serialize};
99
100    const DB_FILE: &str = "tests/data/sqlite-dataset.db";
101    const JSON_FILE: &str = "tests/data/dataset.json";
102    const CSV_FILE: &str = "tests/data/dataset.csv";
103    const CSV_FMT_FILE: &str = "tests/data/dataset-fmt.csv";
104
105    type SqlDs = SqliteDataset<Sample>;
106
107    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
108    pub struct Sample {
109        column_str: String,
110        column_bytes: Vec<u8>,
111        column_int: i64,
112        column_bool: bool,
113        column_float: f64,
114    }
115
116    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
117    pub struct SampleCsv {
118        column_str: String,
119        column_int: i64,
120        column_bool: bool,
121        column_float: f64,
122    }
123
124    #[fixture]
125    fn train_dataset() -> SqlDs {
126        SqliteDataset::from_db_file(DB_FILE, "train").unwrap()
127    }
128
129    #[rstest]
130    pub fn from_dataset(train_dataset: SqlDs) {
131        let dataset = InMemDataset::from_dataset(&train_dataset);
132
133        let non_existing_record_index: usize = 10;
134        let record_index: usize = 0;
135
136        assert_eq!(train_dataset.get(non_existing_record_index), None);
137        assert_eq!(dataset.get(record_index).unwrap().column_str, "HI1");
138    }
139
140    #[test]
141    pub fn from_json_rows() {
142        let dataset = InMemDataset::<Sample>::from_json_rows(JSON_FILE).unwrap();
143
144        let non_existing_record_index: usize = 10;
145        let record_index: usize = 1;
146
147        assert_eq!(dataset.get(non_existing_record_index), None);
148        assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
149        assert!(!dataset.get(record_index).unwrap().column_bool);
150    }
151
152    #[test]
153    pub fn from_csv_rows() {
154        let rdr = csv::ReaderBuilder::new();
155        let dataset = InMemDataset::<SampleCsv>::from_csv(CSV_FILE, &rdr).unwrap();
156
157        let non_existing_record_index: usize = 10;
158        let record_index: usize = 1;
159
160        assert_eq!(dataset.get(non_existing_record_index), None);
161        assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
162        assert_eq!(dataset.get(record_index).unwrap().column_int, 1);
163        assert!(!dataset.get(record_index).unwrap().column_bool);
164        assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0);
165    }
166
167    #[test]
168    pub fn from_csv_rows_fmt() {
169        let mut rdr = csv::ReaderBuilder::new();
170        let rdr = rdr.delimiter(b' ').has_headers(false);
171        let dataset = InMemDataset::<SampleCsv>::from_csv(CSV_FMT_FILE, rdr).unwrap();
172
173        let non_existing_record_index: usize = 10;
174        let record_index: usize = 1;
175
176        assert_eq!(dataset.get(non_existing_record_index), None);
177        assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
178        assert_eq!(dataset.get(record_index).unwrap().column_int, 1);
179        assert!(!dataset.get(record_index).unwrap().column_bool);
180        assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0);
181    }
182
183    #[test]
184    pub fn given_in_memory_dataset_when_iterate_should_iterate_though_all_items() {
185        let items_original = test_data::string_items();
186        let dataset = InMemDataset::new(items_original.clone());
187
188        let items: Vec<String> = dataset.iter().collect();
189
190        assert_eq!(items_original, items);
191    }
192}