cross-proc-cache 0.1.0

Small cross-process cache helper used in the Liturgy workspace.
Documentation
use std::{
    fs::OpenOptions,
    io::{Read, Seek, Write},
    path::PathBuf,
};

use anyhow::Result;
use bincode::{Decode, Encode};
use fs2::FileExt;
pub struct FsCache<StoredData>
where
    StoredData: Encode + Decode<()>,
{
    fingerprint_path: PathBuf,
    data_path: PathBuf,
    fp: String,
    _marker: std::marker::PhantomData<StoredData>,
}

// Fingerprint-based cache: we store a fingerprint string in a
// `*.fingerprint` file next to the `*.data`. When creating a cache we
// only write the `*.data` (and update the fingerprint) if the
// provided fingerprint doesn't match the on-disk one.
impl<StoredData> FsCache<StoredData>
where
    StoredData: Encode + Decode<()>,
{
    pub fn new(path: &std::path::Path, fingerprint: &str) -> anyhow::Result<Self> {
        let fingerprint_path = path.with_extension("fingerprint");
        let data_path = path.with_extension("data");
        Ok(Self {
            fingerprint_path,
            data_path,
            fp: fingerprint.to_string(),
            _marker: std::marker::PhantomData,
        })
    }
    /// Create or open a fingerprinted cache at `path`.
    ///
    /// If the existing on-disk fingerprint differs from `fingerprint`, the
    /// `data` is generated by `f` and written to disk and the fingerprint
    /// file is updated. Otherwise the existing data is reused.
    pub fn load(&self, f: impl Fn() -> StoredData) -> Result<StoredData> {
        let fingerprint = &self.fp;

        // First, try to read the existing fingerprint.
        let existing_fp = self.read_fingerprint()?;
        if existing_fp == *fingerprint {
            // Fingerprint matches, try to read existing data.
            return self.read_data();
            // If reading data failed, we'll fall through to regenerate it.
        }

        let mut fp = self.lock_fingerprint()?;
        let existing_fp = Self::read_locked_fingerprint(&mut fp)?;
        if existing_fp == *fingerprint {
            // Fingerprint matches, try to read existing data.
            return self.read_data();
            // If reading data failed, we'll fall through to regenerate it.
        }

        // Either no existing fingerprint or it doesn't match.
        // Generate new data and write it along with the new fingerprint.
        let data = f();
        self.write_data(&data)?;
        self.write_fingerprint(fp)?;
        Ok(data)
    }

    fn read_fingerprint(&self) -> anyhow::Result<String> {
        // Open read+write so we can reliably place locks on some platforms.
        let mut f = OpenOptions::new()
            .read(true)
            .write(true)
            .create(true)
            .truncate(false)
            .open(&self.fingerprint_path)?;
        f.lock_shared()?;
        let mut contents = String::new();
        f.read_to_string(&mut contents)?;
        if contents.is_empty() {
            Ok(String::new())
        } else {
            Ok(contents)
        }
    }

    fn lock_fingerprint(&self) -> anyhow::Result<std::fs::File> {
        let f = OpenOptions::new()
            .read(true)
            .write(true)
            .create(true)
            .truncate(false)
            .open(&self.fingerprint_path)?;
        f.lock()?;
        Ok(f)
    }

    fn write_fingerprint(&self, mut f: std::fs::File) -> anyhow::Result<()> {
        // Overwrite existing contents.
        f.set_len(0)?;
        f.seek(std::io::SeekFrom::Start(0))?;
        f.write_all(self.fp.as_bytes())?;
        f.sync_all()?;
        Ok(())
    }

    fn read_locked_fingerprint(f: &mut std::fs::File) -> anyhow::Result<String> {
        let mut contents = String::new();
        f.read_to_string(&mut contents)?;
        if contents.is_empty() {
            Ok(String::new())
        } else {
            Ok(contents)
        }
    }

    fn read_data(&self) -> anyhow::Result<StoredData> {
        let mut f = OpenOptions::new().read(true).open(&self.data_path)?;
        f.lock_shared()?;
        let mut buf = Vec::new();
        f.read_to_end(&mut buf)?;
        let data = bincode::decode_from_slice(&buf, bincode::config::standard())?.0;
        Ok(data)
    }

    fn write_data(&self, data: &StoredData) -> anyhow::Result<()> {
        let mut f = OpenOptions::new()
            .write(true)
            .create(true)
            .truncate(false)
            .open(&self.data_path)?;
        f.lock_exclusive()?;
        let encoded = bincode::encode_to_vec(data, bincode::config::standard())?;
        f.write_all(&encoded)?;
        f.sync_all()?;
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use serde::{Deserialize, Serialize};
    use tempfile::tempdir;

    use super::*;

    #[derive(Encode, Decode, PartialEq, Debug, Serialize, Deserialize)]
    struct TestData {
        value: String,
    }

    #[test]
    fn test_fs_cache() {
        let dir = tempdir().unwrap();
        let path = dir.path().join("test_cache");
        {
            let cache = FsCache::new(&path, "v1").unwrap();
            let data = cache
                .load(|| TestData {
                    value: "Hello, World!".to_string(),
                })
                .unwrap();
            assert_eq!(data.value, "Hello, World!");
        }
        {
            // New handle with same fingerprint should reuse data.
            let cache = FsCache::new(&path, "v1").unwrap();
            let data = cache
                .load(|| TestData {
                    value: "This should not be used".to_string(),
                })
                .unwrap();
            assert_eq!(data.value, "Hello, World!");
        }
        {
            // New fingerprint should overwrite data.
            let cache = FsCache::new(&path, "v2").unwrap();
            let data = cache
                .load(|| TestData {
                    value: "New value".to_string(),
                })
                .unwrap();
            assert_eq!(data.value, "New value");
        }
        dir.close().unwrap();
    }

    #[test]
    fn test_concurrent_create_same_fingerprint() {
        use std::{sync::Arc, thread};

        let dir = tempdir().unwrap();
        let path = Arc::new(dir.path().join("test_cache_concurrent"));

        // We'll spawn multiple threads that concurrently call `FsCache::new`
        // with the same fingerprint. The cache should end up containing
        // the expected value and no thread should panic.
        let mut handles = Vec::new();
        for _ in 0..8 {
            let p = path.clone();
            handles.push(thread::spawn(move || {
                let cache = FsCache::new(&p, "cfp").unwrap();
                let data = cache
                    .load(|| TestData {
                        value: "Concurrent Hello".to_string(),
                    })
                    .unwrap();
                assert_eq!(data.value, "Concurrent Hello");
            }));
        }

        for h in handles {
            h.join().expect("thread panicked");
        }

        dir.close().unwrap();
    }

    #[test]
    fn test_concurrent_readers_after_write() {
        use std::{sync::Arc, thread};

        let dir = tempdir().unwrap();
        let path = dir.path().join("test_cache_readers");

        // Create the cache once.
        let cache = FsCache::new(&path, "r1").unwrap();
        let data = cache
            .load(|| TestData {
                value: "Reader Hello".to_string(),
            })
            .unwrap();
        assert_eq!(data.value, "Reader Hello");

        // Spawn multiple reader threads which will open the cache and read.
        let path = Arc::new(path);
        let mut handles = Vec::new();
        for _ in 0..16 {
            let p = path.clone();
            handles.push(thread::spawn(move || {
                let c = FsCache::new(&p, "r1").unwrap();
                for _ in 0..10 {
                    let d = c
                        .load(|| TestData {
                            value: "Should not be used".to_string(),
                        })
                        .unwrap();
                    assert_eq!(d.value, "Reader Hello");
                }
            }));
        }

        for h in handles {
            h.join().expect("reader thread panicked");
        }

        dir.close().unwrap();
    }
}