1mod query;
62
63use std::error::Error as StdError;
64use std::fs;
65use std::io;
66use std::path::{Path, PathBuf};
67use std::thread;
68use std::time::{Duration, Instant, SystemTime};
69
70use flagset::FlagSet;
71use serde::{Deserialize, Serialize};
72use serde_json as json;
73use thiserror::Error;
74
75use powerpack_detach as detach;
76use powerpack_env as env;
77
78pub use crate::query::{Query, QueryError, QueryPolicy};
79
80const DATA: &str = "v1.json";
82
83#[derive(Debug, Error)]
85#[non_exhaustive]
86pub enum BuildError {
87 #[error("home directory not found")]
89 NoHomeDir,
90}
91
92#[derive(Debug, Error)]
94#[non_exhaustive]
95enum UpdateError {
96 #[error("io error")]
98 Io(#[from] io::Error),
99
100 #[error("serialization error")]
102 Serialize(#[from] json::Error),
103
104 #[error("update fn failed: {0}")]
106 UpdateFn(#[from] Box<dyn StdError + Send + Sync + 'static>),
107}
108
109#[derive(Debug, Clone)]
111pub struct Builder {
112 directory: Option<PathBuf>,
113 query_policy: FlagSet<QueryPolicy>,
114 ttl: Duration,
115 initial_poll: Option<Duration>,
116}
117
118#[derive(Debug)]
120pub struct Cache {
121 directory: PathBuf,
122 query_policy: FlagSet<QueryPolicy>,
123 ttl: Duration,
124 initial_poll: Option<Duration>,
125}
126
127#[derive(Debug, Clone, Deserialize, Serialize)]
131struct CacheData<'a, T> {
132 modified: SystemTime,
133 checksum: Option<&'a str>,
134 data: T,
135}
136
137impl Default for Builder {
138 #[inline]
139 fn default() -> Self {
140 Self::new()
141 }
142}
143
144impl Builder {
145 #[inline]
147 pub fn new() -> Self {
148 Builder {
149 directory: None,
150 query_policy: QueryPolicy::default_set(),
151 ttl: Duration::from_secs(60),
152 initial_poll: None,
153 }
154 }
155
156 #[inline]
166 pub fn directory(mut self, directory: impl Into<PathBuf>) -> Self {
167 self.directory = Some(directory.into());
168 self
169 }
170
171 pub fn policy(mut self, query_policy: impl Into<FlagSet<QueryPolicy>>) -> Self {
176 self.query_policy = query_policy.into();
177 self
178 }
179
180 #[inline]
189 pub fn ttl(mut self, ttl: Duration) -> Self {
190 self.ttl = ttl;
191 self
192 }
193
194 #[inline]
206 pub fn initial_poll(mut self, initial_poll: Duration) -> Self {
207 self.initial_poll = Some(initial_poll);
208 self
209 }
210
211 pub fn try_build(self) -> Result<Cache, BuildError> {
215 let Self {
216 directory,
217 query_policy,
218 ttl,
219 initial_poll,
220 } = self;
221
222 let directory = match directory {
223 Some(directory) => directory,
224 None => env::try_workflow_cache_or_default()
225 .ok_or(BuildError::NoHomeDir)?
226 .join("cache"),
227 };
228
229 Ok(Cache {
230 directory,
231 query_policy,
232 ttl,
233 initial_poll,
234 })
235 }
236
237 #[track_caller]
243 #[inline]
244 pub fn build(self) -> Cache {
245 self.try_build().expect("failed to build cache")
246 }
247}
248
249struct CacheDataHolder<'a, T> {
250 result: Result<CacheData<'a, T>, json::Error>,
251 is_bad_data: bool,
252 is_checksum_mismatch: bool,
253 is_expired: bool,
254}
255
256impl<'a, T> CacheDataHolder<'a, T> {
257 fn build(data: &'a [u8], checksum: Option<&str>, ttl: Duration) -> Self
258 where
259 T: for<'de> Deserialize<'de>,
260 {
261 let result: Result<CacheData<T>, _> = json::from_slice(data);
262 match &result {
263 Ok(d) => {
264 let is_checksum_mismatch = checksum.is_some() && d.checksum != checksum;
265 let is_expired = d.modified.elapsed().map_or(true, |d| d > ttl);
266 Self {
267 result,
268 is_bad_data: false,
269 is_checksum_mismatch,
270 is_expired,
271 }
272 }
273 Err(_) => Self {
274 result,
275 is_bad_data: true,
276 is_checksum_mismatch: false,
277 is_expired: false,
278 },
279 }
280 }
281
282 fn should_update(&self, policy: FlagSet<QueryPolicy>) -> bool {
283 policy.contains(QueryPolicy::UpdateAlways)
284 || self.is_bad_data && policy.contains(QueryPolicy::UpdateBadData)
285 || self.is_checksum_mismatch && policy.contains(QueryPolicy::UpdateChecksumMismatch)
286 || self.is_expired && policy.contains(QueryPolicy::UpdateExpired)
287 }
288
289 #[rustfmt::skip]
290 fn should_return(&self, policy: FlagSet<QueryPolicy>) -> bool {
291 policy.contains(QueryPolicy::ReturnAlways) || {
292 (!self.is_bad_data || policy.contains(QueryPolicy::ReturnBadDataErr))
293 && (!self.is_checksum_mismatch || policy.contains(QueryPolicy::ReturnChecksumMismatch))
294 && (!self.is_expired || policy.contains(QueryPolicy::ReturnExpired))
295 }
296 }
297
298 fn into_result(self, policy: FlagSet<QueryPolicy>) -> Result<T, QueryError> {
299 if self.should_return(policy) {
300 Ok(self.result.map(|c| c.data)?)
301 } else {
302 Err(QueryError::Miss)
303 }
304 }
305}
306
307impl Cache {
308 pub fn query<'a, T, E>(&self, query: Query<'a, T, E>) -> Result<T, QueryError>
310 where
311 T: Serialize + for<'de> Deserialize<'de>,
312 E: Into<Box<dyn std::error::Error + Send + Sync>>,
313 {
314 let Query {
315 key,
316 checksum,
317 policy,
318 ttl,
319 initial_poll,
320 update_fn,
321 ..
322 } = query;
323
324 let directory = self.directory.join(key);
325 let path = directory.join(DATA);
326
327 let checksum = checksum.as_deref();
328 let policy = policy.unwrap_or(self.query_policy);
329 let ttl = ttl.unwrap_or(self.ttl);
330 let initial_poll = initial_poll.or(self.initial_poll).map(|d| {
331 let sleep = (d / 5).min(Duration::from_millis(100)).min(d);
332 (d, sleep)
333 });
334
335 let update_cache = update_fn.map(|f| {
336 || match update(&directory, &path, checksum, f) {
337 Ok(true) => log::info!("cache: updated {key}"),
338 Ok(false) => log::debug!("cache: another process updated {key}"),
339 Err(err) => log::error!(
340 "cache: failed to update {key}: {}",
341 detach::format_err(&err)
342 ),
343 }
344 });
345
346 match fs::read(&path) {
347 Ok(data) => {
348 let data = CacheDataHolder::build(&data, checksum, ttl);
349 if let Some(update_cache) = update_cache {
350 if data.should_update(policy) {
351 detach::spawn(update_cache)?;
352 }
353 }
354 data.into_result(policy)
355 }
356
357 Err(err) if err.kind() == io::ErrorKind::NotFound => {
358 if let Some(update_cache) = update_cache {
359 detach::spawn(update_cache)?;
360 }
361
362 if let Some((poll_duration, poll_sleep)) = initial_poll {
364 let start = Instant::now();
365 while Instant::now().duration_since(start) < poll_duration {
366 thread::sleep(poll_sleep);
367 match fs::read(&path) {
368 Ok(data) => {
369 let data = CacheDataHolder::build(&data, checksum, ttl);
370 return data.into_result(policy);
371 }
372 Err(err) if err.kind() == io::ErrorKind::NotFound => continue,
373 Err(err) => return Err(err.into()),
374 }
375 }
376 }
377
378 Err(QueryError::Miss)
379 }
380
381 Err(err) => Err(err.into()),
382 }
383 }
384}
385
386fn update<'a, T, E>(
387 directory: &Path,
388 path: &Path,
389 checksum: Option<&str>,
390 f: Box<dyn FnOnce() -> Result<T, E> + 'a>,
391) -> Result<bool, UpdateError>
392where
393 T: Serialize + for<'de> Deserialize<'de>,
394 E: Into<Box<dyn std::error::Error + Send + Sync>>,
395{
396 fs::create_dir_all(directory)?;
397 let tmp = path.with_extension("tmp");
398 match fmutex::try_lock(directory)? {
399 Some(_guard) => {
400 let data = f().map_err(Into::into)?;
401 let file = fs::File::create(&tmp)?;
402 let modified = SystemTime::now();
403 json::to_writer(
404 &file,
405 &CacheData {
406 checksum,
407 modified,
408 data,
409 },
410 )?;
411 fs::rename(tmp, path)?;
412 Ok(true)
413 }
414 None => Ok(false),
415 }
416}