use std::time::Duration;
use serde::Deserialize;
use tracing::warn;
use crate::types::{Finding, FindingCategory, Severity};
const OSV_API_URL: &str = "https://api.osv.dev/v1/query";
const REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Debug, Deserialize)]
struct OsvResponse {
#[serde(default)]
vulns: Vec<OsvVuln>,
}
#[derive(Debug, Deserialize)]
struct OsvVuln {
id: String,
#[serde(default)]
summary: Option<String>,
#[serde(default)]
details: Option<String>,
#[serde(default)]
severity: Vec<OsvSeverity>,
#[serde(default)]
references: Vec<OsvReference>,
}
#[derive(Debug, Deserialize)]
struct OsvSeverity {
#[serde(rename = "type", default)]
severity_type: Option<String>,
#[serde(default)]
score: Option<String>,
}
#[derive(Debug, Deserialize)]
struct OsvReference {
#[allow(dead_code)]
#[serde(rename = "type", default)]
ref_type: Option<String>,
#[serde(default)]
url: Option<String>,
}
pub struct CveChecker {
client: reqwest::Client,
}
impl CveChecker {
pub fn new() -> Self {
let client = reqwest::Client::builder()
.timeout(REQUEST_TIMEOUT)
.build()
.unwrap_or_default();
Self { client }
}
pub async fn check(&self, name: &str, version: &str) -> Vec<Finding> {
let body = serde_json::json!({
"package": {
"name": name,
"ecosystem": "npm"
},
"version": version
});
let response = match self.client.post(OSV_API_URL).json(&body).send().await {
Ok(r) => r,
Err(e) => {
warn!("OSV API request failed for {name}@{version}: {e}");
return Vec::new();
}
};
if !response.status().is_success() {
warn!(
"OSV API returned status {} for {name}@{version}",
response.status()
);
return Vec::new();
}
let osv: OsvResponse = match response.json().await {
Ok(r) => r,
Err(e) => {
warn!("Failed to parse OSV response for {name}@{version}: {e}");
return Vec::new();
}
};
osv.vulns
.into_iter()
.map(|vuln| self.vuln_to_finding(name, version, vuln))
.collect()
}
fn vuln_to_finding(&self, name: &str, version: &str, vuln: OsvVuln) -> Finding {
let severity = self.determine_severity(&vuln);
let summary = vuln
.summary
.as_deref()
.or(vuln.details.as_deref())
.unwrap_or("No description available");
let link = vuln
.references
.iter()
.find_map(|r| r.url.clone())
.unwrap_or_else(|| format!("https://osv.dev/vulnerability/{}", vuln.id));
Finding {
severity,
category: FindingCategory::KnownVulnerability,
title: format!("{}: {}", vuln.id, truncate(summary, 80)),
description: format!(
"Package {name}@{version} is affected by {}.\n\n{summary}\n\nMore info: {link}",
vuln.id
),
file: None,
line: None,
snippet: None,
}
}
fn determine_severity(&self, vuln: &OsvVuln) -> Severity {
for sev in &vuln.severity {
if let Some(score_str) = &sev.score {
if let Some(cvss) = parse_cvss_score(score_str) {
if cvss >= 9.0 {
return Severity::Critical;
} else if cvss >= 7.0 {
return Severity::High;
} else if cvss >= 4.0 {
return Severity::Medium;
} else {
return Severity::Low;
}
}
let upper = score_str.to_uppercase();
if upper.contains("CRITICAL") {
return Severity::Critical;
} else if upper.contains("HIGH") {
return Severity::High;
} else if upper.contains("MODERATE") || upper.contains("MEDIUM") {
return Severity::Medium;
}
}
if let Some(st) = &sev.severity_type {
let upper = st.to_uppercase();
if upper.contains("CRITICAL") {
return Severity::Critical;
} else if upper.contains("HIGH") {
return Severity::High;
}
}
}
Severity::Medium
}
}
fn parse_cvss_score(s: &str) -> Option<f64> {
if let Ok(v) = s.parse::<f64>() {
return Some(v);
}
None
}
fn truncate(s: &str, max: usize) -> String {
if s.len() <= max {
s.to_string()
} else {
let t: String = s.chars().take(max).collect();
format!("{t}...")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_cvss_score_plain() {
assert_eq!(parse_cvss_score("9.8"), Some(9.8));
assert_eq!(parse_cvss_score("7.0"), Some(7.0));
}
#[test]
fn test_parse_cvss_score_vector() {
assert_eq!(
parse_cvss_score("CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H"),
None
);
}
#[test]
fn test_truncate_short() {
assert_eq!(truncate("hello", 10), "hello");
}
#[test]
fn test_truncate_long() {
let long = "a".repeat(100);
let t = truncate(&long, 10);
assert_eq!(t.len(), 13); }
}