Skip to main content

pro_core/audit/
osv.rs

1//! OSV (Open Source Vulnerabilities) API client
2//!
3//! Documentation: https://osv.dev/docs/
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8use super::types::{Severity, Vulnerability};
9use crate::{Error, Result};
10
11/// OSV API base URL
12const OSV_API_URL: &str = "https://api.osv.dev/v1";
13
14/// OSV API client
15pub struct OsvClient {
16    client: reqwest::Client,
17}
18
19impl OsvClient {
20    /// Create a new OSV client
21    pub fn new() -> Self {
22        Self {
23            client: reqwest::Client::builder()
24                .user_agent("Pro/0.1.0")
25                .build()
26                .expect("Failed to create HTTP client"),
27        }
28    }
29
30    /// Query vulnerabilities for a single package
31    pub async fn query(&self, package: &str, version: &str) -> Result<Vec<Vulnerability>> {
32        let request = OsvQueryRequest {
33            package: OsvPackage {
34                name: package.to_string(),
35                ecosystem: "PyPI".to_string(),
36            },
37            version: version.to_string(),
38        };
39
40        let response = self
41            .client
42            .post(format!("{}/query", OSV_API_URL))
43            .json(&request)
44            .send()
45            .await
46            .map_err(Error::Network)?;
47
48        if !response.status().is_success() {
49            return Err(Error::Index(format!(
50                "OSV API error: HTTP {}",
51                response.status()
52            )));
53        }
54
55        let osv_response: OsvQueryResponse = response.json().await.map_err(Error::Network)?;
56
57        Ok(osv_response
58            .vulns
59            .unwrap_or_default()
60            .into_iter()
61            .map(|v| convert_osv_vuln(v, package))
62            .collect())
63    }
64
65    /// Query vulnerabilities for multiple packages in batch
66    /// Uses batch API for detection, then fetches full details for affected packages
67    pub async fn query_batch(
68        &self,
69        packages: &[(&str, &str)], // (name, version)
70    ) -> Result<HashMap<String, Vec<Vulnerability>>> {
71        if packages.is_empty() {
72            return Ok(HashMap::new());
73        }
74
75        // First, use batch API to detect which packages have vulnerabilities
76        let queries: Vec<OsvBatchQuery> = packages
77            .iter()
78            .map(|(name, version)| OsvBatchQuery {
79                package: OsvPackage {
80                    name: name.to_string(),
81                    ecosystem: "PyPI".to_string(),
82                },
83                version: version.to_string(),
84            })
85            .collect();
86
87        let request = OsvBatchRequest { queries };
88
89        let response = self
90            .client
91            .post(format!("{}/querybatch", OSV_API_URL))
92            .json(&request)
93            .send()
94            .await
95            .map_err(Error::Network)?;
96
97        if !response.status().is_success() {
98            return Err(Error::Index(format!(
99                "OSV API error: HTTP {}",
100                response.status()
101            )));
102        }
103
104        let batch_response: OsvBatchResponse = response.json().await.map_err(Error::Network)?;
105
106        // Identify packages with vulnerabilities
107        let mut vulnerable_packages: Vec<(&str, &str)> = Vec::new();
108        for (i, result) in batch_response.results.iter().enumerate() {
109            if i < packages.len() && result.vulns.as_ref().is_some_and(|v| !v.is_empty()) {
110                vulnerable_packages.push(packages[i]);
111            }
112        }
113
114        // Fetch full details for vulnerable packages using single query API
115        let mut results = HashMap::new();
116        for (name, version) in vulnerable_packages {
117            match self.query(name, version).await {
118                Ok(vulns) => {
119                    if !vulns.is_empty() {
120                        results.insert(name.to_string(), vulns);
121                    }
122                }
123                Err(e) => {
124                    tracing::warn!("Failed to fetch vulnerability details for {}: {}", name, e);
125                }
126            }
127        }
128
129        Ok(results)
130    }
131}
132
133impl Default for OsvClient {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139/// Convert OSV vulnerability to our Vulnerability type
140fn convert_osv_vuln(osv: OsvVulnerability, package: &str) -> Vulnerability {
141    // Extract severity from database_specific or severity field
142    let (severity, cvss_score) = extract_severity(&osv);
143
144    // Find the fixed version for this package
145    // First try to match by package name, then fallback to any fixed version
146    let fixed_version = osv
147        .affected
148        .iter()
149        .filter(|a| {
150            a.package
151                .as_ref()
152                .map(|p| p.name.to_lowercase() == package.to_lowercase())
153                .unwrap_or(true) // Include entries without package info
154        })
155        .flat_map(|a| &a.ranges)
156        .flat_map(|r| &r.events)
157        .find_map(|e| e.fixed.clone());
158
159    // Collect affected versions
160    let affected_versions: Vec<String> = osv
161        .affected
162        .iter()
163        .filter(|a| {
164            a.package
165                .as_ref()
166                .map(|p| p.name.to_lowercase() == package.to_lowercase())
167                .unwrap_or(false)
168        })
169        .flat_map(|a| &a.versions)
170        .cloned()
171        .collect();
172
173    Vulnerability {
174        id: osv.id.clone(),
175        aliases: osv.aliases.unwrap_or_default(),
176        summary: osv.summary.unwrap_or_else(|| osv.id.clone()),
177        details: osv.details.unwrap_or_default(),
178        severity,
179        cvss_score,
180        package: package.to_string(),
181        affected_versions,
182        fixed_version,
183        references: osv
184            .references
185            .unwrap_or_default()
186            .into_iter()
187            .map(|r| r.url)
188            .collect(),
189        published: osv.published,
190        modified: osv.modified,
191    }
192}
193
194/// Extract severity and CVSS score from OSV vulnerability
195fn extract_severity(osv: &OsvVulnerability) -> (Severity, Option<f32>) {
196    // Try to get severity from the severity field first (CVSS score)
197    if let Some(severities) = &osv.severity {
198        for sev in severities {
199            // Try parsing as CVSS vector string (e.g., "CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:N/I:L/A:N")
200            if sev.severity_type == "CVSS_V3" || sev.severity_type == "CVSS_V2" {
201                // Try to extract base score from CVSS vector
202                if let Some(score) = parse_cvss_score(&sev.score) {
203                    let severity = cvss_to_severity(score);
204                    return (severity, Some(score));
205                }
206                // Try parsing as direct score
207                if let Ok(score) = sev.score.parse::<f32>() {
208                    let severity = cvss_to_severity(score);
209                    return (severity, Some(score));
210                }
211            }
212        }
213    }
214
215    // Try database_specific.severity (GHSA uses this)
216    if let Some(db_specific) = &osv.database_specific {
217        if let Some(severity_str) = &db_specific.severity {
218            return (severity_str.parse().unwrap_or(Severity::Unknown), None);
219        }
220        // Try CVSS score in database_specific
221        if let Some(score) = db_specific.cvss_score {
222            let severity = cvss_to_severity(score);
223            return (severity, Some(score));
224        }
225        // Try cvss object with score field
226        if let Some(cvss) = &db_specific.cvss {
227            if let Some(score) = cvss.get("score").and_then(|v| v.as_f64()) {
228                let severity = cvss_to_severity(score as f32);
229                return (severity, Some(score as f32));
230            }
231        }
232    }
233
234    // Try to extract from affected packages' severity
235    for affected in &osv.affected {
236        if let Some(sev) = &affected.database_specific {
237            if let Some(severity_str) = sev.get("severity").and_then(|v| v.as_str()) {
238                return (severity_str.parse().unwrap_or(Severity::Unknown), None);
239            }
240        }
241    }
242
243    (Severity::Unknown, None)
244}
245
246/// Try to parse CVSS score from vector string
247fn parse_cvss_score(vector: &str) -> Option<f32> {
248    // CVSS vectors sometimes include the score at the end or can be parsed
249    // For now, just check if it's a plain score
250    if let Ok(score) = vector.parse::<f32>() {
251        return Some(score);
252    }
253    None
254}
255
256/// Convert CVSS score to severity level
257fn cvss_to_severity(score: f32) -> Severity {
258    match score {
259        s if s >= 9.0 => Severity::Critical,
260        s if s >= 7.0 => Severity::High,
261        s if s >= 4.0 => Severity::Medium,
262        s if s > 0.0 => Severity::Low,
263        _ => Severity::Unknown,
264    }
265}
266
267// OSV API request/response types
268
269#[derive(Debug, Serialize)]
270struct OsvQueryRequest {
271    package: OsvPackage,
272    version: String,
273}
274
275#[derive(Debug, Serialize)]
276struct OsvPackage {
277    name: String,
278    ecosystem: String,
279}
280
281#[derive(Debug, Deserialize)]
282struct OsvQueryResponse {
283    vulns: Option<Vec<OsvVulnerability>>,
284}
285
286#[derive(Debug, Serialize)]
287struct OsvBatchRequest {
288    queries: Vec<OsvBatchQuery>,
289}
290
291#[derive(Debug, Serialize)]
292struct OsvBatchQuery {
293    package: OsvPackage,
294    version: String,
295}
296
297#[derive(Debug, Deserialize)]
298struct OsvBatchResponse {
299    results: Vec<OsvBatchResult>,
300}
301
302#[derive(Debug, Deserialize)]
303struct OsvBatchResult {
304    vulns: Option<Vec<OsvVulnerability>>,
305}
306
307#[derive(Debug, Deserialize)]
308struct OsvVulnerability {
309    id: String,
310    aliases: Option<Vec<String>>,
311    summary: Option<String>,
312    details: Option<String>,
313    severity: Option<Vec<OsvSeverity>>,
314    #[serde(default)]
315    affected: Vec<OsvAffected>,
316    references: Option<Vec<OsvReference>>,
317    database_specific: Option<OsvDatabaseSpecific>,
318    published: Option<String>,
319    modified: Option<String>,
320}
321
322#[derive(Debug, Deserialize)]
323struct OsvSeverity {
324    #[serde(rename = "type")]
325    severity_type: String,
326    score: String,
327}
328
329#[derive(Debug, Deserialize)]
330struct OsvAffected {
331    package: Option<OsvAffectedPackage>,
332    #[serde(default)]
333    ranges: Vec<OsvRange>,
334    #[serde(default)]
335    versions: Vec<String>,
336    #[serde(default)]
337    database_specific: Option<serde_json::Value>,
338}
339
340#[derive(Debug, Deserialize)]
341struct OsvAffectedPackage {
342    name: String,
343    #[allow(dead_code)]
344    ecosystem: String,
345}
346
347#[derive(Debug, Deserialize)]
348struct OsvRange {
349    #[serde(rename = "type")]
350    #[allow(dead_code)]
351    range_type: String,
352    events: Vec<OsvEvent>,
353}
354
355#[derive(Debug, Deserialize)]
356struct OsvEvent {
357    #[allow(dead_code)]
358    introduced: Option<String>,
359    fixed: Option<String>,
360}
361
362#[derive(Debug, Deserialize)]
363struct OsvReference {
364    url: String,
365}
366
367#[derive(Debug, Deserialize)]
368struct OsvDatabaseSpecific {
369    severity: Option<String>,
370    #[serde(default)]
371    cvss_score: Option<f32>,
372    #[serde(default)]
373    cvss: Option<serde_json::Value>,
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn test_cvss_to_severity() {
382        assert_eq!(cvss_to_severity(9.5), Severity::Critical);
383        assert_eq!(cvss_to_severity(9.0), Severity::Critical);
384        assert_eq!(cvss_to_severity(8.0), Severity::High);
385        assert_eq!(cvss_to_severity(7.0), Severity::High);
386        assert_eq!(cvss_to_severity(5.0), Severity::Medium);
387        assert_eq!(cvss_to_severity(4.0), Severity::Medium);
388        assert_eq!(cvss_to_severity(2.0), Severity::Low);
389        assert_eq!(cvss_to_severity(0.0), Severity::Unknown);
390    }
391
392    #[tokio::test]
393    #[ignore] // Requires network
394    async fn test_query_known_vulnerable_package() {
395        let client = OsvClient::new();
396        // urllib3 < 1.26.5 has known vulnerabilities
397        let vulns = client.query("urllib3", "1.26.0").await.unwrap();
398        // Should find at least one vulnerability
399        assert!(!vulns.is_empty());
400    }
401}