Skip to main content

pro_core/audit/
mod.rs

1//! Security audit module for vulnerability checking
2//!
3//! Checks packages against vulnerability databases and provides
4//! automated fix recommendations.
5
6pub mod osv;
7pub mod pypi;
8pub mod types;
9
10pub use osv::OsvClient;
11pub use pypi::PyPIClient;
12pub use types::*;
13
14use std::collections::{HashMap, HashSet};
15
16use crate::lockfile::Lockfile;
17use crate::pep::Requirement;
18use crate::resolver::Resolver;
19use crate::Result;
20
21/// Auditor for checking package vulnerabilities
22pub struct Auditor {
23    /// OSV client for vulnerability queries
24    osv_client: OsvClient,
25    /// PyPI client for yanked version detection
26    pypi_client: PyPIClient,
27    /// Audit configuration
28    config: AuditConfig,
29}
30
31/// Audit configuration
32#[derive(Debug, Clone, Default)]
33pub struct AuditConfig {
34    /// Vulnerability IDs to ignore (from CLI)
35    pub ignored_ids: HashSet<String>,
36    /// Ignored vulnerabilities with metadata (from config file)
37    pub ignored_vulnerabilities: Vec<IgnoredVulnerability>,
38    /// Whether to check for yanked versions
39    pub check_yanked: bool,
40}
41
42impl Auditor {
43    /// Create a new auditor
44    pub fn new() -> Self {
45        Self {
46            osv_client: OsvClient::new(),
47            pypi_client: PyPIClient::new(),
48            config: AuditConfig::default(),
49        }
50    }
51
52    /// Create an auditor with ignored vulnerabilities (from CLI)
53    pub fn with_ignored(ignored: Vec<String>) -> Self {
54        Self {
55            osv_client: OsvClient::new(),
56            pypi_client: PyPIClient::new(),
57            config: AuditConfig {
58                ignored_ids: ignored.into_iter().collect(),
59                ignored_vulnerabilities: Vec::new(),
60                check_yanked: true,
61            },
62        }
63    }
64
65    /// Create an auditor with full configuration
66    pub fn with_config(config: AuditConfig) -> Self {
67        Self {
68            osv_client: OsvClient::new(),
69            pypi_client: PyPIClient::new(),
70            config,
71        }
72    }
73
74    /// Add vulnerability IDs to ignore
75    pub fn ignore(&mut self, ids: impl IntoIterator<Item = String>) {
76        self.config.ignored_ids.extend(ids);
77    }
78
79    /// Check if a vulnerability ID should be ignored
80    fn is_ignored(&self, id: &str) -> bool {
81        // Check simple ignore list
82        if self.config.ignored_ids.contains(id) {
83            return true;
84        }
85
86        // Check configured ignores with expiration
87        for ignored in &self.config.ignored_vulnerabilities {
88            if ignored.id == id || ignored.id == "*" {
89                // Check expiration
90                if let Some(expires) = &ignored.expires {
91                    if let Ok(expiry_date) = chrono::NaiveDate::parse_from_str(expires, "%Y-%m-%d")
92                    {
93                        let today = chrono::Utc::now().date_naive();
94                        if today > expiry_date {
95                            // Ignore has expired
96                            continue;
97                        }
98                    }
99                }
100                return true;
101            }
102        }
103
104        false
105    }
106
107    /// Get all effective ignored IDs (for reporting)
108    pub fn effective_ignores(&self) -> Vec<&str> {
109        let mut ignores: Vec<&str> = self.config.ignored_ids.iter().map(|s| s.as_str()).collect();
110
111        for ignored in &self.config.ignored_vulnerabilities {
112            // Check expiration
113            let is_expired = ignored.expires.as_ref().is_some_and(|expires| {
114                chrono::NaiveDate::parse_from_str(expires, "%Y-%m-%d")
115                    .map(|expiry| chrono::Utc::now().date_naive() > expiry)
116                    .unwrap_or(false)
117            });
118
119            if !is_expired {
120                ignores.push(&ignored.id);
121            }
122        }
123
124        ignores
125    }
126
127    /// Audit packages from a lockfile
128    pub async fn audit_lockfile(&self, lockfile: &Lockfile) -> Result<AuditReport> {
129        let packages: Vec<(&str, &str)> = lockfile
130            .packages
131            .iter()
132            .map(|(name, pkg)| (name.as_str(), pkg.version.as_str()))
133            .collect();
134
135        self.audit_packages(&packages).await
136    }
137
138    /// Audit a list of packages
139    pub async fn audit_packages(&self, packages: &[(&str, &str)]) -> Result<AuditReport> {
140        tracing::info!(
141            "Checking {} packages for vulnerabilities...",
142            packages.len()
143        );
144
145        // Query OSV for all packages
146        let vuln_map = self.osv_client.query_batch(packages).await?;
147
148        let mut report = AuditReport::new();
149
150        for (name, version) in packages {
151            let vulnerabilities: Vec<Vulnerability> = vuln_map
152                .get(*name)
153                .cloned()
154                .unwrap_or_default()
155                .into_iter()
156                .filter(|v| !self.is_ignored(&v.id))
157                .filter(|v| !v.aliases.iter().any(|a| self.is_ignored(a)))
158                .collect();
159
160            let ignored_count =
161                vuln_map.get(*name).map(|vs| vs.len()).unwrap_or(0) - vulnerabilities.len();
162
163            if ignored_count > 0 {
164                report
165                    .ignored
166                    .push(format!("{} ({} ignored)", name, ignored_count));
167            }
168
169            report.packages.push(PackageAuditResult {
170                name: name.to_string(),
171                version: version.to_string(),
172                vulnerabilities,
173            });
174        }
175
176        // Check for yanked versions if enabled
177        if self.config.check_yanked {
178            let yanked = self.pypi_client.check_yanked_batch(packages).await?;
179            report.yanked_packages = yanked;
180        }
181
182        Ok(report)
183    }
184
185    /// Generate fix recommendations for vulnerable packages
186    pub async fn generate_fixes(
187        &self,
188        report: &AuditReport,
189        _lockfile: &Lockfile,
190    ) -> Result<Vec<FixRecommendation>> {
191        let mut fixes = Vec::new();
192
193        for pkg_result in report.vulnerable_packages() {
194            for vuln in &pkg_result.vulnerabilities {
195                if let Some(fixed_version) = &vuln.fixed_version {
196                    fixes.push(FixRecommendation {
197                        package: pkg_result.name.clone(),
198                        current_version: pkg_result.version.clone(),
199                        fixed_version: fixed_version.clone(),
200                        vulnerability_id: vuln.id.clone(),
201                        severity: vuln.severity,
202                        requires_parent_update: false, // Will be determined during resolution
203                    });
204                }
205            }
206        }
207
208        // Deduplicate fixes per package (take highest fixed version)
209        let mut package_fixes: HashMap<String, FixRecommendation> = HashMap::new();
210        for fix in fixes {
211            package_fixes
212                .entry(fix.package.clone())
213                .and_modify(|existing| {
214                    // Keep the fix with higher version or severity
215                    if fix.severity > existing.severity {
216                        *existing = fix.clone();
217                    }
218                })
219                .or_insert(fix);
220        }
221
222        Ok(package_fixes.into_values().collect())
223    }
224
225    /// Apply fixes by re-resolving dependencies
226    pub async fn apply_fixes(
227        &self,
228        fixes: &[FixRecommendation],
229        lockfile: &Lockfile,
230        force: bool,
231    ) -> Result<FixResult> {
232        if fixes.is_empty() {
233            return Ok(FixResult {
234                updated_lockfile: lockfile.clone(),
235                applied_fixes: vec![],
236                failed_fixes: vec![],
237                requires_force: vec![],
238            });
239        }
240
241        // Build new requirements with minimum versions from fixes
242        let mut requirements: Vec<Requirement> = Vec::new();
243        let mut min_versions: HashMap<String, String> = HashMap::new();
244
245        for fix in fixes {
246            min_versions.insert(fix.package.clone(), fix.fixed_version.clone());
247        }
248
249        // Create requirements for all packages, with minimum versions for vulnerable ones
250        for (name, pkg) in &lockfile.packages {
251            let version_spec = if let Some(min_ver) = min_versions.get(name) {
252                format!(">={}", min_ver)
253            } else {
254                format!(">={}", pkg.version)
255            };
256
257            let req_str = format!("{}{}", name, version_spec);
258            match Requirement::parse(&req_str) {
259                Ok(req) => requirements.push(req),
260                Err(e) => {
261                    tracing::warn!("Failed to parse requirement {}: {}", req_str, e);
262                }
263            }
264        }
265
266        // Try to resolve with new constraints
267        let resolver = Resolver::new();
268        match resolver.resolve(&requirements).await {
269            Ok(resolution) => {
270                let new_lockfile = Lockfile::from_resolution(&resolution);
271
272                // Determine which fixes were applied
273                let mut applied_fixes = Vec::new();
274                let mut requires_force = Vec::new();
275
276                for fix in fixes {
277                    if let Some(new_pkg) = new_lockfile.packages.get(&fix.package) {
278                        if new_pkg.version != fix.current_version {
279                            applied_fixes.push(AppliedFix {
280                                package: fix.package.clone(),
281                                from_version: fix.current_version.clone(),
282                                to_version: new_pkg.version.clone(),
283                                vulnerability_id: fix.vulnerability_id.clone(),
284                            });
285                        }
286                    }
287                }
288
289                // Check for packages that changed unexpectedly (might need force)
290                for (name, new_pkg) in &new_lockfile.packages {
291                    if let Some(old_pkg) = lockfile.packages.get(name) {
292                        if new_pkg.version != old_pkg.version && !min_versions.contains_key(name) {
293                            requires_force.push(format!(
294                                "{}: {} -> {} (transitive update)",
295                                name, old_pkg.version, new_pkg.version
296                            ));
297                        }
298                    }
299                }
300
301                // If there are transitive updates and force is not set, return early
302                if !requires_force.is_empty() && !force {
303                    return Ok(FixResult {
304                        updated_lockfile: lockfile.clone(),
305                        applied_fixes: vec![],
306                        failed_fixes: vec![],
307                        requires_force,
308                    });
309                }
310
311                Ok(FixResult {
312                    updated_lockfile: new_lockfile,
313                    applied_fixes,
314                    failed_fixes: vec![],
315                    requires_force: vec![],
316                })
317            }
318            Err(e) => {
319                // Resolution failed - all fixes failed
320                let failed_fixes: Vec<FailedFix> = fixes
321                    .iter()
322                    .map(|f| FailedFix {
323                        package: f.package.clone(),
324                        target_version: f.fixed_version.clone(),
325                        reason: format!("Resolution failed: {}", e),
326                    })
327                    .collect();
328
329                Ok(FixResult {
330                    updated_lockfile: lockfile.clone(),
331                    applied_fixes: vec![],
332                    failed_fixes,
333                    requires_force: vec![],
334                })
335            }
336        }
337    }
338}
339
340impl Default for Auditor {
341    fn default() -> Self {
342        Self::new()
343    }
344}
345
346/// Recommendation to fix a vulnerability
347#[derive(Debug, Clone)]
348pub struct FixRecommendation {
349    /// Package name
350    pub package: String,
351    /// Current installed version
352    pub current_version: String,
353    /// Version that fixes the vulnerability
354    pub fixed_version: String,
355    /// Vulnerability ID being fixed
356    pub vulnerability_id: String,
357    /// Severity of the vulnerability
358    pub severity: Severity,
359    /// Whether parent packages need updating too
360    pub requires_parent_update: bool,
361}
362
363/// Result of applying fixes
364#[derive(Debug, Clone)]
365pub struct FixResult {
366    /// Updated lockfile with fixes applied
367    pub updated_lockfile: Lockfile,
368    /// Successfully applied fixes
369    pub applied_fixes: Vec<AppliedFix>,
370    /// Fixes that failed to apply
371    pub failed_fixes: Vec<FailedFix>,
372    /// Changes that require --force flag
373    pub requires_force: Vec<String>,
374}
375
376impl FixResult {
377    /// Check if any fixes were applied
378    pub fn has_changes(&self) -> bool {
379        !self.applied_fixes.is_empty()
380    }
381
382    /// Check if force is required
383    pub fn needs_force(&self) -> bool {
384        !self.requires_force.is_empty()
385    }
386}
387
388/// A successfully applied fix
389#[derive(Debug, Clone)]
390pub struct AppliedFix {
391    /// Package name
392    pub package: String,
393    /// Previous version
394    pub from_version: String,
395    /// New version
396    pub to_version: String,
397    /// Vulnerability ID that was fixed
398    pub vulnerability_id: String,
399}
400
401/// A fix that failed to apply
402#[derive(Debug, Clone)]
403pub struct FailedFix {
404    /// Package name
405    pub package: String,
406    /// Target version that couldn't be installed
407    pub target_version: String,
408    /// Reason for failure
409    pub reason: String,
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_auditor_ignore() {
418        let mut auditor = Auditor::new();
419        auditor.ignore(["CVE-2023-1234".to_string()]);
420        assert!(auditor.config.ignored_ids.contains("CVE-2023-1234"));
421    }
422
423    #[test]
424    fn test_auditor_is_ignored() {
425        let config = AuditConfig {
426            ignored_ids: vec!["CVE-2023-1234".to_string()].into_iter().collect(),
427            ignored_vulnerabilities: vec![IgnoredVulnerability {
428                id: "GHSA-xxxx".to_string(),
429                reason: Some("Not applicable".to_string()),
430                expires: None,
431            }],
432            check_yanked: true,
433        };
434        let auditor = Auditor::with_config(config);
435
436        assert!(auditor.is_ignored("CVE-2023-1234"));
437        assert!(auditor.is_ignored("GHSA-xxxx"));
438        assert!(!auditor.is_ignored("CVE-2023-9999"));
439    }
440
441    #[test]
442    fn test_auditor_expired_ignore() {
443        let config = AuditConfig {
444            ignored_ids: HashSet::new(),
445            ignored_vulnerabilities: vec![
446                IgnoredVulnerability {
447                    id: "CVE-2023-EXPIRED".to_string(),
448                    reason: Some("Testing".to_string()),
449                    expires: Some("2020-01-01".to_string()), // Expired
450                },
451                IgnoredVulnerability {
452                    id: "CVE-2023-ACTIVE".to_string(),
453                    reason: Some("Testing".to_string()),
454                    expires: Some("2099-12-31".to_string()), // Not expired
455                },
456            ],
457            check_yanked: true,
458        };
459        let auditor = Auditor::with_config(config);
460
461        assert!(!auditor.is_ignored("CVE-2023-EXPIRED")); // Expired, should not be ignored
462        assert!(auditor.is_ignored("CVE-2023-ACTIVE")); // Not expired, should be ignored
463    }
464
465    #[test]
466    fn test_fix_result_has_changes() {
467        let result = FixResult {
468            updated_lockfile: Lockfile::new(),
469            applied_fixes: vec![AppliedFix {
470                package: "test".to_string(),
471                from_version: "1.0.0".to_string(),
472                to_version: "1.0.1".to_string(),
473                vulnerability_id: "CVE-2023-1234".to_string(),
474            }],
475            failed_fixes: vec![],
476            requires_force: vec![],
477        };
478        assert!(result.has_changes());
479        assert!(!result.needs_force());
480    }
481}