multi_skill/data_processing/
dataset.rs

1use serde::{de::DeserializeOwned, Serialize};
2use std::path::{Path, PathBuf};
3
4/// Generic `Dataset` trait, modeled after PyTorch's `utils.data.Dataset`.
5/// It represents a collection of objects indexed in the range `0..len()`.
6pub trait Dataset {
7    /// The type of objects procured by the `Dataset`.
8    type Item;
9    /// The number of objects in the `Dataset`.
10    fn len(&self) -> usize;
11    /// Get the `index`'th element, where `0 <= index < len()`
12    fn get(&self, index: usize) -> Self::Item;
13
14    /// Modifies the dataset to check a cache directory before reading.
15    /// If the cache entry is present, it's used instead of the underlying `get()`.
16    /// If the cache entry is absent, it will be created after calling `get()`.
17    // Due to the `Sized` bound, calling `cached()` on `dyn Dataset` trait objects
18    // requires the `impl Dataset` implementations below, for `&T` and `Box T`.
19    // If we wanted to avoid this complication, `CachedDataset` could have simply stored
20    // a `Box<dyn Dataset>`, at the expense of a pointer indirection per method call.
21    // Basically, our optimization allows `CachedDataset` to store `Dataset`s by value
22    // (and statically dispatch its methods) or by pointer (with dynamic dispatch), as needed.
23    fn cached(self, cache_dir: impl Into<PathBuf>) -> CachedDataset<Self>
24    where
25        Self: Sized,
26    {
27        let cache_dir = cache_dir.into();
28        std::fs::create_dir_all(&cache_dir).expect("Could not create cache directory");
29        CachedDataset {
30            base_dataset: self,
31            cache_dir,
32        }
33    }
34
35    /// Produces an `Iterator` that produces the entire `Dataset` in indexed order.
36    // I don't know how to implement `IntoIterator` on `Dataset`, so this is the next best thing.
37    // The return type must be a concrete type (either `Box` or custom `DatasetIterator`, not `impl`),
38    // in case some `impl Dataset` overrides `iter()`.
39    fn iter(&self) -> Box<dyn Iterator<Item = Self::Item> + '_> {
40        Box::new((0..self.len()).map(move |i| self.get(i)))
41    }
42}
43
44/// A slice can act as an in-memory `Dataset`.
45impl<T: Clone> Dataset for [T] {
46    type Item = T;
47
48    fn len(&self) -> usize {
49        self.len()
50    }
51
52    fn get(&self, index: usize) -> T {
53        self[index].clone()
54    }
55}
56
57/// References to `Dataset`s are also `Dataset`s.
58impl<D: Dataset + ?Sized> Dataset for &D {
59    type Item = D::Item;
60
61    fn len(&self) -> usize {
62        (**self).len()
63    }
64
65    fn get(&self, index: usize) -> Self::Item {
66        (**self).get(index)
67    }
68}
69
70impl<D: Dataset + ?Sized> Dataset for Box<D> {
71    type Item = D::Item;
72
73    fn len(&self) -> usize {
74        (**self).len()
75    }
76
77    fn get(&self, index: usize) -> Self::Item {
78        (**self).get(index)
79    }
80}
81
82/// A `Dataset` defined in terms of a closure, which acts as a "getter".
83pub struct ClosureDataset<T, F: Fn(usize) -> T> {
84    length: usize,
85    closure: F,
86}
87
88impl<T, F: Fn(usize) -> T> ClosureDataset<T, F> {
89    pub fn new(length: usize, closure: F) -> Self {
90        Self { length, closure }
91    }
92}
93
94impl<T, F: Fn(usize) -> T> Dataset for ClosureDataset<T, F> {
95    type Item = T;
96
97    fn len(&self) -> usize {
98        self.length
99    }
100
101    fn get(&self, index: usize) -> T {
102        (self.closure)(index)
103    }
104}
105
106/// A `Dataset` that uses a disk directory as its cache, useful when calls to `get()` are expensive.
107/// Created using `Dataset::cached()`.
108pub struct CachedDataset<D: Dataset> {
109    base_dataset: D,
110    cache_dir: PathBuf,
111}
112
113impl<D: Dataset> Dataset for CachedDataset<D>
114where
115    D::Item: Serialize + DeserializeOwned,
116{
117    type Item = D::Item;
118
119    fn len(&self) -> usize {
120        self.base_dataset.len()
121    }
122
123    fn get(&self, index: usize) -> Self::Item {
124        let cache_file = self.cache_dir.join(format!("{}.json", index));
125        // Try to read the contest from the cache
126        match std::fs::read_to_string(&cache_file) {
127            Ok(cached_json) => serde_json::from_str(&cached_json).expect("Failed to read cache"),
128            Err(_) => {
129                // The contest doesn't appear in our cache, so request it from the base dataset
130                let contest = self.base_dataset.get(index);
131
132                // Write the contest to the cache
133                super::write_to_json(&contest, &cache_file).expect("Failed to write to cache");
134                println!("Codeforces contest successfully cached at {:?}", cache_file);
135
136                contest
137            }
138        }
139    }
140}
141
142/// Helper function to get data that is already stored inside a disk directory.
143pub fn get_dataset_from_disk<T: Serialize + DeserializeOwned>(
144    dataset_dir: impl AsRef<Path>,
145) -> impl Dataset<Item = T> {
146    // Check that the directory exists and count the number of JSON files
147    let ext = Some(std::ffi::OsStr::new("json"));
148    let dataset_dir = dataset_dir.as_ref();
149    let length = std::fs::read_dir(dataset_dir)
150        .unwrap_or_else(|_| panic!("There's no dataset at {:?}", dataset_dir))
151        .filter(|file| file.as_ref().unwrap().path().extension() == ext)
152        .count();
153    println!("Found {} JSON files at {:?}", length, dataset_dir);
154
155    // Every entry should already be in the directory; if not, we should panic
156    ClosureDataset::new(length, |i| {
157        panic!("Expected to find contest {} in the cache, but didn't", i)
158    })
159    .cached(dataset_dir)
160}
161
162#[cfg(test)]
163mod test {
164    use super::*;
165
166    #[test]
167    fn test_in_memory_dataset() {
168        let vec = vec![5.7, 9.2, -1.5];
169        let dataset: Box<dyn Dataset<Item = f64>> = Box::new(vec.as_slice());
170
171        assert_eq!(dataset.len(), vec.len());
172        for (data_val, &vec_val) in dataset.iter().zip(vec.iter()) {
173            assert_eq!(data_val, vec_val);
174        }
175    }
176
177    #[test]
178    fn test_closure_dataset() {
179        let dataset = ClosureDataset::new(10, |x| x * x);
180
181        for (idx, val) in dataset.iter().enumerate() {
182            assert_eq!(val, idx * idx);
183        }
184    }
185
186    #[test]
187    fn test_cached_dataset() {
188        let length = 5;
189        let cache_dir = "temp_dir_containing_squares";
190        let cache = || std::fs::read_dir(cache_dir);
191        let fancy_item = |idx: usize| (idx.checked_sub(2), vec![idx * idx; idx]);
192
193        // Create a new directory
194        assert!(cache().is_err());
195        let data_from_fn = ClosureDataset::new(length, fancy_item).cached(cache_dir);
196
197        // Write into both a Vec and an empty directory
198        assert_eq!(cache().unwrap().count(), 0);
199        let data_into_vec = data_from_fn.iter().collect::<Vec<_>>();
200
201        // Read from a filled directory
202        assert_eq!(cache().unwrap().count(), length);
203        let data_from_disk = get_dataset_from_disk(cache_dir);
204
205        // Check all three views into the data for correctness
206        assert_eq!(data_from_fn.len(), length);
207        assert_eq!(data_into_vec.len(), length);
208        assert_eq!(data_from_disk.len(), length);
209        for idx in 0..length {
210            let expected = fancy_item(idx);
211            let data_from_disk_val: (Option<usize>, Vec<usize>) = data_from_disk.get(idx);
212            assert_eq!(data_from_fn.get(idx), expected);
213            assert_eq!(data_into_vec[idx], expected);
214            assert_eq!(data_from_disk_val, expected);
215        }
216
217        // Trash the directory
218        assert_eq!(cache().unwrap().count(), length);
219        std::fs::remove_dir_all(cache_dir).unwrap();
220        assert!(cache().is_err());
221    }
222}