1mod query;
62
63use std::error::Error as StdError;
64use std::fmt::Write;
65use std::fs;
66use std::io;
67use std::path::{Path, PathBuf};
68use std::thread;
69use std::time::{Duration, Instant, SystemTime};
70
71use flagset::FlagSet;
72use serde::{Deserialize, Serialize};
73use serde_json as json;
74use thiserror::Error;
75
76use powerpack_detach as detach;
77use powerpack_env as env;
78
79pub use crate::query::{Query, QueryError, QueryPolicy};
80
81const DATA: &str = "v1.json";
83
84#[derive(Debug, Error)]
86#[non_exhaustive]
87pub enum BuildError {
88 #[error("home directory not found")]
90 NoHomeDir,
91}
92
93#[derive(Debug, Error)]
95#[non_exhaustive]
96enum UpdateError {
97 #[error("io error")]
99 Io(#[from] io::Error),
100
101 #[error("serialization error")]
103 Serialize(#[from] json::Error),
104
105 #[error("update fn failed: {0}")]
107 UpdateFn(#[from] Box<dyn StdError + Send + Sync + 'static>),
108}
109
110#[derive(Debug, Clone)]
112pub struct Builder {
113 directory: Option<PathBuf>,
114 query_policy: FlagSet<QueryPolicy>,
115 ttl: Duration,
116 initial_poll: Option<Duration>,
117}
118
119#[derive(Debug)]
121pub struct Cache {
122 directory: PathBuf,
123 query_policy: FlagSet<QueryPolicy>,
124 ttl: Duration,
125 initial_poll: Option<Duration>,
126}
127
128#[derive(Debug, Clone, Deserialize, Serialize)]
132struct CacheData<'a, T> {
133 modified: SystemTime,
134 checksum: Option<&'a str>,
135 data: T,
136}
137
138impl Default for Builder {
139 #[inline]
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145impl Builder {
146 #[inline]
148 pub fn new() -> Self {
149 Builder {
150 directory: None,
151 query_policy: QueryPolicy::default_set(),
152 ttl: Duration::from_secs(60),
153 initial_poll: None,
154 }
155 }
156
157 #[inline]
167 pub fn directory(mut self, directory: impl Into<PathBuf>) -> Self {
168 self.directory = Some(directory.into());
169 self
170 }
171
172 pub fn policy(mut self, query_policy: impl Into<FlagSet<QueryPolicy>>) -> Self {
177 self.query_policy = query_policy.into();
178 self
179 }
180
181 #[inline]
190 pub fn ttl(mut self, ttl: Duration) -> Self {
191 self.ttl = ttl;
192 self
193 }
194
195 #[inline]
207 pub fn initial_poll(mut self, initial_poll: Duration) -> Self {
208 self.initial_poll = Some(initial_poll);
209 self
210 }
211
212 pub fn try_build(self) -> Result<Cache, BuildError> {
216 let Self {
217 directory,
218 query_policy,
219 ttl,
220 initial_poll,
221 } = self;
222
223 let directory = match directory {
224 Some(directory) => directory,
225 None => env::try_workflow_cache_or_default()
226 .ok_or(BuildError::NoHomeDir)?
227 .join("cache"),
228 };
229
230 Ok(Cache {
231 directory,
232 query_policy,
233 ttl,
234 initial_poll,
235 })
236 }
237
238 #[track_caller]
244 #[inline]
245 pub fn build(self) -> Cache {
246 self.try_build().expect("failed to build cache")
247 }
248}
249
250struct CacheDataHolder<'a, T> {
251 result: Result<CacheData<'a, T>, json::Error>,
252 is_bad_data: bool,
253 is_checksum_mismatch: bool,
254 is_expired: bool,
255}
256
257impl<'a, T> CacheDataHolder<'a, T> {
258 fn build(data: &'a [u8], checksum: Option<&str>, ttl: Duration) -> Self
259 where
260 T: for<'de> Deserialize<'de>,
261 {
262 let result: Result<CacheData<T>, _> = json::from_slice(data);
263 match &result {
264 Ok(d) => {
265 let is_checksum_mismatch = checksum.is_some() && d.checksum != checksum;
266 let is_expired = d.modified.elapsed().map_or(true, |d| d > ttl);
267 Self {
268 result,
269 is_bad_data: false,
270 is_checksum_mismatch,
271 is_expired,
272 }
273 }
274 Err(_) => Self {
275 result,
276 is_bad_data: true,
277 is_checksum_mismatch: false,
278 is_expired: false,
279 },
280 }
281 }
282
283 fn should_update(&self, policy: FlagSet<QueryPolicy>) -> bool {
284 policy.contains(QueryPolicy::UpdateAlways)
285 || self.is_bad_data && policy.contains(QueryPolicy::UpdateBadData)
286 || self.is_checksum_mismatch && policy.contains(QueryPolicy::UpdateChecksumMismatch)
287 || self.is_expired && policy.contains(QueryPolicy::UpdateExpired)
288 }
289
290 #[rustfmt::skip]
291 fn should_return(&self, policy: FlagSet<QueryPolicy>) -> bool {
292 policy.contains(QueryPolicy::ReturnAlways) || {
293 (!self.is_bad_data || policy.contains(QueryPolicy::ReturnBadDataErr))
294 && (!self.is_checksum_mismatch || policy.contains(QueryPolicy::ReturnChecksumMismatch))
295 && (!self.is_expired || policy.contains(QueryPolicy::ReturnExpired))
296 }
297 }
298
299 fn into_result(self, policy: FlagSet<QueryPolicy>) -> Result<T, QueryError> {
300 if self.should_return(policy) {
301 Ok(self.result.map(|c| c.data)?)
302 } else {
303 Err(QueryError::Miss)
304 }
305 }
306}
307
308impl Cache {
309 pub fn query<'a, T, E>(&self, query: Query<'a, T, E>) -> Result<T, QueryError>
311 where
312 T: Serialize + for<'de> Deserialize<'de>,
313 E: Into<Box<dyn std::error::Error + Send + Sync>>,
314 {
315 let Query {
316 key,
317 checksum,
318 policy,
319 ttl,
320 initial_poll,
321 update_fn,
322 ..
323 } = query;
324
325 let directory = self.directory.join(key);
326 let path = directory.join(DATA);
327
328 let checksum = checksum.as_deref();
329 let policy = policy.unwrap_or(self.query_policy);
330 let ttl = ttl.unwrap_or(self.ttl);
331 let initial_poll = initial_poll.or(self.initial_poll).map(|d| {
332 let sleep = (d / 5).min(Duration::from_millis(100)).min(d);
333 (d, sleep)
334 });
335
336 let update_cache = update_fn.map(|f| {
337 || match update(&directory, &path, checksum, f) {
338 Ok(true) => log::info!("cache: updated {key}"),
339 Ok(false) => log::warn!("cache: another process updated {key}"),
340 Err(err) => log::error!("cache: failed to update {key}: {}", format_err(&err)),
341 }
342 });
343
344 match fs::read(&path) {
345 Ok(data) => {
346 let data = CacheDataHolder::build(&data, checksum, ttl);
347 if let Some(update_cache) = update_cache {
348 if data.should_update(policy) {
349 detach::spawn(update_cache)?;
350 }
351 }
352 data.into_result(policy)
353 }
354
355 Err(err) if err.kind() == io::ErrorKind::NotFound => {
356 if let Some(update_cache) = update_cache {
357 detach::spawn(update_cache)?;
358 }
359
360 if let Some((poll_duration, poll_sleep)) = initial_poll {
362 let start = Instant::now();
363 while Instant::now().duration_since(start) < poll_duration {
364 thread::sleep(poll_sleep);
365 match fs::read(&path) {
366 Ok(data) => {
367 let data = CacheDataHolder::build(&data, checksum, ttl);
368 return data.into_result(policy);
369 }
370 Err(err) if err.kind() == io::ErrorKind::NotFound => continue,
371 Err(err) => return Err(err.into()),
372 }
373 }
374 }
375
376 Err(QueryError::Miss)
377 }
378
379 Err(err) => Err(err.into()),
380 }
381 }
382}
383
384fn update<'a, T, E>(
385 directory: &Path,
386 path: &Path,
387 checksum: Option<&str>,
388 f: Box<dyn FnOnce() -> Result<T, E> + 'a>,
389) -> Result<bool, UpdateError>
390where
391 T: Serialize + for<'de> Deserialize<'de>,
392 E: Into<Box<dyn std::error::Error + Send + Sync>>,
393{
394 fs::create_dir_all(directory)?;
395 let tmp = path.with_extension("tmp");
396 match fmutex::try_lock(directory)? {
397 Some(_guard) => {
398 let data = f().map_err(Into::into)?;
399 let file = fs::File::create(&tmp)?;
400 let modified = SystemTime::now();
401 json::to_writer(
402 &file,
403 &CacheData {
404 checksum,
405 modified,
406 data,
407 },
408 )?;
409 fs::rename(tmp, path)?;
410 Ok(true)
411 }
412 None => Ok(false),
413 }
414}
415
416fn format_err(err: &(dyn StdError + 'static)) -> String {
417 let mut out = err.to_string();
418 let mut tmp = String::new();
419 let mut source = err.source();
420 while let Some(err) = source {
421 write!(&mut tmp, "{err}").expect("fmt error");
422 if !out.contains(&tmp) {
423 out.push_str(": ");
424 out.push_str(&tmp);
425 }
426 source = err.source();
427 tmp.clear();
428 }
429 out
430}