Skip to main content

infigraph_core/vuln/
mod.rs

1//! Dependency vulnerability scanning via the OSV (Open Source Vulnerabilities) API.
2//!
3//! Cross-references project dependencies (from `manifest::query_deps`) against
4//! the OSV batch endpoint and reports known vulnerabilities with severity.
5
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8
9use crate::manifest::DepEntry;
10
11// ── Public types ────────────────────────────────────────────────────────────
12
13#[derive(Debug, Clone, Serialize)]
14pub struct VulnEntry {
15    pub dep_name: String,
16    pub dep_version: String,
17    pub ecosystem: String,
18    pub vuln_id: String,
19    pub summary: String,
20    pub severity: String,
21    pub fixed_version: Option<String>,
22    pub url: String,
23}
24
25#[derive(Debug, Clone, Serialize)]
26pub struct VulnReport {
27    pub total_deps: usize,
28    pub vulnerable_deps: usize,
29    pub findings: Vec<VulnEntry>,
30}
31
32// ── OSV request/response models (serde) ─────────────────────────────────────
33
34#[derive(Debug, Serialize)]
35struct OsvBatchRequest {
36    queries: Vec<OsvQuery>,
37}
38
39#[derive(Debug, Serialize)]
40struct OsvQuery {
41    package: OsvPackage,
42    version: String,
43}
44
45#[derive(Debug, Serialize)]
46struct OsvPackage {
47    name: String,
48    ecosystem: String,
49}
50
51#[derive(Debug, Deserialize)]
52struct OsvBatchResponse {
53    results: Vec<OsvResultEntry>,
54}
55
56#[derive(Debug, Deserialize)]
57struct OsvResultEntry {
58    vulns: Option<Vec<OsvVuln>>,
59}
60
61#[derive(Debug, Deserialize)]
62struct OsvVuln {
63    id: String,
64    summary: Option<String>,
65    severity: Option<Vec<OsvSeverity>>,
66    affected: Option<Vec<OsvAffected>>,
67    references: Option<Vec<OsvReference>>,
68    database_specific: Option<OsvDatabaseSpecific>,
69}
70
71#[derive(Debug, Deserialize)]
72struct OsvSeverity {
73    #[serde(rename = "type")]
74    #[allow(dead_code)]
75    severity_type: Option<String>,
76    score: Option<String>,
77}
78
79#[derive(Debug, Deserialize)]
80struct OsvAffected {
81    ranges: Option<Vec<OsvRange>>,
82}
83
84#[derive(Debug, Deserialize)]
85struct OsvRange {
86    events: Option<Vec<OsvEvent>>,
87}
88
89#[derive(Debug, Deserialize)]
90struct OsvEvent {
91    fixed: Option<String>,
92}
93
94#[derive(Debug, Deserialize)]
95struct OsvReference {
96    #[serde(rename = "type")]
97    ref_type: Option<String>,
98    url: Option<String>,
99}
100
101#[derive(Debug, Deserialize)]
102struct OsvDatabaseSpecific {
103    severity: Option<String>,
104}
105
106// ── Ecosystem mapping ───────────────────────────────────────────────────────
107
108/// Map ecosystem names to OSV ecosystem names.
109fn map_ecosystem(eco: &str) -> &str {
110    match eco.to_lowercase().as_str() {
111        "npm" => "npm",
112        "cargo" => "crates.io",
113        "pip" | "pypi" => "PyPI",
114        "maven" | "gradle" => "Maven",
115        "gem" => "RubyGems",
116        "nuget" => "NuGet",
117        "go" => "Go",
118        "composer" => "Packagist",
119        "pub" => "Pub",
120        _other => {
121            // Return the input as-is for unknown ecosystems.
122            // We can't return `_other` directly because it borrows the
123            // lowercased temporary. Use the original eco instead.
124            eco
125        }
126    }
127}
128
129// ── CVSS score extraction ───────────────────────────────────────────────────
130
131/// Extract a severity label from a CVSS v3 vector string or numeric score.
132fn severity_from_cvss(score_str: &str) -> &'static str {
133    // Some OSV entries include the numeric base score directly (e.g. "9.8")
134    if let Ok(base) = score_str.parse::<f64>() {
135        return cvss_to_label(base);
136    }
137    // Try to estimate from a CVSS vector string
138    if score_str.starts_with("CVSS:") {
139        // Count high-impact metrics as a rough proxy for base score
140        let high_count = score_str.matches(":H").count();
141        let none_count = score_str.matches(":N").count();
142        let rough = match high_count {
143            0..=1 => 4.0,
144            2..=3 => 7.0,
145            _ => 9.0,
146        };
147        let bump = (none_count as f64) * 0.5;
148        return cvss_to_label(rough + bump);
149    }
150    "UNKNOWN"
151}
152
153fn cvss_to_label(base: f64) -> &'static str {
154    if base >= 9.0 {
155        "CRITICAL"
156    } else if base >= 7.0 {
157        "HIGH"
158    } else if base >= 4.0 {
159        "MEDIUM"
160    } else {
161        "LOW"
162    }
163}
164
165/// Determine severity label for a single OSV vuln entry.
166fn extract_severity(vuln: &OsvVuln) -> String {
167    // 1. Try CVSS scores from the severity array
168    if let Some(ref sev_list) = vuln.severity {
169        for s in sev_list {
170            if let Some(ref score) = s.score {
171                let label = severity_from_cvss(score);
172                if label != "UNKNOWN" {
173                    return label.to_string();
174                }
175            }
176        }
177    }
178    // 2. Try database_specific.severity
179    if let Some(ref db) = vuln.database_specific {
180        if let Some(ref sev) = db.severity {
181            return sev.to_uppercase();
182        }
183    }
184    "UNKNOWN".to_string()
185}
186
187/// Extract the first fixed version from affected ranges.
188fn extract_fixed_version(vuln: &OsvVuln) -> Option<String> {
189    if let Some(ref affected) = vuln.affected {
190        for a in affected {
191            if let Some(ref ranges) = a.ranges {
192                for r in ranges {
193                    if let Some(ref events) = r.events {
194                        for e in events {
195                            if let Some(ref fixed) = e.fixed {
196                                return Some(fixed.clone());
197                            }
198                        }
199                    }
200                }
201            }
202        }
203    }
204    None
205}
206
207/// Extract the best advisory URL from references.
208fn extract_url(vuln: &OsvVuln) -> String {
209    if let Some(ref refs) = vuln.references {
210        // Prefer ADVISORY type
211        for r in refs {
212            if r.ref_type.as_deref() == Some("ADVISORY") {
213                if let Some(ref url) = r.url {
214                    return url.clone();
215                }
216            }
217        }
218        // Fall back to first URL
219        for r in refs {
220            if let Some(ref url) = r.url {
221                return url.clone();
222            }
223        }
224    }
225    format!("https://osv.dev/vulnerability/{}", vuln.id)
226}
227
228// ── Version cleaning ────────────────────────────────────────────────────────
229
230/// Strip semver operators (^, ~, >=, <=, >, <, =) from version strings.
231fn clean_version(version: &str) -> &str {
232    version.trim_start_matches(|c: char| !c.is_ascii_digit())
233}
234
235/// Returns true if the version is usable for OSV queries.
236fn is_valid_version(version: &str) -> bool {
237    let cleaned = clean_version(version);
238    !cleaned.is_empty() && cleaned != "*"
239}
240
241// ── OSV API client ──────────────────────────────────────────────────────────
242
243const OSV_BATCH_URL: &str = "https://api.osv.dev/v1/querybatch";
244const OSV_BATCH_SIZE: usize = 1000;
245
246/// Send a batch of queries to the OSV API and return per-query results.
247fn query_osv_batch(queries: &[OsvQuery]) -> Result<Vec<OsvResultEntry>> {
248    if queries.is_empty() {
249        return Ok(Vec::new());
250    }
251
252    let body = OsvBatchRequest {
253        queries: queries
254            .iter()
255            .map(|q| OsvQuery {
256                package: OsvPackage {
257                    name: q.package.name.clone(),
258                    ecosystem: q.package.ecosystem.clone(),
259                },
260                version: q.version.clone(),
261            })
262            .collect(),
263    };
264
265    let body_json = serde_json::to_string(&body)?;
266
267    let resp = ureq::post(OSV_BATCH_URL)
268        .set("Content-Type", "application/json")
269        .send_string(&body_json);
270
271    match resp {
272        Ok(response) => {
273            let text = response.into_string()?;
274            let batch_resp: OsvBatchResponse = serde_json::from_str(&text)?;
275            Ok(batch_resp.results)
276        }
277        Err(e) => {
278            eprintln!("Warning: OSV API request failed: {e}");
279            // Return empty results for each query so indexing is preserved
280            Ok(queries
281                .iter()
282                .map(|_| OsvResultEntry { vulns: None })
283                .collect())
284        }
285    }
286}
287
288// ── Main scan entry point ───────────────────────────────────────────────────
289
290/// Scan a list of dependencies against the OSV vulnerability database.
291///
292/// Maps ecosystem names, batches queries (up to 1000 per request), parses
293/// responses, and returns a structured report.
294pub fn scan_deps(deps: &[DepEntry]) -> Result<VulnReport> {
295    // Build queries, filtering out deps with unusable versions
296    let valid_deps: Vec<&DepEntry> = deps
297        .iter()
298        .filter(|d| is_valid_version(&d.version))
299        .collect();
300
301    let queries: Vec<OsvQuery> = valid_deps
302        .iter()
303        .map(|d| OsvQuery {
304            package: OsvPackage {
305                name: d.name.clone(),
306                ecosystem: map_ecosystem(&d.ecosystem).to_string(),
307            },
308            version: clean_version(&d.version).to_string(),
309        })
310        .collect();
311
312    // Send in batches
313    let mut all_results: Vec<OsvResultEntry> = Vec::with_capacity(queries.len());
314    for chunk in queries.chunks(OSV_BATCH_SIZE) {
315        let batch_results = query_osv_batch(chunk)?;
316        all_results.extend(batch_results);
317    }
318
319    // Parse results
320    let mut findings = Vec::new();
321    let mut vulnerable_dep_names = std::collections::HashSet::new();
322
323    for (i, result) in all_results.iter().enumerate() {
324        if i >= valid_deps.len() {
325            break;
326        }
327        let dep = valid_deps[i];
328
329        if let Some(ref vulns) = result.vulns {
330            for vuln in vulns {
331                vulnerable_dep_names.insert(format!("{}@{}", dep.name, dep.version));
332                findings.push(VulnEntry {
333                    dep_name: dep.name.clone(),
334                    dep_version: clean_version(&dep.version).to_string(),
335                    ecosystem: dep.ecosystem.clone(),
336                    vuln_id: vuln.id.clone(),
337                    summary: vuln.summary.clone().unwrap_or_default(),
338                    severity: extract_severity(vuln),
339                    fixed_version: extract_fixed_version(vuln),
340                    url: extract_url(vuln),
341                });
342            }
343        }
344    }
345
346    // Sort findings: CRITICAL first, then HIGH, MEDIUM, LOW, UNKNOWN
347    findings.sort_by(|a, b| {
348        severity_rank(&a.severity)
349            .cmp(&severity_rank(&b.severity))
350            .then(a.dep_name.cmp(&b.dep_name))
351    });
352
353    Ok(VulnReport {
354        total_deps: deps.len(),
355        vulnerable_deps: vulnerable_dep_names.len(),
356        findings,
357    })
358}
359
360fn severity_rank(s: &str) -> u8 {
361    match s {
362        "CRITICAL" => 0,
363        "HIGH" => 1,
364        "MEDIUM" => 2,
365        "LOW" => 3,
366        _ => 4,
367    }
368}
369
370/// Filter report findings by minimum severity.
371pub fn filter_by_severity(report: &mut VulnReport, min_severity: &str) {
372    let min_rank = severity_rank(&min_severity.to_uppercase());
373    report
374        .findings
375        .retain(|f| severity_rank(&f.severity) <= min_rank);
376    let mut names = std::collections::HashSet::new();
377    for f in &report.findings {
378        names.insert(format!("{}@{}", f.dep_name, f.dep_version));
379    }
380    report.vulnerable_deps = names.len();
381}
382
383/// Filter report findings by ecosystem.
384pub fn filter_by_ecosystem(report: &mut VulnReport, ecosystem: &str) {
385    report
386        .findings
387        .retain(|f| f.ecosystem.eq_ignore_ascii_case(ecosystem));
388    let mut names = std::collections::HashSet::new();
389    for f in &report.findings {
390        names.insert(format!("{}@{}", f.dep_name, f.dep_version));
391    }
392    report.vulnerable_deps = names.len();
393}
394
395/// Format the report as a human-readable table.
396pub fn format_table(report: &VulnReport) -> String {
397    if report.findings.is_empty() {
398        return format!(
399            "Vulnerability Scan Results\n\n  No vulnerabilities found ({} dependencies scanned)\n",
400            report.total_deps
401        );
402    }
403
404    let mut out = String::from("Vulnerability Scan Results\n\n");
405
406    // Header
407    out.push_str(&format!(
408        "  {:<20} {:<12} {:<18} {:<10} {}\n",
409        "Dep", "Version", "Vuln ID", "Severity", "Summary"
410    ));
411
412    // Findings
413    for f in &report.findings {
414        let summary_truncated = if f.summary.len() > 60 {
415            format!("{}...", &f.summary[..57])
416        } else {
417            f.summary.clone()
418        };
419        out.push_str(&format!(
420            "  {:<20} {:<12} {:<18} {:<10} {}\n",
421            truncate_str(&f.dep_name, 20),
422            truncate_str(&f.dep_version, 12),
423            truncate_str(&f.vuln_id, 18),
424            &f.severity,
425            summary_truncated,
426        ));
427    }
428
429    out.push_str(&format!(
430        "\n  {} vulnerable dependencies found (out of {} scanned)\n",
431        report.vulnerable_deps, report.total_deps
432    ));
433
434    out
435}
436
437/// Format the report as JSON.
438pub fn format_json(report: &VulnReport) -> String {
439    serde_json::to_string_pretty(report).unwrap_or_else(|_| "{}".to_string())
440}
441
442fn truncate_str(s: &str, max: usize) -> String {
443    if s.len() > max {
444        format!("{}...", &s[..max - 3])
445    } else {
446        s.to_string()
447    }
448}