Skip to main content

pysentry/cache/
audit.rs

1// SPDX-License-Identifier: MIT
2
3use super::storage::{Cache, CacheBucket, CacheEntry, Freshness};
4use crate::types::{ResolutionCacheEntry, ResolverType};
5use anyhow::Result;
6use chrono::{DateTime, Utc};
7use rustc_hash::FxHasher;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::hash::Hasher;
11use std::time::Duration;
12
13#[derive(Debug, Serialize, Deserialize)]
14pub struct DatabaseMetadata {
15    pub last_updated: DateTime<Utc>,
16    pub version: String,
17    pub advisory_count: usize,
18}
19
20pub struct AuditCache {
21    cache: Cache,
22}
23
24impl AuditCache {
25    pub fn new(cache_dir: std::path::PathBuf) -> Self {
26        Self {
27            cache: Cache::new(cache_dir),
28        }
29    }
30
31    pub fn database_entry(&self, source: &str) -> CacheEntry {
32        self.cache.entry(
33            CacheBucket::VulnerabilityDatabase,
34            &format!("{source}-database"),
35        )
36    }
37
38    pub fn metadata_entry(&self) -> CacheEntry {
39        self.cache.entry(CacheBucket::VulnerabilityDatabase, "meta")
40    }
41
42    pub fn index_entry(&self) -> CacheEntry {
43        self.cache
44            .entry(CacheBucket::VulnerabilityDatabase, "index")
45    }
46
47    pub fn should_refresh(&self, ttl_hours: u64) -> Result<bool> {
48        let meta_entry = self.metadata_entry();
49        let ttl = Duration::from_secs(ttl_hours * 3600);
50
51        match meta_entry.freshness(ttl) {
52            Ok(Freshness::Fresh) => Ok(false),
53            _ => Ok(true), // Stale or doesn't exist
54        }
55    }
56
57    pub async fn read_metadata(&self) -> Result<Option<DatabaseMetadata>> {
58        let entry = self.metadata_entry();
59        let content = match entry.read().await {
60            Ok(data) => data,
61            Err(_) => return Ok(None),
62        };
63
64        let metadata: DatabaseMetadata = serde_json::from_slice(&content)?;
65        Ok(Some(metadata))
66    }
67
68    pub async fn write_metadata(&self, metadata: &DatabaseMetadata) -> Result<()> {
69        let entry = self.metadata_entry();
70        let content = serde_json::to_vec_pretty(metadata)?;
71        entry.write(&content).await?;
72        Ok(())
73    }
74
75    // Resolution Cache Methods
76
77    /// Generate cache key for resolution caching
78    pub fn generate_resolution_cache_key(
79        &self,
80        requirements_content: &str,
81        resolver_type: &ResolverType,
82        resolver_version: &str,
83        python_version: &str,
84        platform: &str,
85        environment_markers: &HashMap<String, String>,
86    ) -> String {
87        let mut hasher = FxHasher::default();
88        hasher.write(requirements_content.as_bytes());
89        hasher.write(resolver_type.to_string().as_bytes());
90        hasher.write(resolver_version.as_bytes());
91        hasher.write(python_version.as_bytes());
92        hasher.write(platform.as_bytes());
93
94        let mut marker_items: Vec<_> = environment_markers.iter().collect();
95        marker_items.sort_by_key(|(k, _)| *k);
96        for (key, value) in marker_items {
97            hasher.write(key.as_bytes());
98            hasher.write(value.as_bytes());
99        }
100
101        let hash = hasher.finish();
102        let content_hash = format!("{hash:x}");
103
104        format!(
105            "{}-py{}-{}-{}",
106            resolver_type, python_version, platform, &content_hash
107        )
108    }
109
110    pub fn resolution_entry(&self, cache_key: &str) -> CacheEntry {
111        self.cache.entry(
112            CacheBucket::DependencyResolution,
113            &format!("{cache_key}.resolution"),
114        )
115    }
116
117    pub fn should_refresh_resolution(&self, cache_key: &str, ttl_hours: u64) -> Result<bool> {
118        let entry = self.resolution_entry(cache_key);
119        let ttl = Duration::from_secs(ttl_hours * 3600);
120
121        match entry.freshness(ttl) {
122            Ok(Freshness::Fresh) => Ok(false),
123            _ => Ok(true), // Stale or doesn't exist
124        }
125    }
126
127    pub async fn read_resolution_cache(
128        &self,
129        cache_key: &str,
130    ) -> Result<Option<ResolutionCacheEntry>> {
131        let entry = self.resolution_entry(cache_key);
132        let content = match entry.read().await {
133            Ok(data) => data,
134            Err(_) => return Ok(None),
135        };
136
137        let cache_entry: ResolutionCacheEntry = serde_json::from_slice(&content)?;
138        Ok(Some(cache_entry))
139    }
140
141    pub async fn write_resolution_cache(
142        &self,
143        cache_key: &str,
144        cache_entry: &ResolutionCacheEntry,
145    ) -> Result<()> {
146        let entry = self.resolution_entry(cache_key);
147        let content = serde_json::to_vec_pretty(cache_entry)?;
148        entry.write(&content).await?;
149        Ok(())
150    }
151
152    pub async fn clear_resolution_cache_entry(&self, cache_key: &str) -> Result<()> {
153        let entry = self.resolution_entry(cache_key);
154        let cache_file = entry.path();
155
156        if cache_file.exists() {
157            if let Err(e) = tokio::fs::remove_file(&cache_file).await {
158                tracing::warn!("Failed to remove cache entry {:?}: {}", cache_file, e);
159                return Err(anyhow::anyhow!("Failed to remove cache entry: {}", e));
160            } else {
161                tracing::debug!("Cleared resolution cache entry: {:?}", cache_file);
162            }
163        }
164
165        Ok(())
166    }
167
168    pub async fn clear_resolution_cache(&self, resolver_type: Option<ResolverType>) -> Result<()> {
169        use fs_err as fs;
170
171        let entry = self.cache.entry(CacheBucket::DependencyResolution, "");
172        let cache_dir = entry
173            .path()
174            .parent()
175            .ok_or_else(|| anyhow::anyhow!("Invalid cache directory"))?;
176
177        if !cache_dir.exists() {
178            return Ok(()); // Nothing to clear
179        }
180
181        let entries = fs::read_dir(cache_dir)?;
182        for entry in entries {
183            let entry = entry?;
184            let path = entry.path();
185
186            if let Some(file_name) = path.file_name().and_then(|n| n.to_str()) {
187                if file_name.ends_with(".resolution.cache") {
188                    if let Some(ref rt) = resolver_type {
189                        if !file_name.starts_with(&rt.to_string()) {
190                            continue;
191                        }
192                    }
193
194                    if let Err(e) = fs::remove_file(&path) {
195                        tracing::warn!("Failed to remove cache file {:?}: {}", path, e);
196                    } else {
197                        tracing::debug!("Cleared resolution cache file: {:?}", path);
198                    }
199                }
200            }
201        }
202
203        Ok(())
204    }
205
206    pub async fn get_resolution_cache_stats(&self) -> Result<ResolutionCacheStats> {
207        use fs_err as fs;
208
209        let entry = self.cache.entry(CacheBucket::DependencyResolution, "");
210        let cache_dir = entry
211            .path()
212            .parent()
213            .ok_or_else(|| anyhow::anyhow!("Invalid cache directory"))?;
214
215        let mut stats = ResolutionCacheStats::default();
216
217        if !cache_dir.exists() {
218            return Ok(stats);
219        }
220
221        let entries = fs::read_dir(cache_dir)?;
222        for entry in entries.flatten() {
223            let path = entry.path();
224            if let Some(file_name) = path.file_name().and_then(|n| n.to_str()) {
225                if file_name.ends_with(".resolution.cache") {
226                    stats.total_entries += 1;
227                    if let Ok(metadata) = fs::metadata(&path) {
228                        stats.total_size_bytes += metadata.len();
229
230                        // Count by resolver type
231                        if file_name.starts_with("uv-") {
232                            stats.uv_entries += 1;
233                        } else if file_name.starts_with("pip-tools-") {
234                            stats.pip_tools_entries += 1;
235                        }
236                    }
237                }
238            }
239        }
240
241        Ok(stats)
242    }
243}
244
245/// Resolution cache statistics
246#[derive(Debug, Default)]
247pub struct ResolutionCacheStats {
248    pub total_entries: usize,
249    pub total_size_bytes: u64,
250    pub uv_entries: usize,
251    pub pip_tools_entries: usize,
252}
253
254/// Normalize package name per PEP 503 for consistent cache keys
255/// https://peps.python.org/pep-0503/#normalized-names
256fn normalize_package_name(name: &str) -> String {
257    name.to_lowercase().replace(['-', '.', '_'], "-")
258}
259
260impl AuditCache {
261    // Project Status Cache Methods (PEP 792)
262
263    /// Get a cache entry for project status
264    pub fn project_status_entry(&self, package_name: &str) -> CacheEntry {
265        let normalized = normalize_package_name(package_name);
266        self.cache.entry(
267            CacheBucket::ProjectStatus,
268            &format!("status-{}", normalized),
269        )
270    }
271
272    /// Check if project status cache should be refreshed
273    pub fn should_refresh_project_status(&self, package_name: &str, ttl_hours: u64) -> bool {
274        let entry = self.project_status_entry(package_name);
275        let ttl = std::time::Duration::from_secs(ttl_hours * 3600);
276
277        !matches!(entry.freshness(ttl), Ok(Freshness::Fresh))
278    }
279
280    pub fn feedback_entry(&self) -> CacheEntry {
281        self.cache
282            .entry(CacheBucket::UserMessages, "last_feedback_shown")
283    }
284
285    pub async fn should_show_feedback(&self) -> bool {
286        let entry = self.feedback_entry();
287        let one_day = Duration::from_secs(24 * 3600);
288
289        match entry.freshness(one_day) {
290            Ok(Freshness::Fresh) => false, // Shown recently, don't show
291            _ => true,                     // Stale or doesn't exist, show feedback
292        }
293    }
294
295    pub async fn record_feedback_shown(&self) -> Result<()> {
296        let entry = self.feedback_entry();
297        let now = Utc::now();
298        let timestamp = serde_json::to_vec(&now)?;
299        entry.write(&timestamp).await?;
300        Ok(())
301    }
302
303    pub fn update_check_entry(&self) -> CacheEntry {
304        self.cache
305            .entry(CacheBucket::UserMessages, "last_update_check")
306    }
307
308    pub async fn should_check_for_updates(&self) -> bool {
309        let entry = self.update_check_entry();
310        let one_day = Duration::from_secs(24 * 3600);
311
312        match entry.freshness(one_day) {
313            Ok(Freshness::Fresh) => false, // Checked recently, don't check
314            _ => true,                     // Stale or doesn't exist, check for updates
315        }
316    }
317
318    pub async fn record_update_check(&self) -> Result<()> {
319        let entry = self.update_check_entry();
320        let now = Utc::now();
321        let timestamp = serde_json::to_vec(&now)?;
322        entry.write(&timestamp).await?;
323        Ok(())
324    }
325}
326
327impl Clone for AuditCache {
328    fn clone(&self) -> Self {
329        Self {
330            cache: self.cache.clone(),
331        }
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use crate::types::{ResolvedDependency, ResolverType};
339    use std::collections::HashMap;
340    use tempfile::tempdir;
341
342    #[tokio::test]
343    async fn test_resolution_cache_key_generation() {
344        let temp_dir = tempdir().unwrap();
345        let cache = AuditCache::new(temp_dir.path().to_path_buf());
346
347        let requirements_content = "requests>=2.25.0\nclick==8.0.0";
348        let resolver_type = ResolverType::Uv;
349        let resolver_version = "0.4.29";
350        let python_version = "3.12";
351        let platform = "linux-x86_64";
352        let environment_markers = HashMap::new();
353
354        let key1 = cache.generate_resolution_cache_key(
355            requirements_content,
356            &resolver_type,
357            resolver_version,
358            python_version,
359            platform,
360            &environment_markers,
361        );
362
363        let key2 = cache.generate_resolution_cache_key(
364            requirements_content,
365            &resolver_type,
366            resolver_version,
367            python_version,
368            platform,
369            &environment_markers,
370        );
371
372        assert_eq!(key1, key2);
373        assert!(key1.starts_with("uv-py3.12-linux-x86_64-"));
374
375        let key3 = cache.generate_resolution_cache_key(
376            "different-content",
377            &resolver_type,
378            resolver_version,
379            python_version,
380            platform,
381            &environment_markers,
382        );
383        assert_ne!(key1, key3);
384    }
385
386    #[tokio::test]
387    async fn test_resolution_cache_write_read() {
388        let temp_dir = tempdir().unwrap();
389        let cache = AuditCache::new(temp_dir.path().to_path_buf());
390
391        let cache_key = "test-cache-key";
392        let cache_entry = ResolutionCacheEntry {
393            output: "requests==2.31.0".to_string(),
394            resolver_type: ResolverType::Uv,
395            resolver_version: "0.4.29".to_string(),
396            python_version: "3.12".to_string(),
397            dependencies: vec![ResolvedDependency {
398                name: "requests".to_string(),
399                version: "2.31.0".to_string(),
400                is_direct: true,
401                source_file: std::path::PathBuf::from("requirements.txt"),
402                extras: vec![],
403                markers: None,
404            }],
405        };
406
407        cache
408            .write_resolution_cache(cache_key, &cache_entry)
409            .await
410            .unwrap();
411
412        let read_entry = cache
413            .read_resolution_cache(cache_key)
414            .await
415            .unwrap()
416            .unwrap();
417
418        assert_eq!(read_entry.output, cache_entry.output);
419        assert_eq!(
420            read_entry.resolver_type.to_string(),
421            cache_entry.resolver_type.to_string()
422        );
423        assert_eq!(read_entry.resolver_version, cache_entry.resolver_version);
424        assert_eq!(read_entry.python_version, cache_entry.python_version);
425        assert_eq!(
426            read_entry.dependencies.len(),
427            cache_entry.dependencies.len()
428        );
429        assert_eq!(read_entry.dependencies[0].name, "requests");
430    }
431
432    #[tokio::test]
433    async fn test_resolution_cache_freshness() {
434        let temp_dir = tempdir().unwrap();
435        let cache = AuditCache::new(temp_dir.path().to_path_buf());
436
437        let cache_key = "test-freshness";
438
439        assert!(cache.should_refresh_resolution(cache_key, 24).unwrap());
440
441        let cache_entry = ResolutionCacheEntry {
442            output: "".to_string(),
443            resolver_type: ResolverType::Uv,
444            resolver_version: "0.4.29".to_string(),
445            python_version: "3.12".to_string(),
446            dependencies: vec![],
447        };
448
449        cache
450            .write_resolution_cache(cache_key, &cache_entry)
451            .await
452            .unwrap();
453
454        assert!(!cache.should_refresh_resolution(cache_key, 24).unwrap());
455
456        assert!(cache.should_refresh_resolution(cache_key, 0).unwrap());
457    }
458
459    #[tokio::test]
460    async fn test_resolution_cache_clear() {
461        let temp_dir = tempdir().unwrap();
462        let cache = AuditCache::new(temp_dir.path().to_path_buf());
463
464        let uv_entry = ResolutionCacheEntry {
465            output: "uv-test-output".to_string(),
466            resolver_type: ResolverType::Uv,
467            resolver_version: "0.4.29".to_string(),
468            python_version: "3.12".to_string(),
469            dependencies: vec![],
470        };
471
472        let pip_tools_entry = ResolutionCacheEntry {
473            output: "pip-tools-test-output".to_string(),
474            resolver_type: ResolverType::PipTools,
475            resolver_version: "7.4.1".to_string(),
476            python_version: "3.12".to_string(),
477            dependencies: vec![],
478        };
479
480        cache
481            .write_resolution_cache("uv-test", &uv_entry)
482            .await
483            .unwrap();
484        cache
485            .write_resolution_cache("pip-tools-test", &pip_tools_entry)
486            .await
487            .unwrap();
488
489        assert!(cache
490            .read_resolution_cache("uv-test")
491            .await
492            .unwrap()
493            .is_some());
494        assert!(cache
495            .read_resolution_cache("pip-tools-test")
496            .await
497            .unwrap()
498            .is_some());
499
500        cache
501            .clear_resolution_cache(Some(ResolverType::Uv))
502            .await
503            .unwrap();
504
505        assert!(cache
506            .read_resolution_cache("uv-test")
507            .await
508            .unwrap()
509            .is_none());
510        assert!(cache
511            .read_resolution_cache("pip-tools-test")
512            .await
513            .unwrap()
514            .is_some());
515
516        cache.clear_resolution_cache(None).await.unwrap();
517
518        assert!(cache
519            .read_resolution_cache("uv-test")
520            .await
521            .unwrap()
522            .is_none());
523        assert!(cache
524            .read_resolution_cache("pip-tools-test")
525            .await
526            .unwrap()
527            .is_none());
528    }
529
530    #[tokio::test]
531    async fn test_resolution_cache_stats() {
532        let temp_dir = tempdir().unwrap();
533        let cache = AuditCache::new(temp_dir.path().to_path_buf());
534
535        let stats = cache.get_resolution_cache_stats().await.unwrap();
536        assert_eq!(stats.total_entries, 0);
537        assert_eq!(stats.total_size_bytes, 0);
538        assert_eq!(stats.uv_entries, 0);
539        assert_eq!(stats.pip_tools_entries, 0);
540
541        let uv_entry = ResolutionCacheEntry {
542            output: "test-uv-output".to_string(),
543            resolver_type: ResolverType::Uv,
544            resolver_version: "0.4.29".to_string(),
545            python_version: "3.12".to_string(),
546            dependencies: vec![],
547        };
548
549        cache
550            .write_resolution_cache("uv-py3.12-linux-x86_64-abc123", &uv_entry)
551            .await
552            .unwrap();
553
554        let stats = cache.get_resolution_cache_stats().await.unwrap();
555        assert_eq!(stats.total_entries, 1);
556        assert!(stats.total_size_bytes > 0);
557        assert_eq!(stats.uv_entries, 1);
558        assert_eq!(stats.pip_tools_entries, 0);
559    }
560
561    #[test]
562    fn test_normalize_package_name() {
563        use super::normalize_package_name;
564
565        // Case normalization
566        assert_eq!(normalize_package_name("Django"), "django");
567        assert_eq!(normalize_package_name("DJANGO"), "django");
568
569        // Separator normalization (underscore -> hyphen)
570        assert_eq!(normalize_package_name("my_package"), "my-package");
571
572        // Separator normalization (dot -> hyphen)
573        assert_eq!(normalize_package_name("my.package"), "my-package");
574
575        // Combined: case + separator
576        assert_eq!(normalize_package_name("My-Package"), "my-package");
577        assert_eq!(normalize_package_name("My_Package"), "my-package");
578        assert_eq!(normalize_package_name("My.Package"), "my-package");
579
580        // Multiple separators
581        assert_eq!(
582            normalize_package_name("some_complex.package-name"),
583            "some-complex-package-name"
584        );
585
586        // Already normalized
587        assert_eq!(normalize_package_name("requests"), "requests");
588        assert_eq!(normalize_package_name("my-package"), "my-package");
589    }
590}