multi_skill/data_processing/
dataset.rs1use serde::{de::DeserializeOwned, Serialize};
2use std::path::{Path, PathBuf};
3
4pub trait Dataset {
7 type Item;
9 fn len(&self) -> usize;
11 fn get(&self, index: usize) -> Self::Item;
13
14 fn cached(self, cache_dir: impl Into<PathBuf>) -> CachedDataset<Self>
24 where
25 Self: Sized,
26 {
27 let cache_dir = cache_dir.into();
28 std::fs::create_dir_all(&cache_dir).expect("Could not create cache directory");
29 CachedDataset {
30 base_dataset: self,
31 cache_dir,
32 }
33 }
34
35 fn iter(&self) -> Box<dyn Iterator<Item = Self::Item> + '_> {
40 Box::new((0..self.len()).map(move |i| self.get(i)))
41 }
42}
43
44impl<T: Clone> Dataset for [T] {
46 type Item = T;
47
48 fn len(&self) -> usize {
49 self.len()
50 }
51
52 fn get(&self, index: usize) -> T {
53 self[index].clone()
54 }
55}
56
57impl<D: Dataset + ?Sized> Dataset for &D {
59 type Item = D::Item;
60
61 fn len(&self) -> usize {
62 (**self).len()
63 }
64
65 fn get(&self, index: usize) -> Self::Item {
66 (**self).get(index)
67 }
68}
69
70impl<D: Dataset + ?Sized> Dataset for Box<D> {
71 type Item = D::Item;
72
73 fn len(&self) -> usize {
74 (**self).len()
75 }
76
77 fn get(&self, index: usize) -> Self::Item {
78 (**self).get(index)
79 }
80}
81
82pub struct ClosureDataset<T, F: Fn(usize) -> T> {
84 length: usize,
85 closure: F,
86}
87
88impl<T, F: Fn(usize) -> T> ClosureDataset<T, F> {
89 pub fn new(length: usize, closure: F) -> Self {
90 Self { length, closure }
91 }
92}
93
94impl<T, F: Fn(usize) -> T> Dataset for ClosureDataset<T, F> {
95 type Item = T;
96
97 fn len(&self) -> usize {
98 self.length
99 }
100
101 fn get(&self, index: usize) -> T {
102 (self.closure)(index)
103 }
104}
105
106pub struct CachedDataset<D: Dataset> {
109 base_dataset: D,
110 cache_dir: PathBuf,
111}
112
113impl<D: Dataset> Dataset for CachedDataset<D>
114where
115 D::Item: Serialize + DeserializeOwned,
116{
117 type Item = D::Item;
118
119 fn len(&self) -> usize {
120 self.base_dataset.len()
121 }
122
123 fn get(&self, index: usize) -> Self::Item {
124 let cache_file = self.cache_dir.join(format!("{}.json", index));
125 match std::fs::read_to_string(&cache_file) {
127 Ok(cached_json) => serde_json::from_str(&cached_json).expect("Failed to read cache"),
128 Err(_) => {
129 let contest = self.base_dataset.get(index);
131
132 super::write_to_json(&contest, &cache_file).expect("Failed to write to cache");
134 println!("Codeforces contest successfully cached at {:?}", cache_file);
135
136 contest
137 }
138 }
139 }
140}
141
142pub fn get_dataset_from_disk<T: Serialize + DeserializeOwned>(
144 dataset_dir: impl AsRef<Path>,
145) -> impl Dataset<Item = T> {
146 let ext = Some(std::ffi::OsStr::new("json"));
148 let dataset_dir = dataset_dir.as_ref();
149 let length = std::fs::read_dir(dataset_dir)
150 .unwrap_or_else(|_| panic!("There's no dataset at {:?}", dataset_dir))
151 .filter(|file| file.as_ref().unwrap().path().extension() == ext)
152 .count();
153 println!("Found {} JSON files at {:?}", length, dataset_dir);
154
155 ClosureDataset::new(length, |i| {
157 panic!("Expected to find contest {} in the cache, but didn't", i)
158 })
159 .cached(dataset_dir)
160}
161
162#[cfg(test)]
163mod test {
164 use super::*;
165
166 #[test]
167 fn test_in_memory_dataset() {
168 let vec = vec![5.7, 9.2, -1.5];
169 let dataset: Box<dyn Dataset<Item = f64>> = Box::new(vec.as_slice());
170
171 assert_eq!(dataset.len(), vec.len());
172 for (data_val, &vec_val) in dataset.iter().zip(vec.iter()) {
173 assert_eq!(data_val, vec_val);
174 }
175 }
176
177 #[test]
178 fn test_closure_dataset() {
179 let dataset = ClosureDataset::new(10, |x| x * x);
180
181 for (idx, val) in dataset.iter().enumerate() {
182 assert_eq!(val, idx * idx);
183 }
184 }
185
186 #[test]
187 fn test_cached_dataset() {
188 let length = 5;
189 let cache_dir = "temp_dir_containing_squares";
190 let cache = || std::fs::read_dir(cache_dir);
191 let fancy_item = |idx: usize| (idx.checked_sub(2), vec![idx * idx; idx]);
192
193 assert!(cache().is_err());
195 let data_from_fn = ClosureDataset::new(length, fancy_item).cached(cache_dir);
196
197 assert_eq!(cache().unwrap().count(), 0);
199 let data_into_vec = data_from_fn.iter().collect::<Vec<_>>();
200
201 assert_eq!(cache().unwrap().count(), length);
203 let data_from_disk = get_dataset_from_disk(cache_dir);
204
205 assert_eq!(data_from_fn.len(), length);
207 assert_eq!(data_into_vec.len(), length);
208 assert_eq!(data_from_disk.len(), length);
209 for idx in 0..length {
210 let expected = fancy_item(idx);
211 let data_from_disk_val: (Option<usize>, Vec<usize>) = data_from_disk.get(idx);
212 assert_eq!(data_from_fn.get(idx), expected);
213 assert_eq!(data_into_vec[idx], expected);
214 assert_eq!(data_from_disk_val, expected);
215 }
216
217 assert_eq!(cache().unwrap().count(), length);
219 std::fs::remove_dir_all(cache_dir).unwrap();
220 assert!(cache().is_err());
221 }
222}