ser_data_loadr/
lib.rs

1use std::{
2    collections::BTreeMap,
3    fs::File,
4    future::Future,
5    io::{BufRead, BufReader, BufWriter, Read},
6    path::{Path, PathBuf},
7    sync::{atomic::AtomicU64, Arc, RwLock},
8    time::SystemTime,
9};
10
11use anyhow::Context;
12use futures::{channel::oneshot, stream::FuturesUnordered, StreamExt};
13use serde::{de::DeserializeOwned, Deserialize, Serialize};
14
15#[derive(Default)]
16struct LoaderTaskSet {
17    #[cfg(feature = "tokio")]
18    inner: tokio::task::JoinSet<anyhow::Result<()>>,
19}
20
21impl LoaderTaskSet {
22    fn spawn_task<F: FnOnce() -> anyhow::Result<()> + Send + 'static>(&mut self, f: F) {
23        #[cfg(feature = "tokio")]
24        {
25            self.inner.spawn_blocking(f);
26        }
27    }
28
29    async fn wait_all(self) -> anyhow::Result<()> {
30        #[cfg(feature = "tokio")]
31        {
32            let res = self.inner.join_all().await;
33            for res in res {
34                res?;
35            }
36            Ok(())
37        }
38    }
39}
40
41/// Data format to read and load data from a file
42pub trait DataFormat {
43    /// Read data from a reader
44    fn read_from<T: Data, R: Read + BufRead>(rdr: R) -> anyhow::Result<T>;
45
46    /// Read data from a file
47    fn read_from_file<T: Data>(path: &Path) -> anyhow::Result<T> {
48        let file = File::open(path).context("open file")?;
49        let rdr = BufReader::new(file);
50        Self::read_from(rdr)
51    }
52
53    /// Write data to a writer
54    fn write_to<T: Serialize, W: std::io::Write>(data: &T, wtr: W) -> anyhow::Result<()>;
55    /// Write data to a file
56    fn write_to_file<T: Serialize>(data: &T, path: &Path) -> anyhow::Result<()> {
57        let file = File::create(path).context("create file")?;
58        Self::write_to(data, file)
59    }
60}
61
62/// JSON data format
63pub struct JsonDataFormat;
64impl DataFormat for JsonDataFormat {
65    fn read_from<T: Data, R: Read + BufRead>(rdr: R) -> anyhow::Result<T> {
66        serde_json::from_reader(rdr).context("parse json")
67    }
68
69    fn write_to<T: Serialize, W: std::io::Write>(data: &T, wtr: W) -> anyhow::Result<()> {
70        serde_json::to_writer(wtr, data).context("write json")
71    }
72}
73
74/// Bincode data format
75pub struct BinCodeDataFormat;
76impl DataFormat for BinCodeDataFormat {
77    fn read_from<T: Data, R: Read + BufRead>(rdr: R) -> anyhow::Result<T> {
78        bincode::deserialize_from(rdr).context("parse bincode")
79    }
80
81    fn write_to<T: Serialize, W: std::io::Write>(data: &T, wtr: W) -> anyhow::Result<()> {
82        bincode::serialize_into(wtr, data).context("write bincode")
83    }
84}
85
86/// Data mapper, to map data from an `In` type to an `Out` type,
87/// this is useful for using the built-in caching mechanism of the `DataLoader`
88pub trait DataMapper: Send + 'static {
89    type In: Data;
90    type Out: Send + 'static;
91
92    fn map(self, data: Self::In) -> anyhow::Result<Self::Out>;
93}
94
95/// Create a data mapper from a function
96pub fn data_mapper_fn<F, In, Out>(f: F) -> MapperFn<F, In, Out> {
97    MapperFn::new(f)
98}
99
100/// Data mapper from a function
101pub struct MapperFn<F, In, Out> {
102    f: F,
103    _t: std::marker::PhantomData<(In, Out)>,
104}
105
106impl<F, In, Out> Clone for MapperFn<F, In, Out>
107where
108    F: Clone,
109{
110    fn clone(&self) -> Self {
111        Self {
112            f: self.f.clone(),
113            _t: std::marker::PhantomData,
114        }
115    }
116}
117
118impl<F, In, Out> MapperFn<F, In, Out> {
119    /// Create a new data mapper from a function
120    pub fn new(f: F) -> Self {
121        Self {
122            f,
123            _t: std::marker::PhantomData,
124        }
125    }
126}
127
128impl<F, In: Data, Out: Send + 'static> DataMapper for MapperFn<F, In, Out>
129where
130    F: FnOnce(In) -> anyhow::Result<Out> + Send + 'static,
131{
132    type In = In;
133    type Out = Out;
134
135    fn map(self, data: Self::In) -> anyhow::Result<Self::Out> {
136        (self.f)(data)
137    }
138}
139
140/// Data trait, to mark a type as a data type
141/// basically `DeserializeOwned + Send + 'static`
142pub trait Data: DeserializeOwned + Send + 'static {}
143impl<T: DeserializeOwned + Send + 'static> Data for T {}
144
145pub struct DataReceiver<T>(oneshot::Receiver<T>);
146
147impl<T> DataReceiver<T> {
148    /// Get the data, this will usually wailt until the data is loaded
149    pub fn get(mut self) -> T {
150        self.0
151            .try_recv()
152            .expect("Data recv closed")
153            .expect("Data recv no value")
154    }
155}
156
157impl<T> Future for DataReceiver<T> {
158    type Output = T;
159
160    fn poll(
161        mut self: std::pin::Pin<&mut Self>,
162        cx: &mut std::task::Context,
163    ) -> std::task::Poll<Self::Output> {
164        std::pin::Pin::new(&mut self.0).poll(cx).map(Result::unwrap)
165    }
166}
167
168/// Data format handler, to handle different data formats
169pub trait DataFormatHandler {
170    /// Load data from a file
171    fn load_from_file<T: Data>(p: &Path) -> anyhow::Result<T>;
172}
173
174/// Auto data format handler, to automatically detects the data format from the file extension
175pub struct AutoDataFormatHandler;
176impl DataFormatHandler for AutoDataFormatHandler {
177    fn load_from_file<T: Data>(p: &Path) -> anyhow::Result<T> {
178        let ext = p.extension().context("no extension")?;
179        match ext.to_string_lossy().to_lowercase().as_str() {
180            "json" => JsonDataFormat::read_from_file(p),
181            "bincode" => BinCodeDataFormat::read_from_file(p),
182            _ => anyhow::bail!("unknown extension: {:?}", ext),
183        }
184    }
185}
186
187/// Entry for the Manifest for the cache
188#[derive(Debug, Deserialize, Serialize, Clone)]
189struct DataManifestEntry {
190    last_changed: SystemTime,
191    cached_name: String,
192}
193
194/// Manifest for the cache
195#[derive(Debug, Deserialize, Serialize)]
196struct DataManifest {
197    pub entries: BTreeMap<PathBuf, DataManifestEntry>,
198    pub counter: u64,
199}
200
201/// Cache for the data loader
202struct Cache {
203    manifest: RwLock<DataManifest>,
204    dir: PathBuf,
205    counter: AtomicU64,
206}
207
208impl Cache {
209    /// Load the cache from a directory
210    fn load(dir: &Path) -> anyhow::Result<Self> {
211        if !dir.exists() {
212            std::fs::create_dir(dir).context("create cache dir")?;
213        }
214
215        let manifest_path = dir.join("manifest.json");
216        let manifest = if manifest_path.exists() {
217            JsonDataFormat::read_from_file(&manifest_path)?
218        } else {
219            DataManifest {
220                entries: BTreeMap::new(),
221                counter: 0,
222            }
223        };
224
225        let counter = manifest.counter;
226
227        Ok(Self {
228            manifest: RwLock::new(manifest),
229            dir: dir.to_owned(),
230            counter: counter.into(),
231        })
232    }
233
234    /// Save the cache
235    fn save(&self) -> anyhow::Result<()> {
236        let mut manifest = self.manifest.write().expect("write");
237        manifest.counter = self.counter.load(std::sync::atomic::Ordering::Relaxed);
238
239        let manifest_path = self.dir.join("manifest.json");
240        JsonDataFormat::write_to_file::<DataManifest>(&manifest, &manifest_path)
241    }
242
243    /// Update an entry in the cache
244    fn update_entry<F>(&self, path: &Path, update_cached: F) -> anyhow::Result<()>
245    where
246        F: FnOnce(&mut BufWriter<File>) -> anyhow::Result<()>,
247    {
248        let path = path.canonicalize().expect("canonicalize path");
249        let filename = path.file_name().expect("file_name").to_string_lossy();
250        let num = self
251            .counter
252            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
253        let cache_file = format!("{num}_{filename}.cache");
254
255        let file = File::create(self.dir.join(&cache_file))?;
256        let mut file = BufWriter::new(file);
257        update_cached(&mut file)?;
258
259        // Remove the old file
260        if let Some(old) = self.get_entry(&path) {
261            let old_cache = self.dir.join(&old.cached_name);
262            std::fs::remove_file(old_cache).context("remove old cache file")?;
263        }
264
265        let last_changed = path
266            .metadata()
267            .context("metadata")?
268            .modified()
269            .expect("modified");
270        self.manifest.write().expect("write").entries.insert(
271            path.to_owned(),
272            DataManifestEntry {
273                last_changed,
274                cached_name: cache_file,
275            },
276        );
277
278        Ok(())
279    }
280
281    /// Get or update a data from the cache
282    fn get_or_update<T: Data + Serialize>(
283        &self,
284        path: &Path,
285        load: impl FnOnce(&Path) -> anyhow::Result<T>,
286    ) -> anyhow::Result<T> {
287        if let Some(entry) = self.get_validated_entry(path)? {
288            let cache_file = self.dir.join(&entry.cached_name);
289            let rdr = BufReader::new(File::open(cache_file).context("open cache file")?);
290            return BinCodeDataFormat::read_from(rdr);
291        }
292
293        let data = load(path)?;
294        self.update_entry(path, |file| BinCodeDataFormat::write_to(&data, file))?;
295        Ok(data)
296    }
297
298    /// Get an entry from the cache
299    fn get_entry(&self, path: &Path) -> Option<DataManifestEntry> {
300        let norm = path.canonicalize().expect("canonicalize path");
301        self.manifest
302            .read()
303            .expect("read")
304            .entries
305            .get(&norm)
306            .cloned()
307    }
308
309    /// Get a validated entry from the cache
310    fn get_validated_entry(&self, path: &Path) -> anyhow::Result<Option<DataManifestEntry>> {
311        Ok(match self.get_entry(path) {
312            Some(entry) => {
313                let meta = path.metadata().context("metadata")?;
314                if meta.modified().expect("modified") > entry.last_changed {
315                    None
316                } else {
317                    Some(entry)
318                }
319            }
320            None => None,
321        })
322    }
323}
324
325/// Data loader, to load data from files
326pub struct DataLoader<F> {
327    pending: LoaderTaskSet,
328    dir: PathBuf,
329    cache: Arc<Cache>,
330    _f: std::marker::PhantomData<F>,
331}
332
333impl<F: DataFormatHandler> DataLoader<F> {
334    /// Create a new data loader from a directory
335    pub fn new(dir: &Path) -> anyhow::Result<Self> {
336        anyhow::ensure!(dir.is_dir(), "Dir is not a directory: {:?}", dir);
337        let cache = Cache::load(&dir.join(".cache"))?;
338
339        Ok(Self {
340            pending: LoaderTaskSet::default(),
341            cache: Arc::new(cache),
342            dir: dir.to_owned(),
343            _f: std::marker::PhantomData,
344        })
345    }
346
347    /// Spawn a task to load data from a file
348    fn spawn<Out: Send + 'static>(
349        &mut self,
350        path: &Path,
351        f: impl FnOnce(&Path) -> anyhow::Result<Out> + Send + 'static,
352    ) -> DataReceiver<Out> {
353        let path = self.dir.join(path);
354
355        let (tx, rx) = oneshot::channel();
356        let path = path.to_owned();
357        self.pending.spawn_task(move || {
358            let out = f(&path).with_context(|| format!("load {path:?}"))?;
359            let _ = tx.send(out);
360            Ok(())
361        });
362
363        DataReceiver(rx)
364    }
365
366    /// Load Data from a file
367    pub fn load_file<T: Data>(&mut self, path: impl AsRef<Path>) -> DataReceiver<T> {
368        self.spawn(path.as_ref(), |path| F::load_from_file::<T>(path))
369    }
370
371    /// Load and map data from a file
372    pub fn load_map<M: DataMapper>(
373        &mut self,
374        path: impl AsRef<Path>,
375        mapper: M,
376    ) -> DataReceiver<M::Out> {
377        self.spawn::<M::Out>(path.as_ref(), move |path| {
378            let data = F::load_from_file::<M::In>(path)?;
379            let out = mapper.map(data)?;
380            Ok(out)
381        })
382    }
383
384    /// Load and map data from a file, but cache the result
385    /// so the mapper is only invoked if the file has changed
386    pub fn load_map_cached<M: DataMapper>(
387        &mut self,
388        path: impl AsRef<Path>,
389        mapper: M,
390    ) -> DataReceiver<M::Out>
391    where
392        M::Out: Serialize + Data,
393    {
394        let cache = self.cache.clone();
395        self.spawn::<M::Out>(path.as_ref(), move |path| {
396            cache.get_or_update(path, |path| {
397                let data = F::load_from_file::<M::In>(path)?;
398                let out = mapper.map(data)?;
399                Ok(out)
400            })
401        })
402    }
403
404    /// Loads data from a given path, with a custom Out type
405    pub fn load<Out: Send + 'static>(
406        &mut self,
407        path: &Path,
408        f: impl FnOnce(&Path) -> anyhow::Result<Out> + Send + 'static,
409    ) -> DataReceiver<Out> {
410        self.spawn(path, f)
411    }
412
413    /// Spawn all tasks from a given iterator of paths
414    fn spawn_all<Out: Send + 'static, P: AsRef<Path>>(
415        &mut self,
416        paths: impl Iterator<Item = P>,
417        f: impl Fn(&Path) -> anyhow::Result<Out> + Clone + Send + 'static,
418    ) -> DataReceiver<Vec<Out>> {
419        let mut tasks: FuturesUnordered<_> = paths.map(|path| self.spawn(path.as_ref(), f.clone())).collect();
420        let (tx, rx) = oneshot::channel();
421        tokio::spawn(async move {
422            let mut res = Vec::new();
423            while let Some(data) = tasks.next().await {
424                res.push(data);
425            }
426            let _ = tx.send(res);
427        });
428
429        DataReceiver(rx)
430    }
431
432    /// Load all files from a given iterator of paths
433    pub fn load_all_files<T: Data, P: AsRef<Path>>(
434        &mut self,
435        paths: impl Iterator<Item = P>,
436    ) -> DataReceiver<Vec<T>> {
437        self.spawn_all(paths, |path| F::load_from_file::<T>(path))
438    }
439
440
441    /// Load all files from a given iterator of paths, with a custom mapper
442    pub fn load_all_mapped<M: DataMapper + Clone, P: AsRef<Path>>(
443        &mut self,
444        paths: impl Iterator<Item = P>,
445        mapper: M,
446    ) -> DataReceiver<Vec<M::Out>> {
447        self.spawn_all(paths, move |path| {
448            let data = F::load_from_file::<M::In>(path)?;
449            let out = mapper.clone().map(data)?;
450            Ok(out)
451        })
452    }
453
454    /// Load all files from a given iterator of paths, with a custom Out type
455    pub fn load_all<Out: Send + 'static, P: AsRef<Path>>(
456        &mut self,
457        paths: impl Iterator<Item = P>,
458        f: impl Fn(&Path) -> anyhow::Result<Out> + Clone + Send + 'static,
459    ) -> DataReceiver<Vec<Out>> {
460        self.spawn_all(paths, f)
461    }
462
463    /// Wait for all pending tasks to finish and update the manifest
464    pub async fn wait_all(self) -> anyhow::Result<()> {
465        self.pending.wait_all().await?;
466        self.cache.save()?;
467        Ok(())
468    }
469}
470
471pub type AutoDataLoader = DataLoader<AutoDataFormatHandler>;
472
473#[cfg(test)]
474mod tests {
475    use std::{sync::atomic::AtomicBool, time::Duration};
476
477    use super::*;
478
479    #[tokio::test]
480    async fn data_loader() {
481        let path = Path::new("test_1");
482        let _ = std::fs::create_dir(&path);
483
484        let a_file = path.join("a.json");
485        let is_mapped = Arc::new(AtomicBool::new(false));
486
487        const HELLO_WORLD: &str = "Hello, World...";
488        const HELLO_UNIVERSE: &str = "Hello, Universe...";
489
490        let seq = [
491            (HELLO_WORLD, true),
492            (HELLO_WORLD, false),
493            (HELLO_UNIVERSE, true),
494            (HELLO_UNIVERSE, false),
495            (HELLO_WORLD, true),
496            (HELLO_UNIVERSE, true),
497        ];
498
499        for (inp, map) in seq {
500            // first
501            let json = format!("\"{inp}\"");
502            let content = std::fs::read_to_string(&a_file).unwrap();
503            if content != json {
504                std::fs::write(&a_file, json).unwrap();
505            }
506
507            is_mapped.store(false, std::sync::atomic::Ordering::Relaxed);
508            let mut loader = AutoDataLoader::new(&path).unwrap();
509            let txt = loader.load_file::<String>("a.json");
510            let mapped = loader.load_map("a.json", data_mapper_fn(|s: String| Ok(s.len())));
511            let is_mapped_ = is_mapped.clone();
512            let mapped_cached = loader.load_map_cached(
513                "a.json",
514                data_mapper_fn(move |s: String| {
515                    is_mapped_.store(true, std::sync::atomic::Ordering::Relaxed);
516                    Ok(s.len())
517                }),
518            );
519            loader.wait_all().await.unwrap();
520
521            assert_eq!(map, is_mapped.load(std::sync::atomic::Ordering::Relaxed));
522            assert_eq!(txt.get(), inp);
523            assert_eq!(mapped.get(), inp.len());
524            assert_eq!(mapped_cached.get(), inp.len());
525            std::thread::sleep(Duration::from_millis(1));
526        }
527
528
529        let mut loader = AutoDataLoader::new(&path).unwrap();
530        let all = loader.load_all_files::<String, _>(std::iter::repeat_n(Path::new("a.json"), 10));
531        loader.wait_all().await.unwrap();
532
533        let all = all.get();
534        assert_eq!(all.len(), 10);
535        for s in all {
536            assert_eq!(s, HELLO_UNIVERSE);
537        }
538    }
539
540}