1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8use super::types::{Severity, Vulnerability};
9use crate::{Error, Result};
10
11const OSV_API_URL: &str = "https://api.osv.dev/v1";
13
14pub struct OsvClient {
16 client: reqwest::Client,
17}
18
19impl OsvClient {
20 pub fn new() -> Self {
22 Self {
23 client: reqwest::Client::builder()
24 .user_agent("Pro/0.1.0")
25 .build()
26 .expect("Failed to create HTTP client"),
27 }
28 }
29
30 pub async fn query(&self, package: &str, version: &str) -> Result<Vec<Vulnerability>> {
32 let request = OsvQueryRequest {
33 package: OsvPackage {
34 name: package.to_string(),
35 ecosystem: "PyPI".to_string(),
36 },
37 version: version.to_string(),
38 };
39
40 let response = self
41 .client
42 .post(format!("{}/query", OSV_API_URL))
43 .json(&request)
44 .send()
45 .await
46 .map_err(Error::Network)?;
47
48 if !response.status().is_success() {
49 return Err(Error::Index(format!(
50 "OSV API error: HTTP {}",
51 response.status()
52 )));
53 }
54
55 let osv_response: OsvQueryResponse = response.json().await.map_err(Error::Network)?;
56
57 Ok(osv_response
58 .vulns
59 .unwrap_or_default()
60 .into_iter()
61 .map(|v| convert_osv_vuln(v, package))
62 .collect())
63 }
64
65 pub async fn query_batch(
68 &self,
69 packages: &[(&str, &str)], ) -> Result<HashMap<String, Vec<Vulnerability>>> {
71 if packages.is_empty() {
72 return Ok(HashMap::new());
73 }
74
75 let queries: Vec<OsvBatchQuery> = packages
77 .iter()
78 .map(|(name, version)| OsvBatchQuery {
79 package: OsvPackage {
80 name: name.to_string(),
81 ecosystem: "PyPI".to_string(),
82 },
83 version: version.to_string(),
84 })
85 .collect();
86
87 let request = OsvBatchRequest { queries };
88
89 let response = self
90 .client
91 .post(format!("{}/querybatch", OSV_API_URL))
92 .json(&request)
93 .send()
94 .await
95 .map_err(Error::Network)?;
96
97 if !response.status().is_success() {
98 return Err(Error::Index(format!(
99 "OSV API error: HTTP {}",
100 response.status()
101 )));
102 }
103
104 let batch_response: OsvBatchResponse = response.json().await.map_err(Error::Network)?;
105
106 let mut vulnerable_packages: Vec<(&str, &str)> = Vec::new();
108 for (i, result) in batch_response.results.iter().enumerate() {
109 if i < packages.len() && result.vulns.as_ref().is_some_and(|v| !v.is_empty()) {
110 vulnerable_packages.push(packages[i]);
111 }
112 }
113
114 let mut results = HashMap::new();
116 for (name, version) in vulnerable_packages {
117 match self.query(name, version).await {
118 Ok(vulns) => {
119 if !vulns.is_empty() {
120 results.insert(name.to_string(), vulns);
121 }
122 }
123 Err(e) => {
124 tracing::warn!("Failed to fetch vulnerability details for {}: {}", name, e);
125 }
126 }
127 }
128
129 Ok(results)
130 }
131}
132
133impl Default for OsvClient {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139fn convert_osv_vuln(osv: OsvVulnerability, package: &str) -> Vulnerability {
141 let (severity, cvss_score) = extract_severity(&osv);
143
144 let fixed_version = osv
147 .affected
148 .iter()
149 .filter(|a| {
150 a.package
151 .as_ref()
152 .map(|p| p.name.to_lowercase() == package.to_lowercase())
153 .unwrap_or(true) })
155 .flat_map(|a| &a.ranges)
156 .flat_map(|r| &r.events)
157 .find_map(|e| e.fixed.clone());
158
159 let affected_versions: Vec<String> = osv
161 .affected
162 .iter()
163 .filter(|a| {
164 a.package
165 .as_ref()
166 .map(|p| p.name.to_lowercase() == package.to_lowercase())
167 .unwrap_or(false)
168 })
169 .flat_map(|a| &a.versions)
170 .cloned()
171 .collect();
172
173 Vulnerability {
174 id: osv.id.clone(),
175 aliases: osv.aliases.unwrap_or_default(),
176 summary: osv.summary.unwrap_or_else(|| osv.id.clone()),
177 details: osv.details.unwrap_or_default(),
178 severity,
179 cvss_score,
180 package: package.to_string(),
181 affected_versions,
182 fixed_version,
183 references: osv
184 .references
185 .unwrap_or_default()
186 .into_iter()
187 .map(|r| r.url)
188 .collect(),
189 published: osv.published,
190 modified: osv.modified,
191 }
192}
193
194fn extract_severity(osv: &OsvVulnerability) -> (Severity, Option<f32>) {
196 if let Some(severities) = &osv.severity {
198 for sev in severities {
199 if sev.severity_type == "CVSS_V3" || sev.severity_type == "CVSS_V2" {
201 if let Some(score) = parse_cvss_score(&sev.score) {
203 let severity = cvss_to_severity(score);
204 return (severity, Some(score));
205 }
206 if let Ok(score) = sev.score.parse::<f32>() {
208 let severity = cvss_to_severity(score);
209 return (severity, Some(score));
210 }
211 }
212 }
213 }
214
215 if let Some(db_specific) = &osv.database_specific {
217 if let Some(severity_str) = &db_specific.severity {
218 return (severity_str.parse().unwrap_or(Severity::Unknown), None);
219 }
220 if let Some(score) = db_specific.cvss_score {
222 let severity = cvss_to_severity(score);
223 return (severity, Some(score));
224 }
225 if let Some(cvss) = &db_specific.cvss {
227 if let Some(score) = cvss.get("score").and_then(|v| v.as_f64()) {
228 let severity = cvss_to_severity(score as f32);
229 return (severity, Some(score as f32));
230 }
231 }
232 }
233
234 for affected in &osv.affected {
236 if let Some(sev) = &affected.database_specific {
237 if let Some(severity_str) = sev.get("severity").and_then(|v| v.as_str()) {
238 return (severity_str.parse().unwrap_or(Severity::Unknown), None);
239 }
240 }
241 }
242
243 (Severity::Unknown, None)
244}
245
246fn parse_cvss_score(vector: &str) -> Option<f32> {
248 if let Ok(score) = vector.parse::<f32>() {
251 return Some(score);
252 }
253 None
254}
255
256fn cvss_to_severity(score: f32) -> Severity {
258 match score {
259 s if s >= 9.0 => Severity::Critical,
260 s if s >= 7.0 => Severity::High,
261 s if s >= 4.0 => Severity::Medium,
262 s if s > 0.0 => Severity::Low,
263 _ => Severity::Unknown,
264 }
265}
266
267#[derive(Debug, Serialize)]
270struct OsvQueryRequest {
271 package: OsvPackage,
272 version: String,
273}
274
275#[derive(Debug, Serialize)]
276struct OsvPackage {
277 name: String,
278 ecosystem: String,
279}
280
281#[derive(Debug, Deserialize)]
282struct OsvQueryResponse {
283 vulns: Option<Vec<OsvVulnerability>>,
284}
285
286#[derive(Debug, Serialize)]
287struct OsvBatchRequest {
288 queries: Vec<OsvBatchQuery>,
289}
290
291#[derive(Debug, Serialize)]
292struct OsvBatchQuery {
293 package: OsvPackage,
294 version: String,
295}
296
297#[derive(Debug, Deserialize)]
298struct OsvBatchResponse {
299 results: Vec<OsvBatchResult>,
300}
301
302#[derive(Debug, Deserialize)]
303struct OsvBatchResult {
304 vulns: Option<Vec<OsvVulnerability>>,
305}
306
307#[derive(Debug, Deserialize)]
308struct OsvVulnerability {
309 id: String,
310 aliases: Option<Vec<String>>,
311 summary: Option<String>,
312 details: Option<String>,
313 severity: Option<Vec<OsvSeverity>>,
314 #[serde(default)]
315 affected: Vec<OsvAffected>,
316 references: Option<Vec<OsvReference>>,
317 database_specific: Option<OsvDatabaseSpecific>,
318 published: Option<String>,
319 modified: Option<String>,
320}
321
322#[derive(Debug, Deserialize)]
323struct OsvSeverity {
324 #[serde(rename = "type")]
325 severity_type: String,
326 score: String,
327}
328
329#[derive(Debug, Deserialize)]
330struct OsvAffected {
331 package: Option<OsvAffectedPackage>,
332 #[serde(default)]
333 ranges: Vec<OsvRange>,
334 #[serde(default)]
335 versions: Vec<String>,
336 #[serde(default)]
337 database_specific: Option<serde_json::Value>,
338}
339
340#[derive(Debug, Deserialize)]
341struct OsvAffectedPackage {
342 name: String,
343 #[allow(dead_code)]
344 ecosystem: String,
345}
346
347#[derive(Debug, Deserialize)]
348struct OsvRange {
349 #[serde(rename = "type")]
350 #[allow(dead_code)]
351 range_type: String,
352 events: Vec<OsvEvent>,
353}
354
355#[derive(Debug, Deserialize)]
356struct OsvEvent {
357 #[allow(dead_code)]
358 introduced: Option<String>,
359 fixed: Option<String>,
360}
361
362#[derive(Debug, Deserialize)]
363struct OsvReference {
364 url: String,
365}
366
367#[derive(Debug, Deserialize)]
368struct OsvDatabaseSpecific {
369 severity: Option<String>,
370 #[serde(default)]
371 cvss_score: Option<f32>,
372 #[serde(default)]
373 cvss: Option<serde_json::Value>,
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn test_cvss_to_severity() {
382 assert_eq!(cvss_to_severity(9.5), Severity::Critical);
383 assert_eq!(cvss_to_severity(9.0), Severity::Critical);
384 assert_eq!(cvss_to_severity(8.0), Severity::High);
385 assert_eq!(cvss_to_severity(7.0), Severity::High);
386 assert_eq!(cvss_to_severity(5.0), Severity::Medium);
387 assert_eq!(cvss_to_severity(4.0), Severity::Medium);
388 assert_eq!(cvss_to_severity(2.0), Severity::Low);
389 assert_eq!(cvss_to_severity(0.0), Severity::Unknown);
390 }
391
392 #[tokio::test]
393 #[ignore] async fn test_query_known_vulnerable_package() {
395 let client = OsvClient::new();
396 let vulns = client.query("urllib3", "1.26.0").await.unwrap();
398 assert!(!vulns.is_empty());
400 }
401}