Skip to main content

ai_coding_shield/scanner/
vuln_db.rs

1use serde::{Deserialize, Serialize};
2use reqwest::Client;
3
4
5#[derive(Debug, Serialize, Deserialize)]
6struct OsvQuery {
7    package: OsvPackage,
8}
9
10#[derive(Debug, Serialize, Deserialize)]
11struct OsvPackage {
12    name: String,
13    ecosystem: String,
14}
15
16#[derive(Debug, Serialize, Deserialize)]
17struct OsvResponse {
18    vulns: Option<Vec<OsvVuln>>,
19}
20
21#[derive(Debug, Serialize, Deserialize)]
22struct OsvVuln {
23    id: String,
24    summary: Option<String>,
25    details: Option<String>,
26    severity: Option<Vec<OsvSeverity>>,
27    database_specific: Option<OsvDatabaseSpecific>,
28}
29
30#[derive(Debug, Serialize, Deserialize)]
31struct OsvSeverity {
32    #[serde(rename = "type")]
33    severity_type: String,
34    score: String,
35}
36
37#[derive(Debug, Serialize, Deserialize)]
38struct OsvDatabaseSpecific {
39    severity: Option<String>, // LOW, MODERATE, HIGH, CRITICAL
40}
41
42pub struct VulnerabilityScanner {
43    client: Client,
44}
45
46impl VulnerabilityScanner {
47    pub fn new() -> Self {
48        Self {
49            client: Client::new(),
50        }
51    }
52
53    pub async fn check_package(&self, ecosystem: &str, name: &str) -> anyhow::Result<Option<String>> {
54        let query = OsvQuery {
55            package: OsvPackage {
56                name: name.to_string(),
57                ecosystem: ecosystem.to_string(),
58            },
59        };
60
61        let resp = self.client.post("https://api.osv.dev/v1/query")
62            .json(&query)
63            .send()
64            .await?;
65
66        if !resp.status().is_success() {
67            return Ok(None);
68        }
69
70        let body: OsvResponse = resp.json().await?;
71
72        if let Some(vulns) = body.vulns {
73            if vulns.is_empty() {
74                return Ok(None);
75            }
76
77            // Summarize vulnerabilities
78            let mut summary;
79            let count = vulns.len();
80            
81            // Find highest severity
82            let mut highest_sev = "LOW";
83            
84            for vuln in &vulns {
85                if let Some(db) = &vuln.database_specific {
86                     if let Some(sev) = &db.severity {
87                         if sev == "CRITICAL" { highest_sev = "CRITICAL"; }
88                         else if sev == "HIGH" && highest_sev != "CRITICAL" { highest_sev = "HIGH"; }
89                         else if sev == "MODERATE" && highest_sev != "CRITICAL" && highest_sev != "HIGH" { highest_sev = "MODERATE"; }
90                     }
91                }
92            }
93            
94            summary = format!("Found {} known vulnerabilities (Max: {}). ", count, highest_sev);
95            
96            // Add top 3 IDs
97            let ids: Vec<String> = vulns.iter().take(3).map(|v| v.id.clone()).collect();
98            summary.push_str(&format!("IDs: {}", ids.join(", ")));
99
100            return Ok(Some(summary));
101        }
102
103        Ok(None)
104    }
105}