1use std::{
2 fs::OpenOptions,
3 io::{Read, Seek, Write},
4 path::PathBuf,
5};
6
7use anyhow::Result;
8use bincode::{Decode, Encode};
9use fs2::FileExt;
10pub struct FsCache<StoredData>
11where
12 StoredData: Encode + Decode<()>,
13{
14 fingerprint_path: PathBuf,
15 data_path: PathBuf,
16 fp: String,
17 _marker: std::marker::PhantomData<StoredData>,
18}
19
20impl<StoredData> FsCache<StoredData>
25where
26 StoredData: Encode + Decode<()>,
27{
28 pub fn new(path: &std::path::Path, fingerprint: &str) -> anyhow::Result<Self> {
29 let fingerprint_path = path.with_extension("fingerprint");
30 let data_path = path.with_extension("data");
31 Ok(Self {
32 fingerprint_path,
33 data_path,
34 fp: fingerprint.to_string(),
35 _marker: std::marker::PhantomData,
36 })
37 }
38 pub fn load(&self, f: impl Fn() -> StoredData) -> Result<StoredData> {
44 let fingerprint = &self.fp;
45
46 let existing_fp = self.read_fingerprint()?;
48 if existing_fp == *fingerprint {
49 return self.read_data();
51 }
53
54 let mut fp = self.lock_fingerprint()?;
55 let existing_fp = Self::read_locked_fingerprint(&mut fp)?;
56 if existing_fp == *fingerprint {
57 return self.read_data();
59 }
61
62 let data = f();
65 self.write_data(&data)?;
66 self.write_fingerprint(fp)?;
67 Ok(data)
68 }
69
70 fn read_fingerprint(&self) -> anyhow::Result<String> {
71 let mut f = OpenOptions::new()
73 .read(true)
74 .write(true)
75 .create(true)
76 .truncate(false)
77 .open(&self.fingerprint_path)?;
78 f.lock_shared()?;
79 let mut contents = String::new();
80 f.read_to_string(&mut contents)?;
81 if contents.is_empty() {
82 Ok(String::new())
83 } else {
84 Ok(contents)
85 }
86 }
87
88 fn lock_fingerprint(&self) -> anyhow::Result<std::fs::File> {
89 let f = OpenOptions::new()
90 .read(true)
91 .write(true)
92 .create(true)
93 .truncate(false)
94 .open(&self.fingerprint_path)?;
95 f.lock()?;
96 Ok(f)
97 }
98
99 fn write_fingerprint(&self, mut f: std::fs::File) -> anyhow::Result<()> {
100 f.set_len(0)?;
102 f.seek(std::io::SeekFrom::Start(0))?;
103 f.write_all(self.fp.as_bytes())?;
104 f.sync_all()?;
105 Ok(())
106 }
107
108 fn read_locked_fingerprint(f: &mut std::fs::File) -> anyhow::Result<String> {
109 let mut contents = String::new();
110 f.read_to_string(&mut contents)?;
111 if contents.is_empty() {
112 Ok(String::new())
113 } else {
114 Ok(contents)
115 }
116 }
117
118 fn read_data(&self) -> anyhow::Result<StoredData> {
119 let mut f = OpenOptions::new().read(true).open(&self.data_path)?;
120 f.lock_shared()?;
121 let mut buf = Vec::new();
122 f.read_to_end(&mut buf)?;
123 let data = bincode::decode_from_slice(&buf, bincode::config::standard())?.0;
124 Ok(data)
125 }
126
127 fn write_data(&self, data: &StoredData) -> anyhow::Result<()> {
128 let mut f = OpenOptions::new()
129 .write(true)
130 .create(true)
131 .truncate(false)
132 .open(&self.data_path)?;
133 f.lock_exclusive()?;
134 let encoded = bincode::encode_to_vec(data, bincode::config::standard())?;
135 f.write_all(&encoded)?;
136 f.sync_all()?;
137 Ok(())
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use serde::{Deserialize, Serialize};
144 use tempfile::tempdir;
145
146 use super::*;
147
148 #[derive(Encode, Decode, PartialEq, Debug, Serialize, Deserialize)]
149 struct TestData {
150 value: String,
151 }
152
153 #[test]
154 fn test_fs_cache() {
155 let dir = tempdir().unwrap();
156 let path = dir.path().join("test_cache");
157 {
158 let cache = FsCache::new(&path, "v1").unwrap();
159 let data = cache
160 .load(|| TestData {
161 value: "Hello, World!".to_string(),
162 })
163 .unwrap();
164 assert_eq!(data.value, "Hello, World!");
165 }
166 {
167 let cache = FsCache::new(&path, "v1").unwrap();
169 let data = cache
170 .load(|| TestData {
171 value: "This should not be used".to_string(),
172 })
173 .unwrap();
174 assert_eq!(data.value, "Hello, World!");
175 }
176 {
177 let cache = FsCache::new(&path, "v2").unwrap();
179 let data = cache
180 .load(|| TestData {
181 value: "New value".to_string(),
182 })
183 .unwrap();
184 assert_eq!(data.value, "New value");
185 }
186 dir.close().unwrap();
187 }
188
189 #[test]
190 fn test_concurrent_create_same_fingerprint() {
191 use std::{sync::Arc, thread};
192
193 let dir = tempdir().unwrap();
194 let path = Arc::new(dir.path().join("test_cache_concurrent"));
195
196 let mut handles = Vec::new();
200 for _ in 0..8 {
201 let p = path.clone();
202 handles.push(thread::spawn(move || {
203 let cache = FsCache::new(&p, "cfp").unwrap();
204 let data = cache
205 .load(|| TestData {
206 value: "Concurrent Hello".to_string(),
207 })
208 .unwrap();
209 assert_eq!(data.value, "Concurrent Hello");
210 }));
211 }
212
213 for h in handles {
214 h.join().expect("thread panicked");
215 }
216
217 dir.close().unwrap();
218 }
219
220 #[test]
221 fn test_concurrent_readers_after_write() {
222 use std::{sync::Arc, thread};
223
224 let dir = tempdir().unwrap();
225 let path = dir.path().join("test_cache_readers");
226
227 let cache = FsCache::new(&path, "r1").unwrap();
229 let data = cache
230 .load(|| TestData {
231 value: "Reader Hello".to_string(),
232 })
233 .unwrap();
234 assert_eq!(data.value, "Reader Hello");
235
236 let path = Arc::new(path);
238 let mut handles = Vec::new();
239 for _ in 0..16 {
240 let p = path.clone();
241 handles.push(thread::spawn(move || {
242 let c = FsCache::new(&p, "r1").unwrap();
243 for _ in 0..10 {
244 let d = c
245 .load(|| TestData {
246 value: "Should not be used".to_string(),
247 })
248 .unwrap();
249 assert_eq!(d.value, "Reader Hello");
250 }
251 }));
252 }
253
254 for h in handles {
255 h.join().expect("reader thread panicked");
256 }
257
258 dir.close().unwrap();
259 }
260}