1pub mod osv;
7pub mod pypi;
8pub mod types;
9
10pub use osv::OsvClient;
11pub use pypi::PyPIClient;
12pub use types::*;
13
14use std::collections::{HashMap, HashSet};
15
16use crate::lockfile::Lockfile;
17use crate::pep::Requirement;
18use crate::resolver::Resolver;
19use crate::Result;
20
21pub struct Auditor {
23 osv_client: OsvClient,
25 pypi_client: PyPIClient,
27 config: AuditConfig,
29}
30
31#[derive(Debug, Clone, Default)]
33pub struct AuditConfig {
34 pub ignored_ids: HashSet<String>,
36 pub ignored_vulnerabilities: Vec<IgnoredVulnerability>,
38 pub check_yanked: bool,
40}
41
42impl Auditor {
43 pub fn new() -> Self {
45 Self {
46 osv_client: OsvClient::new(),
47 pypi_client: PyPIClient::new(),
48 config: AuditConfig::default(),
49 }
50 }
51
52 pub fn with_ignored(ignored: Vec<String>) -> Self {
54 Self {
55 osv_client: OsvClient::new(),
56 pypi_client: PyPIClient::new(),
57 config: AuditConfig {
58 ignored_ids: ignored.into_iter().collect(),
59 ignored_vulnerabilities: Vec::new(),
60 check_yanked: true,
61 },
62 }
63 }
64
65 pub fn with_config(config: AuditConfig) -> Self {
67 Self {
68 osv_client: OsvClient::new(),
69 pypi_client: PyPIClient::new(),
70 config,
71 }
72 }
73
74 pub fn ignore(&mut self, ids: impl IntoIterator<Item = String>) {
76 self.config.ignored_ids.extend(ids);
77 }
78
79 fn is_ignored(&self, id: &str) -> bool {
81 if self.config.ignored_ids.contains(id) {
83 return true;
84 }
85
86 for ignored in &self.config.ignored_vulnerabilities {
88 if ignored.id == id || ignored.id == "*" {
89 if let Some(expires) = &ignored.expires {
91 if let Ok(expiry_date) = chrono::NaiveDate::parse_from_str(expires, "%Y-%m-%d")
92 {
93 let today = chrono::Utc::now().date_naive();
94 if today > expiry_date {
95 continue;
97 }
98 }
99 }
100 return true;
101 }
102 }
103
104 false
105 }
106
107 pub fn effective_ignores(&self) -> Vec<&str> {
109 let mut ignores: Vec<&str> = self.config.ignored_ids.iter().map(|s| s.as_str()).collect();
110
111 for ignored in &self.config.ignored_vulnerabilities {
112 let is_expired = ignored.expires.as_ref().is_some_and(|expires| {
114 chrono::NaiveDate::parse_from_str(expires, "%Y-%m-%d")
115 .map(|expiry| chrono::Utc::now().date_naive() > expiry)
116 .unwrap_or(false)
117 });
118
119 if !is_expired {
120 ignores.push(&ignored.id);
121 }
122 }
123
124 ignores
125 }
126
127 pub async fn audit_lockfile(&self, lockfile: &Lockfile) -> Result<AuditReport> {
129 let packages: Vec<(&str, &str)> = lockfile
130 .packages
131 .iter()
132 .map(|(name, pkg)| (name.as_str(), pkg.version.as_str()))
133 .collect();
134
135 self.audit_packages(&packages).await
136 }
137
138 pub async fn audit_packages(&self, packages: &[(&str, &str)]) -> Result<AuditReport> {
140 tracing::info!(
141 "Checking {} packages for vulnerabilities...",
142 packages.len()
143 );
144
145 let vuln_map = self.osv_client.query_batch(packages).await?;
147
148 let mut report = AuditReport::new();
149
150 for (name, version) in packages {
151 let vulnerabilities: Vec<Vulnerability> = vuln_map
152 .get(*name)
153 .cloned()
154 .unwrap_or_default()
155 .into_iter()
156 .filter(|v| !self.is_ignored(&v.id))
157 .filter(|v| !v.aliases.iter().any(|a| self.is_ignored(a)))
158 .collect();
159
160 let ignored_count =
161 vuln_map.get(*name).map(|vs| vs.len()).unwrap_or(0) - vulnerabilities.len();
162
163 if ignored_count > 0 {
164 report
165 .ignored
166 .push(format!("{} ({} ignored)", name, ignored_count));
167 }
168
169 report.packages.push(PackageAuditResult {
170 name: name.to_string(),
171 version: version.to_string(),
172 vulnerabilities,
173 });
174 }
175
176 if self.config.check_yanked {
178 let yanked = self.pypi_client.check_yanked_batch(packages).await?;
179 report.yanked_packages = yanked;
180 }
181
182 Ok(report)
183 }
184
185 pub async fn generate_fixes(
187 &self,
188 report: &AuditReport,
189 _lockfile: &Lockfile,
190 ) -> Result<Vec<FixRecommendation>> {
191 let mut fixes = Vec::new();
192
193 for pkg_result in report.vulnerable_packages() {
194 for vuln in &pkg_result.vulnerabilities {
195 if let Some(fixed_version) = &vuln.fixed_version {
196 fixes.push(FixRecommendation {
197 package: pkg_result.name.clone(),
198 current_version: pkg_result.version.clone(),
199 fixed_version: fixed_version.clone(),
200 vulnerability_id: vuln.id.clone(),
201 severity: vuln.severity,
202 requires_parent_update: false, });
204 }
205 }
206 }
207
208 let mut package_fixes: HashMap<String, FixRecommendation> = HashMap::new();
210 for fix in fixes {
211 package_fixes
212 .entry(fix.package.clone())
213 .and_modify(|existing| {
214 if fix.severity > existing.severity {
216 *existing = fix.clone();
217 }
218 })
219 .or_insert(fix);
220 }
221
222 Ok(package_fixes.into_values().collect())
223 }
224
225 pub async fn apply_fixes(
227 &self,
228 fixes: &[FixRecommendation],
229 lockfile: &Lockfile,
230 force: bool,
231 ) -> Result<FixResult> {
232 if fixes.is_empty() {
233 return Ok(FixResult {
234 updated_lockfile: lockfile.clone(),
235 applied_fixes: vec![],
236 failed_fixes: vec![],
237 requires_force: vec![],
238 });
239 }
240
241 let mut requirements: Vec<Requirement> = Vec::new();
243 let mut min_versions: HashMap<String, String> = HashMap::new();
244
245 for fix in fixes {
246 min_versions.insert(fix.package.clone(), fix.fixed_version.clone());
247 }
248
249 for (name, pkg) in &lockfile.packages {
251 let version_spec = if let Some(min_ver) = min_versions.get(name) {
252 format!(">={}", min_ver)
253 } else {
254 format!(">={}", pkg.version)
255 };
256
257 let req_str = format!("{}{}", name, version_spec);
258 match Requirement::parse(&req_str) {
259 Ok(req) => requirements.push(req),
260 Err(e) => {
261 tracing::warn!("Failed to parse requirement {}: {}", req_str, e);
262 }
263 }
264 }
265
266 let resolver = Resolver::new();
268 match resolver.resolve(&requirements).await {
269 Ok(resolution) => {
270 let new_lockfile = Lockfile::from_resolution(&resolution);
271
272 let mut applied_fixes = Vec::new();
274 let mut requires_force = Vec::new();
275
276 for fix in fixes {
277 if let Some(new_pkg) = new_lockfile.packages.get(&fix.package) {
278 if new_pkg.version != fix.current_version {
279 applied_fixes.push(AppliedFix {
280 package: fix.package.clone(),
281 from_version: fix.current_version.clone(),
282 to_version: new_pkg.version.clone(),
283 vulnerability_id: fix.vulnerability_id.clone(),
284 });
285 }
286 }
287 }
288
289 for (name, new_pkg) in &new_lockfile.packages {
291 if let Some(old_pkg) = lockfile.packages.get(name) {
292 if new_pkg.version != old_pkg.version && !min_versions.contains_key(name) {
293 requires_force.push(format!(
294 "{}: {} -> {} (transitive update)",
295 name, old_pkg.version, new_pkg.version
296 ));
297 }
298 }
299 }
300
301 if !requires_force.is_empty() && !force {
303 return Ok(FixResult {
304 updated_lockfile: lockfile.clone(),
305 applied_fixes: vec![],
306 failed_fixes: vec![],
307 requires_force,
308 });
309 }
310
311 Ok(FixResult {
312 updated_lockfile: new_lockfile,
313 applied_fixes,
314 failed_fixes: vec![],
315 requires_force: vec![],
316 })
317 }
318 Err(e) => {
319 let failed_fixes: Vec<FailedFix> = fixes
321 .iter()
322 .map(|f| FailedFix {
323 package: f.package.clone(),
324 target_version: f.fixed_version.clone(),
325 reason: format!("Resolution failed: {}", e),
326 })
327 .collect();
328
329 Ok(FixResult {
330 updated_lockfile: lockfile.clone(),
331 applied_fixes: vec![],
332 failed_fixes,
333 requires_force: vec![],
334 })
335 }
336 }
337 }
338}
339
340impl Default for Auditor {
341 fn default() -> Self {
342 Self::new()
343 }
344}
345
346#[derive(Debug, Clone)]
348pub struct FixRecommendation {
349 pub package: String,
351 pub current_version: String,
353 pub fixed_version: String,
355 pub vulnerability_id: String,
357 pub severity: Severity,
359 pub requires_parent_update: bool,
361}
362
363#[derive(Debug, Clone)]
365pub struct FixResult {
366 pub updated_lockfile: Lockfile,
368 pub applied_fixes: Vec<AppliedFix>,
370 pub failed_fixes: Vec<FailedFix>,
372 pub requires_force: Vec<String>,
374}
375
376impl FixResult {
377 pub fn has_changes(&self) -> bool {
379 !self.applied_fixes.is_empty()
380 }
381
382 pub fn needs_force(&self) -> bool {
384 !self.requires_force.is_empty()
385 }
386}
387
388#[derive(Debug, Clone)]
390pub struct AppliedFix {
391 pub package: String,
393 pub from_version: String,
395 pub to_version: String,
397 pub vulnerability_id: String,
399}
400
401#[derive(Debug, Clone)]
403pub struct FailedFix {
404 pub package: String,
406 pub target_version: String,
408 pub reason: String,
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_auditor_ignore() {
418 let mut auditor = Auditor::new();
419 auditor.ignore(["CVE-2023-1234".to_string()]);
420 assert!(auditor.config.ignored_ids.contains("CVE-2023-1234"));
421 }
422
423 #[test]
424 fn test_auditor_is_ignored() {
425 let config = AuditConfig {
426 ignored_ids: vec!["CVE-2023-1234".to_string()].into_iter().collect(),
427 ignored_vulnerabilities: vec![IgnoredVulnerability {
428 id: "GHSA-xxxx".to_string(),
429 reason: Some("Not applicable".to_string()),
430 expires: None,
431 }],
432 check_yanked: true,
433 };
434 let auditor = Auditor::with_config(config);
435
436 assert!(auditor.is_ignored("CVE-2023-1234"));
437 assert!(auditor.is_ignored("GHSA-xxxx"));
438 assert!(!auditor.is_ignored("CVE-2023-9999"));
439 }
440
441 #[test]
442 fn test_auditor_expired_ignore() {
443 let config = AuditConfig {
444 ignored_ids: HashSet::new(),
445 ignored_vulnerabilities: vec![
446 IgnoredVulnerability {
447 id: "CVE-2023-EXPIRED".to_string(),
448 reason: Some("Testing".to_string()),
449 expires: Some("2020-01-01".to_string()), },
451 IgnoredVulnerability {
452 id: "CVE-2023-ACTIVE".to_string(),
453 reason: Some("Testing".to_string()),
454 expires: Some("2099-12-31".to_string()), },
456 ],
457 check_yanked: true,
458 };
459 let auditor = Auditor::with_config(config);
460
461 assert!(!auditor.is_ignored("CVE-2023-EXPIRED")); assert!(auditor.is_ignored("CVE-2023-ACTIVE")); }
464
465 #[test]
466 fn test_fix_result_has_changes() {
467 let result = FixResult {
468 updated_lockfile: Lockfile::new(),
469 applied_fixes: vec![AppliedFix {
470 package: "test".to_string(),
471 from_version: "1.0.0".to_string(),
472 to_version: "1.0.1".to_string(),
473 vulnerability_id: "CVE-2023-1234".to_string(),
474 }],
475 failed_fixes: vec![],
476 requires_force: vec![],
477 };
478 assert!(result.has_changes());
479 assert!(!result.needs_force());
480 }
481}