1use crate::error::Result;
18use crate::model::Component;
19use serde::Serialize;
20use serde::de::DeserializeOwned;
21use sha2::{Digest, Sha256};
22use std::fs;
23use std::path::{Path, PathBuf};
24use std::sync::atomic::{AtomicBool, Ordering};
25use std::time::Duration;
26
27static OFFLINE: AtomicBool = AtomicBool::new(false);
37
38pub fn set_offline(offline: bool) {
40 OFFLINE.store(offline, Ordering::Relaxed);
41}
42
43#[must_use]
45pub fn is_offline() -> bool {
46 OFFLINE.load(Ordering::Relaxed)
47}
48
49pub const CACHE_SCHEMA_VERSION: u32 = 1;
56
57#[must_use]
70pub fn cache_dir() -> Option<PathBuf> {
71 #[cfg(target_os = "macos")]
72 {
73 std::env::var("HOME")
74 .ok()
75 .map(|h| PathBuf::from(h).join("Library").join("Caches"))
76 }
77 #[cfg(target_os = "linux")]
78 {
79 std::env::var("XDG_CACHE_HOME")
80 .ok()
81 .map(PathBuf::from)
82 .or_else(|| {
83 std::env::var("HOME")
84 .ok()
85 .map(|h| PathBuf::from(h).join(".cache"))
86 })
87 }
88 #[cfg(target_os = "windows")]
89 {
90 std::env::var("LOCALAPPDATA").ok().map(PathBuf::from)
91 }
92 #[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))]
93 {
94 std::env::var("HOME")
95 .ok()
96 .map(|h| PathBuf::from(h).join(".cache"))
97 }
98}
99
100#[must_use]
107pub fn root_cache_dir() -> PathBuf {
108 cache_dir()
109 .unwrap_or_else(|| PathBuf::from(".cache"))
110 .join("sbom-tools")
111}
112
113#[must_use]
118pub fn namespaced_cache_dir(namespace: &str) -> PathBuf {
119 root_cache_dir().join(namespace)
120}
121
122pub fn offline_guard(what: &str) -> Result<()> {
138 if is_offline() {
139 return Err(crate::error::SbomDiffError::enrichment(
140 "offline mode",
141 crate::error::EnrichmentErrorKind::Offline(what.to_string()),
142 ));
143 }
144 Ok(())
145}
146
147#[cfg(feature = "enrichment")]
148pub fn http_client(timeout: Duration) -> reqwest::Result<reqwest::blocking::Client> {
149 reqwest::blocking::Client::builder()
150 .timeout(timeout)
151 .user_agent(concat!(
152 env!("CARGO_PKG_NAME"),
153 "/",
154 env!("CARGO_PKG_VERSION")
155 ))
156 .build()
157}
158
159const MAX_BACKOFF: Duration = Duration::from_secs(30);
162
163pub const MAX_RESPONSE_BYTES: u64 = 256 * 1024 * 1024;
172
173#[cfg(feature = "enrichment")]
180pub fn read_bounded(response: reqwest::blocking::Response) -> Result<Vec<u8>> {
181 read_bounded_with_max(response, MAX_RESPONSE_BYTES)
182}
183
184#[cfg(feature = "enrichment")]
189pub(crate) fn read_bounded_with_max(
190 response: reqwest::blocking::Response,
191 max_bytes: u64,
192) -> Result<Vec<u8>> {
193 if let Some(len) = response.content_length()
194 && len > max_bytes
195 {
196 return Err(oversized_error(len, max_bytes));
197 }
198
199 let bytes = response
200 .bytes()
201 .map_err(|e| network_error("reading response body", &e))?;
202
203 if bytes.len() as u64 > max_bytes {
204 return Err(oversized_error(bytes.len() as u64, max_bytes));
205 }
206
207 Ok(bytes.to_vec())
208}
209
210#[cfg(feature = "enrichment")]
212fn oversized_error(len: u64, max_bytes: u64) -> crate::error::SbomDiffError {
213 crate::error::SbomDiffError::enrichment(
214 "response too large",
215 crate::error::EnrichmentErrorKind::NetworkError(format!(
216 "response body of {len} bytes exceeds the {max_bytes}-byte limit"
217 )),
218 )
219}
220
221#[must_use]
227pub fn backoff_delay(attempt: u32, retry_after: Option<Duration>) -> Duration {
228 if let Some(after) = retry_after {
229 return after.min(MAX_BACKOFF);
230 }
231 let secs = 1u64
232 .checked_shl(attempt.saturating_sub(1))
233 .unwrap_or(u64::MAX);
234 Duration::from_secs(secs).min(MAX_BACKOFF)
235}
236
237#[cfg(feature = "enrichment")]
242fn parse_retry_after(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
243 headers
244 .get(reqwest::header::RETRY_AFTER)?
245 .to_str()
246 .ok()?
247 .trim()
248 .parse::<u64>()
249 .ok()
250 .map(Duration::from_secs)
251}
252
253#[cfg(feature = "enrichment")]
264pub fn get_with_retry(
265 client: &reqwest::blocking::Client,
266 url: &str,
267 max_retries: u8,
268) -> Result<reqwest::blocking::Response> {
269 offline_guard(url)?;
270
271 for attempt in 0..=u32::from(max_retries) {
272 if attempt > 0 {
273 tracing::debug!("retry attempt {attempt} for {url}");
274 }
275
276 match client.get(url).send() {
277 Ok(response) => {
278 let status = response.status();
279 let retryable =
280 status == reqwest::StatusCode::TOO_MANY_REQUESTS || status.is_server_error();
281 if retryable && attempt < u32::from(max_retries) {
282 let retry_after = if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
283 parse_retry_after(response.headers())
284 } else {
285 None
286 };
287 std::thread::sleep(backoff_delay(attempt + 1, retry_after));
288 continue;
289 }
290 return Ok(response);
291 }
292 Err(e) => {
293 if attempt < u32::from(max_retries) {
294 std::thread::sleep(backoff_delay(attempt + 1, None));
295 continue;
296 }
297 return Err(network_error("request failed", &e));
298 }
299 }
300 }
301
302 Err(network_error_msg("retry loop returned no response"))
304}
305
306#[cfg(feature = "enrichment")]
308fn network_error(context: &str, err: &reqwest::Error) -> crate::error::SbomDiffError {
309 crate::error::SbomDiffError::enrichment(
310 context,
311 crate::error::EnrichmentErrorKind::NetworkError(err.to_string()),
312 )
313}
314
315#[cfg(feature = "enrichment")]
317fn network_error_msg(msg: &str) -> crate::error::SbomDiffError {
318 crate::error::SbomDiffError::enrichment(
319 "network",
320 crate::error::EnrichmentErrorKind::NetworkError(msg.to_string()),
321 )
322}
323
324#[derive(Debug, Clone, Hash, PartialEq, Eq)]
330pub struct CacheKey {
331 pub purl: Option<String>,
333 pub name: String,
335 pub ecosystem: Option<String>,
337 pub version: Option<String>,
339}
340
341impl CacheKey {
342 #[must_use]
344 pub const fn new(
345 purl: Option<String>,
346 name: String,
347 ecosystem: Option<String>,
348 version: Option<String>,
349 ) -> Self {
350 Self {
351 purl,
352 name,
353 ecosystem,
354 version,
355 }
356 }
357
358 #[must_use]
360 pub fn to_filename(&self) -> String {
361 let mut hasher = Sha256::new();
362 hasher.update(format!(
363 "purl:{:?}|name:{}|eco:{:?}|ver:{:?}",
364 self.purl, self.name, self.ecosystem, self.version
365 ));
366 let hash = hasher.finalize();
367 let hex: String = hash.iter().map(|b| format!("{b:02x}")).collect();
368 format!("{hex}.json")
369 }
370
371 #[must_use]
373 pub const fn is_queryable(&self) -> bool {
374 self.purl.is_some() || (self.ecosystem.is_some() && self.version.is_some())
376 }
377}
378
379#[derive(Debug, Serialize, serde::Deserialize)]
388struct CacheEnvelope<T> {
389 schema_version: u32,
391 payload: T,
393}
394
395pub struct JsonCache<T> {
402 cache_dir: PathBuf,
403 ttl: Duration,
404 _marker: std::marker::PhantomData<fn() -> T>,
405}
406
407impl<T> JsonCache<T>
408where
409 T: Serialize + DeserializeOwned,
410{
411 pub fn new(cache_dir: PathBuf, ttl: Duration) -> Result<Self> {
413 if !cache_dir.exists() {
414 fs::create_dir_all(&cache_dir)?;
415 }
416 Ok(Self {
417 cache_dir,
418 ttl,
419 _marker: std::marker::PhantomData,
420 })
421 }
422
423 #[must_use]
425 pub fn path_for(&self, file_name: &str) -> PathBuf {
426 self.cache_dir.join(file_name)
427 }
428
429 #[must_use]
431 pub fn dir(&self) -> &Path {
432 &self.cache_dir
433 }
434
435 #[must_use]
445 pub fn get_named(&self, file_name: &str) -> Option<T> {
446 let (value, stale_by) = self.get_named_allow_stale(file_name)?;
447 if let Some(age) = stale_by {
448 tracing::warn!(
449 "serving stale cache entry {file_name}: {} day(s) past its TTL (offline mode)",
450 age.as_secs() / 86_400
451 );
452 }
453 Some(value)
454 }
455
456 #[must_use]
467 pub fn get_named_allow_stale(&self, file_name: &str) -> Option<(T, Option<Duration>)> {
468 let path = self.path_for(file_name);
469
470 let metadata = fs::metadata(&path).ok()?;
471 let modified = metadata.modified().ok()?;
472 let age = modified.elapsed().ok()?;
473 let mut stale_by = None;
474 if age > self.ttl {
475 if is_offline() {
479 stale_by = Some(age - self.ttl);
480 } else {
481 let _ = fs::remove_file(&path);
482 return None;
483 }
484 }
485
486 let data = fs::read_to_string(&path).ok()?;
487 let envelope: CacheEnvelope<T> = serde_json::from_str(&data).ok()?;
488 if envelope.schema_version != CACHE_SCHEMA_VERSION {
489 let _ = fs::remove_file(&path);
492 return None;
493 }
494 Some((envelope.payload, stale_by))
495 }
496
497 pub fn set_named<V: Serialize + ?Sized>(&self, file_name: &str, value: &V) -> Result<()> {
499 if !self.cache_dir.exists() {
500 fs::create_dir_all(&self.cache_dir)?;
501 }
502 let envelope = CacheEnvelope {
503 schema_version: CACHE_SCHEMA_VERSION,
504 payload: value,
505 };
506 let data = serde_json::to_string(&envelope)?;
507 write_atomic(&self.path_for(file_name), data.as_bytes())?;
508 Ok(())
509 }
510
511 #[must_use]
513 pub fn get(&self, key: &CacheKey) -> Option<T> {
514 self.get_named(&key.to_filename())
515 }
516
517 pub fn set<V: Serialize + ?Sized>(&self, key: &CacheKey, value: &V) -> Result<()> {
519 self.set_named(&key.to_filename(), value)
520 }
521
522 pub fn remove(&self, key: &CacheKey) -> Result<()> {
524 let path = self.path_for(&key.to_filename());
525 if path.exists() {
526 fs::remove_file(path)?;
527 }
528 Ok(())
529 }
530
531 pub fn clear(&self) -> Result<()> {
533 if self.cache_dir.exists() {
534 for entry in fs::read_dir(&self.cache_dir)? {
535 let entry = entry?;
536 if entry.path().extension().is_some_and(|e| e == "json") {
537 let _ = fs::remove_file(entry.path());
538 }
539 }
540 }
541 Ok(())
542 }
543
544 #[must_use]
546 pub fn stats(&self) -> CacheStats {
547 let mut stats = CacheStats::default();
548
549 if let Ok(entries) = fs::read_dir(&self.cache_dir) {
550 for entry in entries.flatten() {
551 if entry.path().extension().is_some_and(|e| e == "json") {
552 stats.total_entries += 1;
553 if let Ok(metadata) = entry.metadata() {
554 stats.total_size += metadata.len();
555 if let Ok(modified) = metadata.modified()
556 && let Ok(age) = modified.elapsed()
557 && age > self.ttl
558 {
559 stats.expired_entries += 1;
560 }
561 }
562 }
563 }
564 }
565
566 stats
567 }
568}
569
570fn write_atomic(path: &Path, bytes: &[u8]) -> Result<()> {
575 let parent = path.parent().unwrap_or_else(|| Path::new("."));
576 let file_name = path
577 .file_name()
578 .map_or_else(|| "cache".to_string(), |n| n.to_string_lossy().into_owned());
579 let tmp = parent.join(format!(".{file_name}.{}.tmp", std::process::id()));
580
581 fs::write(&tmp, bytes)?;
582 match fs::rename(&tmp, path) {
583 Ok(()) => Ok(()),
584 Err(e) => {
585 let _ = fs::remove_file(&tmp);
586 Err(e.into())
587 }
588 }
589}
590
591#[derive(Debug, Default)]
593pub struct CacheStats {
594 pub total_entries: usize,
596 pub expired_entries: usize,
598 pub total_size: u64,
600}
601
602pub trait EnrichmentSource {
616 type Stats;
618
619 fn name(&self) -> &'static str;
621
622 fn cache_namespace(&self) -> &'static str;
624
625 fn cache_ttl(&self) -> Duration;
627
628 fn enrich(&mut self, components: &mut [Component]) -> Result<Self::Stats>;
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635 use serde::Deserialize;
636
637 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
638 struct Payload {
639 value: u32,
640 label: String,
641 }
642
643 fn sample() -> Payload {
644 Payload {
645 value: 7,
646 label: "hello".to_string(),
647 }
648 }
649
650 #[test]
651 fn namespaced_cache_dir_includes_namespace() {
652 let dir = namespaced_cache_dir("osv");
653 let s = dir.to_string_lossy();
654 assert!(s.contains("sbom-tools"));
655 assert!(s.ends_with("osv"));
656 }
657
658 #[test]
659 fn backoff_is_exponential_and_capped() {
660 assert_eq!(backoff_delay(1, None), Duration::from_secs(1));
661 assert_eq!(backoff_delay(2, None), Duration::from_secs(2));
662 assert_eq!(backoff_delay(3, None), Duration::from_secs(4));
663 assert_eq!(backoff_delay(20, None), MAX_BACKOFF);
665 }
666
667 #[test]
668 fn backoff_honors_retry_after_but_caps_it() {
669 assert_eq!(
670 backoff_delay(1, Some(Duration::from_secs(3))),
671 Duration::from_secs(3)
672 );
673 assert_eq!(
674 backoff_delay(1, Some(Duration::from_secs(120))),
675 MAX_BACKOFF
676 );
677 }
678
679 #[test]
680 fn cache_roundtrip_survives_reopen() {
681 let tmp = tempfile::tempdir().unwrap();
682 {
683 let cache: JsonCache<Payload> =
684 JsonCache::new(tmp.path().to_path_buf(), Duration::from_secs(3600)).unwrap();
685 cache.set_named("entry", &sample()).unwrap();
686 }
687 let cache: JsonCache<Payload> =
689 JsonCache::new(tmp.path().to_path_buf(), Duration::from_secs(3600)).unwrap();
690 assert_eq!(cache.get_named("entry"), Some(sample()));
691 }
692
693 #[test]
694 fn atomic_write_leaves_no_temp_files() {
695 let tmp = tempfile::tempdir().unwrap();
696 let cache: JsonCache<Payload> =
697 JsonCache::new(tmp.path().to_path_buf(), Duration::from_secs(3600)).unwrap();
698 cache.set_named("entry", &sample()).unwrap();
699
700 let tmp_files: Vec<_> = fs::read_dir(tmp.path())
701 .unwrap()
702 .flatten()
703 .filter(|e| e.path().extension().is_some_and(|x| x == "tmp"))
704 .collect();
705 assert!(tmp_files.is_empty(), "temp file should be renamed away");
706 }
707
708 #[test]
709 fn schema_version_mismatch_invalidates_entry() {
710 let tmp = tempfile::tempdir().unwrap();
711 let cache: JsonCache<Payload> =
712 JsonCache::new(tmp.path().to_path_buf(), Duration::from_secs(3600)).unwrap();
713
714 let stale = format!(
716 "{{\"schema_version\":{},\"payload\":{{\"value\":1,\"label\":\"x\"}}}}",
717 CACHE_SCHEMA_VERSION + 1
718 );
719 let path = cache.path_for("entry");
720 fs::write(&path, stale).unwrap();
721
722 assert!(
723 cache.get_named("entry").is_none(),
724 "mismatched schema version must be a miss"
725 );
726 assert!(!path.exists(), "stale entry should be evicted");
727 }
728
729 #[test]
730 fn ttl_expiry_evicts_entry() {
731 let tmp = tempfile::tempdir().unwrap();
732 let cache: JsonCache<Payload> =
735 JsonCache::new(tmp.path().to_path_buf(), Duration::from_millis(200)).unwrap();
736 cache.set_named("entry", &sample()).unwrap();
737 assert!(cache.get_named("entry").is_some());
738
739 std::thread::sleep(Duration::from_millis(400));
740 assert!(
741 cache.get_named("entry").is_none(),
742 "entry past TTL must be a miss"
743 );
744 assert!(
745 !cache.path_for("entry").exists(),
746 "expired entry should be evicted"
747 );
748 }
749
750 #[cfg(feature = "enrichment")]
751 #[test]
752 fn read_bounded_rejects_oversized_body() {
753 use httpmock::prelude::*;
754
755 let server = MockServer::start();
756 let body = "x".repeat(1024);
758 let mock = server.mock(|when, then| {
759 when.method(GET).path("/big");
760 then.status(200).body(&body);
761 });
762
763 let client = http_client(Duration::from_secs(5)).unwrap();
764 let resp = get_with_retry(&client, &format!("{}/big", server.base_url()), 0).unwrap();
765 mock.assert();
766
767 let err =
768 read_bounded_with_max(resp, 16).expect_err("a body over the cap must be rejected");
769 assert!(
772 err.to_string().contains("too large"),
773 "error must explain the size-cap rejection, got: {err}"
774 );
775 let detail = std::error::Error::source(&err)
776 .map(ToString::to_string)
777 .unwrap_or_default();
778 assert!(
779 detail.contains("exceeds"),
780 "source must carry the byte-precise detail, got: {detail}"
781 );
782 }
783
784 #[cfg(feature = "enrichment")]
785 #[test]
786 fn read_bounded_accepts_body_within_cap() {
787 use httpmock::prelude::*;
788
789 let server = MockServer::start();
790 let mock = server.mock(|when, then| {
791 when.method(GET).path("/ok");
792 then.status(200).body("hello");
793 });
794
795 let client = http_client(Duration::from_secs(5)).unwrap();
796 let resp = get_with_retry(&client, &format!("{}/ok", server.base_url()), 0).unwrap();
797 mock.assert();
798
799 let bytes = read_bounded_with_max(resp, MAX_RESPONSE_BYTES).unwrap();
800 assert_eq!(bytes, b"hello");
801 }
802
803 #[test]
804 fn clear_and_stats() {
805 let tmp = tempfile::tempdir().unwrap();
806 let cache: JsonCache<Payload> =
807 JsonCache::new(tmp.path().to_path_buf(), Duration::from_secs(3600)).unwrap();
808 cache.set_named("a.json", &sample()).unwrap();
811 cache.set_named("b.json", &sample()).unwrap();
812 assert_eq!(cache.stats().total_entries, 2);
813 cache.clear().unwrap();
814 assert_eq!(cache.stats().total_entries, 0);
815 }
816}