use std::{
fs::OpenOptions,
io::{Read, Seek, Write},
path::PathBuf,
};
use anyhow::Result;
use bincode::{Decode, Encode};
use fs2::FileExt;
pub struct FsCache<StoredData>
where
StoredData: Encode + Decode<()>,
{
fingerprint_path: PathBuf,
data_path: PathBuf,
fp: String,
_marker: std::marker::PhantomData<StoredData>,
}
impl<StoredData> FsCache<StoredData>
where
StoredData: Encode + Decode<()>,
{
pub fn new(path: &std::path::Path, fingerprint: &str) -> anyhow::Result<Self> {
let fingerprint_path = path.with_extension("fingerprint");
let data_path = path.with_extension("data");
Ok(Self {
fingerprint_path,
data_path,
fp: fingerprint.to_string(),
_marker: std::marker::PhantomData,
})
}
pub fn load(&self, f: impl Fn() -> StoredData) -> Result<StoredData> {
let fingerprint = &self.fp;
let existing_fp = self.read_fingerprint()?;
if existing_fp == *fingerprint {
return self.read_data();
}
let mut fp = self.lock_fingerprint()?;
let existing_fp = Self::read_locked_fingerprint(&mut fp)?;
if existing_fp == *fingerprint {
return self.read_data();
}
let data = f();
self.write_data(&data)?;
self.write_fingerprint(fp)?;
Ok(data)
}
fn read_fingerprint(&self) -> anyhow::Result<String> {
let mut f = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&self.fingerprint_path)?;
f.lock_shared()?;
let mut contents = String::new();
f.read_to_string(&mut contents)?;
if contents.is_empty() {
Ok(String::new())
} else {
Ok(contents)
}
}
fn lock_fingerprint(&self) -> anyhow::Result<std::fs::File> {
let f = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&self.fingerprint_path)?;
f.lock()?;
Ok(f)
}
fn write_fingerprint(&self, mut f: std::fs::File) -> anyhow::Result<()> {
f.set_len(0)?;
f.seek(std::io::SeekFrom::Start(0))?;
f.write_all(self.fp.as_bytes())?;
f.sync_all()?;
Ok(())
}
fn read_locked_fingerprint(f: &mut std::fs::File) -> anyhow::Result<String> {
let mut contents = String::new();
f.read_to_string(&mut contents)?;
if contents.is_empty() {
Ok(String::new())
} else {
Ok(contents)
}
}
fn read_data(&self) -> anyhow::Result<StoredData> {
let mut f = OpenOptions::new().read(true).open(&self.data_path)?;
f.lock_shared()?;
let mut buf = Vec::new();
f.read_to_end(&mut buf)?;
let data = bincode::decode_from_slice(&buf, bincode::config::standard())?.0;
Ok(data)
}
fn write_data(&self, data: &StoredData) -> anyhow::Result<()> {
let mut f = OpenOptions::new()
.write(true)
.create(true)
.truncate(false)
.open(&self.data_path)?;
f.lock_exclusive()?;
let encoded = bincode::encode_to_vec(data, bincode::config::standard())?;
f.write_all(&encoded)?;
f.sync_all()?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use serde::{Deserialize, Serialize};
use tempfile::tempdir;
use super::*;
#[derive(Encode, Decode, PartialEq, Debug, Serialize, Deserialize)]
struct TestData {
value: String,
}
#[test]
fn test_fs_cache() {
let dir = tempdir().unwrap();
let path = dir.path().join("test_cache");
{
let cache = FsCache::new(&path, "v1").unwrap();
let data = cache
.load(|| TestData {
value: "Hello, World!".to_string(),
})
.unwrap();
assert_eq!(data.value, "Hello, World!");
}
{
let cache = FsCache::new(&path, "v1").unwrap();
let data = cache
.load(|| TestData {
value: "This should not be used".to_string(),
})
.unwrap();
assert_eq!(data.value, "Hello, World!");
}
{
let cache = FsCache::new(&path, "v2").unwrap();
let data = cache
.load(|| TestData {
value: "New value".to_string(),
})
.unwrap();
assert_eq!(data.value, "New value");
}
dir.close().unwrap();
}
#[test]
fn test_concurrent_create_same_fingerprint() {
use std::{sync::Arc, thread};
let dir = tempdir().unwrap();
let path = Arc::new(dir.path().join("test_cache_concurrent"));
let mut handles = Vec::new();
for _ in 0..8 {
let p = path.clone();
handles.push(thread::spawn(move || {
let cache = FsCache::new(&p, "cfp").unwrap();
let data = cache
.load(|| TestData {
value: "Concurrent Hello".to_string(),
})
.unwrap();
assert_eq!(data.value, "Concurrent Hello");
}));
}
for h in handles {
h.join().expect("thread panicked");
}
dir.close().unwrap();
}
#[test]
fn test_concurrent_readers_after_write() {
use std::{sync::Arc, thread};
let dir = tempdir().unwrap();
let path = dir.path().join("test_cache_readers");
let cache = FsCache::new(&path, "r1").unwrap();
let data = cache
.load(|| TestData {
value: "Reader Hello".to_string(),
})
.unwrap();
assert_eq!(data.value, "Reader Hello");
let path = Arc::new(path);
let mut handles = Vec::new();
for _ in 0..16 {
let p = path.clone();
handles.push(thread::spawn(move || {
let c = FsCache::new(&p, "r1").unwrap();
for _ in 0..10 {
let d = c
.load(|| TestData {
value: "Should not be used".to_string(),
})
.unwrap();
assert_eq!(d.value, "Reader Hello");
}
}));
}
for h in handles {
h.join().expect("reader thread panicked");
}
dir.close().unwrap();
}
}