use serde::{de::DeserializeOwned, Serialize};
use std::path::{Path, PathBuf};
pub trait Dataset {
type Item;
fn len(&self) -> usize;
fn get(&self, index: usize) -> Self::Item;
fn cached(self, cache_dir: impl Into<PathBuf>) -> CachedDataset<Self>
where
Self: Sized,
{
let cache_dir = cache_dir.into();
std::fs::create_dir_all(&cache_dir).expect("Could not create cache directory");
CachedDataset {
base_dataset: self,
cache_dir,
}
}
fn iter(&self) -> Box<dyn Iterator<Item = Self::Item> + '_> {
Box::new((0..self.len()).map(move |i| self.get(i)))
}
}
impl<T: Clone> Dataset for [T] {
type Item = T;
fn len(&self) -> usize {
self.len()
}
fn get(&self, index: usize) -> T {
self[index].clone()
}
}
impl<D: Dataset + ?Sized> Dataset for &D {
type Item = D::Item;
fn len(&self) -> usize {
(**self).len()
}
fn get(&self, index: usize) -> Self::Item {
(**self).get(index)
}
}
impl<D: Dataset + ?Sized> Dataset for Box<D> {
type Item = D::Item;
fn len(&self) -> usize {
(**self).len()
}
fn get(&self, index: usize) -> Self::Item {
(**self).get(index)
}
}
pub struct ClosureDataset<T, F: Fn(usize) -> T> {
length: usize,
closure: F,
}
impl<T, F: Fn(usize) -> T> ClosureDataset<T, F> {
pub fn new(length: usize, closure: F) -> Self {
Self { length, closure }
}
}
impl<T, F: Fn(usize) -> T> Dataset for ClosureDataset<T, F> {
type Item = T;
fn len(&self) -> usize {
self.length
}
fn get(&self, index: usize) -> T {
(self.closure)(index)
}
}
pub struct CachedDataset<D: Dataset> {
base_dataset: D,
cache_dir: PathBuf,
}
impl<D: Dataset> Dataset for CachedDataset<D>
where
D::Item: Serialize + DeserializeOwned,
{
type Item = D::Item;
fn len(&self) -> usize {
self.base_dataset.len()
}
fn get(&self, index: usize) -> Self::Item {
let cache_file = self.cache_dir.join(format!("{}.json", index));
match std::fs::read_to_string(&cache_file) {
Ok(cached_json) => serde_json::from_str(&cached_json).expect("Failed to read cache"),
Err(_) => {
let contest = self.base_dataset.get(index);
super::write_to_json(&contest, &cache_file).expect("Failed to write to cache");
println!("Codeforces contest successfully cached at {:?}", cache_file);
contest
}
}
}
}
pub fn get_dataset_from_disk<T: Serialize + DeserializeOwned>(
dataset_dir: impl AsRef<Path>,
) -> impl Dataset<Item = T> {
let ext = Some(std::ffi::OsStr::new("json"));
let dataset_dir = dataset_dir.as_ref();
let length = std::fs::read_dir(dataset_dir)
.unwrap_or_else(|_| panic!("There's no dataset at {:?}", dataset_dir))
.filter(|file| file.as_ref().unwrap().path().extension() == ext)
.count();
println!("Found {} JSON files at {:?}", length, dataset_dir);
ClosureDataset::new(length, |i| {
panic!("Expected to find contest {} in the cache, but didn't", i)
})
.cached(dataset_dir)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_in_memory_dataset() {
let vec = vec![5.7, 9.2, -1.5];
let dataset: Box<dyn Dataset<Item = f64>> = Box::new(vec.as_slice());
assert_eq!(dataset.len(), vec.len());
for (data_val, &vec_val) in dataset.iter().zip(vec.iter()) {
assert_eq!(data_val, vec_val);
}
}
#[test]
fn test_closure_dataset() {
let dataset = ClosureDataset::new(10, |x| x * x);
for (idx, val) in dataset.iter().enumerate() {
assert_eq!(val, idx * idx);
}
}
#[test]
fn test_cached_dataset() {
let length = 5;
let cache_dir = "temp_dir_containing_squares";
let cache = || std::fs::read_dir(cache_dir);
let fancy_item = |idx: usize| (idx.checked_sub(2), vec![idx * idx; idx]);
assert!(cache().is_err());
let data_from_fn = ClosureDataset::new(length, fancy_item).cached(cache_dir);
assert_eq!(cache().unwrap().count(), 0);
let data_into_vec = data_from_fn.iter().collect::<Vec<_>>();
assert_eq!(cache().unwrap().count(), length);
let data_from_disk = get_dataset_from_disk(cache_dir);
assert_eq!(data_from_fn.len(), length);
assert_eq!(data_into_vec.len(), length);
assert_eq!(data_from_disk.len(), length);
for idx in 0..length {
let expected = fancy_item(idx);
let data_from_disk_val: (Option<usize>, Vec<usize>) = data_from_disk.get(idx);
assert_eq!(data_from_fn.get(idx), expected);
assert_eq!(data_into_vec[idx], expected);
assert_eq!(data_from_disk_val, expected);
}
assert_eq!(cache().unwrap().count(), length);
std::fs::remove_dir_all(cache_dir).unwrap();
assert!(cache().is_err());
}
}