burn_dataset/dataset/
in_memory.rs1use std::{
2 fs::File,
3 io::{BufRead, BufReader},
4 path::Path,
5};
6
7use serde::de::DeserializeOwned;
8
9use crate::Dataset;
10
11pub struct InMemDataset<I> {
13 items: Vec<I>,
14}
15
16impl<I> InMemDataset<I> {
17 pub fn new(items: Vec<I>) -> Self {
19 InMemDataset { items }
20 }
21}
22
23impl<I> Dataset<I> for InMemDataset<I>
24where
25 I: Clone + Send + Sync,
26{
27 fn get(&self, index: usize) -> Option<I> {
28 self.items.get(index).cloned()
29 }
30 fn len(&self) -> usize {
31 self.items.len()
32 }
33}
34
35impl<I> InMemDataset<I>
36where
37 I: Clone + DeserializeOwned,
38{
39 pub fn from_dataset(dataset: &impl Dataset<I>) -> Self {
41 let items: Vec<I> = dataset.iter().collect();
42 Self::new(items)
43 }
44
45 pub fn from_json_rows<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
49 let file = File::open(path)?;
50 let reader = BufReader::new(file);
51 let mut items = Vec::new();
52
53 for line in reader.lines() {
54 let item = serde_json::from_str(line.unwrap().as_str()).unwrap();
55 items.push(item);
56 }
57
58 let dataset = Self::new(items);
59
60 Ok(dataset)
61 }
62
63 pub fn from_csv<P: AsRef<Path>>(
73 path: P,
74 builder: &csv::ReaderBuilder,
75 ) -> Result<Self, std::io::Error> {
76 let mut rdr = builder.from_path(path)?;
77
78 let mut items = Vec::new();
79
80 for result in rdr.deserialize() {
81 let item: I = result?;
82 items.push(item);
83 }
84
85 let dataset = Self::new(items);
86
87 Ok(dataset)
88 }
89}
90
91#[cfg(test)]
92mod tests {
93
94 use super::*;
95 use crate::{SqliteDataset, test_data};
96
97 use rstest::{fixture, rstest};
98 use serde::{Deserialize, Serialize};
99
100 const DB_FILE: &str = "tests/data/sqlite-dataset.db";
101 const JSON_FILE: &str = "tests/data/dataset.json";
102 const CSV_FILE: &str = "tests/data/dataset.csv";
103 const CSV_FMT_FILE: &str = "tests/data/dataset-fmt.csv";
104
105 type SqlDs = SqliteDataset<Sample>;
106
107 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
108 pub struct Sample {
109 column_str: String,
110 column_bytes: Vec<u8>,
111 column_int: i64,
112 column_bool: bool,
113 column_float: f64,
114 }
115
116 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
117 pub struct SampleCsv {
118 column_str: String,
119 column_int: i64,
120 column_bool: bool,
121 column_float: f64,
122 }
123
124 #[fixture]
125 fn train_dataset() -> SqlDs {
126 SqliteDataset::from_db_file(DB_FILE, "train").unwrap()
127 }
128
129 #[rstest]
130 pub fn from_dataset(train_dataset: SqlDs) {
131 let dataset = InMemDataset::from_dataset(&train_dataset);
132
133 let non_existing_record_index: usize = 10;
134 let record_index: usize = 0;
135
136 assert_eq!(train_dataset.get(non_existing_record_index), None);
137 assert_eq!(dataset.get(record_index).unwrap().column_str, "HI1");
138 }
139
140 #[test]
141 pub fn from_json_rows() {
142 let dataset = InMemDataset::<Sample>::from_json_rows(JSON_FILE).unwrap();
143
144 let non_existing_record_index: usize = 10;
145 let record_index: usize = 1;
146
147 assert_eq!(dataset.get(non_existing_record_index), None);
148 assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
149 assert!(!dataset.get(record_index).unwrap().column_bool);
150 }
151
152 #[test]
153 pub fn from_csv_rows() {
154 let rdr = csv::ReaderBuilder::new();
155 let dataset = InMemDataset::<SampleCsv>::from_csv(CSV_FILE, &rdr).unwrap();
156
157 let non_existing_record_index: usize = 10;
158 let record_index: usize = 1;
159
160 assert_eq!(dataset.get(non_existing_record_index), None);
161 assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
162 assert_eq!(dataset.get(record_index).unwrap().column_int, 1);
163 assert!(!dataset.get(record_index).unwrap().column_bool);
164 assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0);
165 }
166
167 #[test]
168 pub fn from_csv_rows_fmt() {
169 let mut rdr = csv::ReaderBuilder::new();
170 let rdr = rdr.delimiter(b' ').has_headers(false);
171 let dataset = InMemDataset::<SampleCsv>::from_csv(CSV_FMT_FILE, rdr).unwrap();
172
173 let non_existing_record_index: usize = 10;
174 let record_index: usize = 1;
175
176 assert_eq!(dataset.get(non_existing_record_index), None);
177 assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
178 assert_eq!(dataset.get(record_index).unwrap().column_int, 1);
179 assert!(!dataset.get(record_index).unwrap().column_bool);
180 assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0);
181 }
182
183 #[test]
184 pub fn given_in_memory_dataset_when_iterate_should_iterate_though_all_items() {
185 let items_original = test_data::string_items();
186 let dataset = InMemDataset::new(items_original.clone());
187
188 let items: Vec<String> = dataset.iter().collect();
189
190 assert_eq!(items_original, items);
191 }
192}