1mod query;
62
63use std::error::Error as StdError;
64use std::fs;
65use std::fs::TryLockError;
66use std::io;
67use std::path::{Path, PathBuf};
68use std::thread;
69use std::time::{Duration, Instant, SystemTime};
70
71use flagset::FlagSet;
72use serde::{Deserialize, Serialize};
73use serde_json as json;
74use thiserror::Error;
75
76use powerpack_detach as detach;
77use powerpack_env as env;
78
79pub use crate::query::{Query, QueryError, QueryPolicy};
80
81const DATA: &str = "v1.json";
83
84#[derive(Debug, Error)]
86#[non_exhaustive]
87pub enum BuildError {
88 #[error("home directory not found")]
90 NoHomeDir,
91}
92
93#[derive(Debug, Error)]
95#[non_exhaustive]
96enum UpdateError {
97 #[error("io error")]
99 Io(#[from] io::Error),
100
101 #[error("serialization error")]
103 Serialize(#[from] json::Error),
104
105 #[error("update fn failed: {0}")]
107 UpdateFn(#[from] Box<dyn StdError + Send + Sync + 'static>),
108}
109
110#[derive(Debug, Clone)]
112pub struct Builder {
113 directory: Option<PathBuf>,
114 query_policy: FlagSet<QueryPolicy>,
115 ttl: Duration,
116 initial_poll: Option<Duration>,
117}
118
119#[derive(Debug)]
121pub struct Cache {
122 directory: PathBuf,
123 query_policy: FlagSet<QueryPolicy>,
124 ttl: Duration,
125 initial_poll: Option<Duration>,
126}
127
128#[derive(Debug, Clone, Deserialize, Serialize)]
132struct CacheData<'a, T> {
133 modified: SystemTime,
134 checksum: Option<&'a str>,
135 data: T,
136}
137
138impl Default for Builder {
139 #[inline]
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145impl Builder {
146 #[inline]
148 pub fn new() -> Self {
149 Builder {
150 directory: None,
151 query_policy: QueryPolicy::default_set(),
152 ttl: Duration::from_secs(60),
153 initial_poll: None,
154 }
155 }
156
157 #[inline]
167 pub fn directory(mut self, directory: impl Into<PathBuf>) -> Self {
168 self.directory = Some(directory.into());
169 self
170 }
171
172 pub fn policy(mut self, query_policy: impl Into<FlagSet<QueryPolicy>>) -> Self {
177 self.query_policy = query_policy.into();
178 self
179 }
180
181 #[inline]
190 pub fn ttl(mut self, ttl: Duration) -> Self {
191 self.ttl = ttl;
192 self
193 }
194
195 #[inline]
207 pub fn initial_poll(mut self, initial_poll: Duration) -> Self {
208 self.initial_poll = Some(initial_poll);
209 self
210 }
211
212 pub fn try_build(self) -> Result<Cache, BuildError> {
216 let Self {
217 directory,
218 query_policy,
219 ttl,
220 initial_poll,
221 } = self;
222
223 let directory = match directory {
224 Some(directory) => directory,
225 None => env::try_workflow_cache_or_default()
226 .ok_or(BuildError::NoHomeDir)?
227 .join("cache"),
228 };
229
230 Ok(Cache {
231 directory,
232 query_policy,
233 ttl,
234 initial_poll,
235 })
236 }
237
238 #[track_caller]
244 #[inline]
245 pub fn build(self) -> Cache {
246 self.try_build().expect("failed to build cache")
247 }
248}
249
250struct CacheDataHolder<'a, T> {
251 result: Result<CacheData<'a, T>, json::Error>,
252 is_bad_data: bool,
253 is_checksum_mismatch: bool,
254 is_expired: bool,
255}
256
257impl<'a, T> CacheDataHolder<'a, T> {
258 fn build(data: &'a [u8], checksum: Option<&str>, ttl: Duration) -> Self
259 where
260 T: for<'de> Deserialize<'de>,
261 {
262 let result: Result<CacheData<T>, _> = json::from_slice(data);
263 match &result {
264 Ok(d) => {
265 let is_checksum_mismatch = checksum.is_some() && d.checksum != checksum;
266 let is_expired = d.modified.elapsed().map_or(true, |d| d > ttl);
267 Self {
268 result,
269 is_bad_data: false,
270 is_checksum_mismatch,
271 is_expired,
272 }
273 }
274 Err(_) => Self {
275 result,
276 is_bad_data: true,
277 is_checksum_mismatch: false,
278 is_expired: false,
279 },
280 }
281 }
282
283 fn should_update(&self, policy: FlagSet<QueryPolicy>) -> bool {
284 policy.contains(QueryPolicy::UpdateAlways)
285 || self.is_bad_data && policy.contains(QueryPolicy::UpdateBadData)
286 || self.is_checksum_mismatch && policy.contains(QueryPolicy::UpdateChecksumMismatch)
287 || self.is_expired && policy.contains(QueryPolicy::UpdateExpired)
288 }
289
290 #[rustfmt::skip]
291 fn should_return(&self, policy: FlagSet<QueryPolicy>) -> bool {
292 policy.contains(QueryPolicy::ReturnAlways) || {
293 (!self.is_bad_data || policy.contains(QueryPolicy::ReturnBadDataErr))
294 && (!self.is_checksum_mismatch || policy.contains(QueryPolicy::ReturnChecksumMismatch))
295 && (!self.is_expired || policy.contains(QueryPolicy::ReturnExpired))
296 }
297 }
298
299 fn into_result(self, policy: FlagSet<QueryPolicy>) -> Result<T, QueryError> {
300 if self.should_return(policy) {
301 Ok(self.result.map(|c| c.data)?)
302 } else {
303 Err(QueryError::Miss)
304 }
305 }
306}
307
308impl Cache {
309 pub fn query<'a, T, E>(&self, query: Query<'a, T, E>) -> Result<T, QueryError>
311 where
312 T: Serialize + for<'de> Deserialize<'de>,
313 E: Into<Box<dyn std::error::Error + Send + Sync>>,
314 {
315 let Query {
316 key,
317 checksum,
318 policy,
319 ttl,
320 initial_poll,
321 update_fn,
322 ..
323 } = query;
324
325 let directory = self.directory.join(key);
326 let path = directory.join(DATA);
327
328 let checksum = checksum.as_deref();
329 let policy = policy.unwrap_or(self.query_policy);
330 let ttl = ttl.unwrap_or(self.ttl);
331 let initial_poll = initial_poll.or(self.initial_poll).map(|d| {
332 let sleep = (d / 5).min(Duration::from_millis(100)).min(d);
333 (d, sleep)
334 });
335
336 let update_cache = update_fn.map(|f| {
337 || match update(&directory, &path, checksum, f) {
338 Ok(true) => log::info!("cache: updated {key}"),
339 Ok(false) => log::debug!("cache: another process updated {key}"),
340 Err(err) => log::error!(
341 "cache: failed to update {key}: {}",
342 detach::format_err(&err)
343 ),
344 }
345 });
346
347 match fs::read(&path) {
348 Ok(data) => {
349 let data = CacheDataHolder::build(&data, checksum, ttl);
350 if let Some(update_cache) = update_cache
351 && data.should_update(policy)
352 {
353 detach::spawn(update_cache)?;
354 }
355 data.into_result(policy)
356 }
357
358 Err(err) if err.kind() == io::ErrorKind::NotFound => {
359 if let Some(update_cache) = update_cache {
360 detach::spawn(update_cache)?;
361 }
362
363 if let Some((poll_duration, poll_sleep)) = initial_poll {
365 let start = Instant::now();
366 while Instant::now().duration_since(start) < poll_duration {
367 thread::sleep(poll_sleep);
368 match fs::read(&path) {
369 Ok(data) => {
370 let data = CacheDataHolder::build(&data, checksum, ttl);
371 return data.into_result(policy);
372 }
373 Err(err) if err.kind() == io::ErrorKind::NotFound => continue,
374 Err(err) => return Err(err.into()),
375 }
376 }
377 }
378
379 Err(QueryError::Miss)
380 }
381
382 Err(err) => Err(err.into()),
383 }
384 }
385}
386
387fn update<'a, T, E>(
388 directory: &Path,
389 path: &Path,
390 checksum: Option<&str>,
391 f: Box<dyn FnOnce() -> Result<T, E> + 'a>,
392) -> Result<bool, UpdateError>
393where
394 T: Serialize + for<'de> Deserialize<'de>,
395 E: Into<Box<dyn std::error::Error + Send + Sync>>,
396{
397 fs::create_dir_all(directory)?;
398 let tmp = path.with_extension("tmp");
399
400 match fs::File::open(directory)?.try_lock() {
401 Ok(()) => {
402 let data = f().map_err(Into::into)?;
403 let file = fs::File::create(&tmp)?;
404 let modified = SystemTime::now();
405 json::to_writer(
406 &file,
407 &CacheData {
408 checksum,
409 modified,
410 data,
411 },
412 )?;
413 fs::rename(tmp, path)?;
414 Ok(true)
415 }
416 Err(TryLockError::Error(err)) => Err(err.into()),
417 Err(TryLockError::WouldBlock) => Ok(false),
418 }
419}