Skip to main content

cc_audit/
cve_db.rs

1//! CVE database for known vulnerabilities in AI coding tools.
2//!
3//! This module provides functionality to load and query a database of known CVEs
4//! affecting MCP servers, AI coding assistants, and related tools.
5
6use crate::rules::{Category, Confidence, Finding, Location, Severity};
7use serde::{Deserialize, Serialize};
8use std::fs;
9use std::path::Path;
10use thiserror::Error;
11
12/// Built-in CVE database (embedded at compile time)
13const BUILTIN_DATABASE: &str = include_str!("../data/cve-database.json");
14
15#[derive(Debug, Error)]
16pub enum CveDbError {
17    #[error("Failed to read CVE database file: {0}")]
18    ReadFile(#[from] std::io::Error),
19
20    #[error("Failed to parse CVE database JSON: {0}")]
21    ParseJson(#[from] serde_json::Error),
22
23    #[error("Failed to parse version requirement for {cve_id}: {version}")]
24    InvalidVersion { cve_id: String, version: String },
25}
26
27/// Affected product information in a CVE entry
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct AffectedProduct {
30    pub vendor: String,
31    pub product: String,
32    pub version_affected: String,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub version_fixed: Option<String>,
35}
36
37/// A CVE entry in the database
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct CveEntry {
40    pub id: String,
41    pub title: String,
42    pub description: String,
43    pub severity: String,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub cvss_score: Option<f32>,
46    pub affected_products: Vec<AffectedProduct>,
47    #[serde(default)]
48    pub cwe_ids: Vec<String>,
49    #[serde(default)]
50    pub references: Vec<String>,
51    pub published_at: String,
52}
53
54/// CVE database file format
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct CveDatabaseFile {
57    pub version: String,
58    pub updated_at: String,
59    pub entries: Vec<CveEntry>,
60}
61
62/// CVE database for querying known vulnerabilities
63pub struct CveDatabase {
64    entries: Vec<CveEntry>,
65    version: String,
66    updated_at: String,
67}
68
69impl CveDatabase {
70    /// Load the built-in CVE database
71    pub fn builtin() -> Result<Self, CveDbError> {
72        Self::from_json(BUILTIN_DATABASE)
73    }
74
75    /// Load CVE database from a JSON file
76    pub fn from_file(path: &Path) -> Result<Self, CveDbError> {
77        let content = fs::read_to_string(path)?;
78        Self::from_json(&content)
79    }
80
81    /// Load CVE database from a JSON string
82    pub fn from_json(json: &str) -> Result<Self, CveDbError> {
83        let file: CveDatabaseFile = serde_json::from_str(json)?;
84        Ok(Self {
85            entries: file.entries,
86            version: file.version,
87            updated_at: file.updated_at,
88        })
89    }
90
91    /// Get database version
92    pub fn version(&self) -> &str {
93        &self.version
94    }
95
96    /// Get last update timestamp
97    pub fn updated_at(&self) -> &str {
98        &self.updated_at
99    }
100
101    /// Get all entries
102    pub fn entries(&self) -> &[CveEntry] {
103        &self.entries
104    }
105
106    /// Get entry count
107    pub fn len(&self) -> usize {
108        self.entries.len()
109    }
110
111    /// Check if database is empty
112    pub fn is_empty(&self) -> bool {
113        self.entries.is_empty()
114    }
115
116    /// Check if a product/version combination is affected by any CVE
117    /// Returns matching CVE entries
118    pub fn check_product(&self, vendor: &str, product: &str, version: &str) -> Vec<&CveEntry> {
119        self.entries
120            .iter()
121            .filter(|entry| {
122                entry.affected_products.iter().any(|p| {
123                    p.vendor.eq_ignore_ascii_case(vendor)
124                        && p.product.eq_ignore_ascii_case(product)
125                        && Self::version_matches(&p.version_affected, version)
126                })
127            })
128            .collect()
129    }
130
131    /// Check if a version string matches a version requirement
132    /// Supports: "< X.Y.Z", "<= X.Y.Z", "= X.Y.Z", ">= X.Y.Z", "> X.Y.Z"
133    fn version_matches(requirement: &str, version: &str) -> bool {
134        let requirement = requirement.trim();
135
136        // Parse the operator and version from the requirement
137        let (op, req_version) = if let Some(rest) = requirement.strip_prefix("<=") {
138            ("<=", rest.trim())
139        } else if let Some(rest) = requirement.strip_prefix(">=") {
140            (">=", rest.trim())
141        } else if let Some(rest) = requirement.strip_prefix('<') {
142            ("<", rest.trim())
143        } else if let Some(rest) = requirement.strip_prefix('>') {
144            (">", rest.trim())
145        } else if let Some(rest) = requirement.strip_prefix('=') {
146            ("=", rest.trim())
147        } else {
148            ("=", requirement) // Default to exact match
149        };
150
151        // Parse versions into comparable parts
152        let version_parts = Self::parse_version(version);
153        let req_parts = Self::parse_version(req_version);
154
155        match op {
156            "<" => Self::compare_versions(&version_parts, &req_parts) < 0,
157            "<=" => Self::compare_versions(&version_parts, &req_parts) <= 0,
158            ">" => Self::compare_versions(&version_parts, &req_parts) > 0,
159            ">=" => Self::compare_versions(&version_parts, &req_parts) >= 0,
160            _ => Self::compare_versions(&version_parts, &req_parts) == 0,
161        }
162    }
163
164    /// Parse version string into comparable parts
165    fn parse_version(version: &str) -> Vec<u32> {
166        version
167            .split(['.', '-', '_'])
168            .filter_map(|s| {
169                // Extract leading numeric part
170                let num_str: String = s.chars().take_while(|c| c.is_ascii_digit()).collect();
171                num_str.parse().ok()
172            })
173            .collect()
174    }
175
176    /// Compare two parsed versions
177    /// Returns: -1 if a < b, 0 if a == b, 1 if a > b
178    fn compare_versions(a: &[u32], b: &[u32]) -> i32 {
179        let max_len = a.len().max(b.len());
180        for i in 0..max_len {
181            let av = a.get(i).copied().unwrap_or(0);
182            let bv = b.get(i).copied().unwrap_or(0);
183            if av < bv {
184                return -1;
185            }
186            if av > bv {
187                return 1;
188            }
189        }
190        0
191    }
192
193    /// Create findings for matching CVEs
194    pub fn create_findings(
195        &self,
196        vendor: &str,
197        product: &str,
198        version: &str,
199        file_path: &str,
200        line: usize,
201    ) -> Vec<Finding> {
202        let matches = self.check_product(vendor, product, version);
203
204        matches
205            .into_iter()
206            .map(|cve| Finding {
207                id: cve.id.clone(),
208                severity: Self::parse_severity(&cve.severity),
209                category: Category::SupplyChain,
210                confidence: Confidence::Certain,
211                name: cve.title.clone(),
212                location: Location {
213                    file: file_path.to_string(),
214                    line,
215                    column: None,
216                },
217                code: format!("{}/{} v{}", vendor, product, version),
218                message: cve.description.clone(),
219                recommendation: if let Some(ref fixed) = cve
220                    .affected_products
221                    .iter()
222                    .find(|p| {
223                        p.vendor.eq_ignore_ascii_case(vendor)
224                            && p.product.eq_ignore_ascii_case(product)
225                    })
226                    .and_then(|p| p.version_fixed.clone())
227                {
228                    format!("Update to version {} or later", fixed)
229                } else {
230                    "Check for security updates from the vendor".to_string()
231                },
232                fix_hint: None,
233                cwe_ids: cve.cwe_ids.clone(),
234                rule_severity: None,
235                client: None,
236                context: None,
237            })
238            .collect()
239    }
240
241    fn parse_severity(s: &str) -> Severity {
242        match s.to_lowercase().as_str() {
243            "critical" => Severity::Critical,
244            "high" => Severity::High,
245            "medium" => Severity::Medium,
246            "low" => Severity::Low,
247            _ => Severity::Medium,
248        }
249    }
250}
251
252impl Default for CveDatabase {
253    fn default() -> Self {
254        Self::builtin().expect("Built-in CVE database should be valid")
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_load_builtin_database() {
264        let db = CveDatabase::builtin().unwrap();
265        assert!(!db.is_empty());
266        assert_eq!(db.version(), "1.0.0");
267    }
268
269    #[test]
270    fn test_version_comparison_less_than() {
271        assert!(CveDatabase::version_matches("< 1.5.0", "1.4.9"));
272        assert!(CveDatabase::version_matches("< 1.5.0", "1.4.0"));
273        assert!(CveDatabase::version_matches("< 1.5.0", "0.9.0"));
274        assert!(!CveDatabase::version_matches("< 1.5.0", "1.5.0"));
275        assert!(!CveDatabase::version_matches("< 1.5.0", "1.5.1"));
276        assert!(!CveDatabase::version_matches("< 1.5.0", "2.0.0"));
277    }
278
279    #[test]
280    fn test_version_comparison_less_than_or_equal() {
281        assert!(CveDatabase::version_matches("<= 1.5.0", "1.4.9"));
282        assert!(CveDatabase::version_matches("<= 1.5.0", "1.5.0"));
283        assert!(!CveDatabase::version_matches("<= 1.5.0", "1.5.1"));
284    }
285
286    #[test]
287    fn test_version_comparison_greater_than() {
288        assert!(CveDatabase::version_matches("> 1.5.0", "1.5.1"));
289        assert!(CveDatabase::version_matches("> 1.5.0", "2.0.0"));
290        assert!(!CveDatabase::version_matches("> 1.5.0", "1.5.0"));
291        assert!(!CveDatabase::version_matches("> 1.5.0", "1.4.9"));
292    }
293
294    #[test]
295    fn test_version_comparison_equal() {
296        assert!(CveDatabase::version_matches("= 1.5.0", "1.5.0"));
297        assert!(!CveDatabase::version_matches("= 1.5.0", "1.5.1"));
298        assert!(!CveDatabase::version_matches("= 1.5.0", "1.4.9"));
299    }
300
301    #[test]
302    fn test_check_product_matches() {
303        let db = CveDatabase::builtin().unwrap();
304        let matches = db.check_product("anthropic", "claude-code-vscode", "1.4.0");
305        assert!(!matches.is_empty());
306        assert!(matches.iter().any(|e| e.id == "CVE-2025-52882"));
307    }
308
309    #[test]
310    fn test_check_product_no_match_fixed_version() {
311        let db = CveDatabase::builtin().unwrap();
312        let matches = db.check_product("anthropic", "claude-code-vscode", "1.5.0");
313        assert!(matches.is_empty());
314    }
315
316    #[test]
317    fn test_check_product_case_insensitive() {
318        let db = CveDatabase::builtin().unwrap();
319        let matches = db.check_product("Anthropic", "Claude-Code-VSCode", "1.4.0");
320        assert!(!matches.is_empty());
321    }
322
323    #[test]
324    fn test_create_findings() {
325        let db = CveDatabase::builtin().unwrap();
326        let findings = db.create_findings(
327            "anthropic",
328            "claude-code-vscode",
329            "1.4.0",
330            "package.json",
331            10,
332        );
333        assert!(!findings.is_empty());
334
335        let finding = &findings[0];
336        assert_eq!(finding.id, "CVE-2025-52882");
337        assert_eq!(finding.severity, Severity::Critical);
338        assert_eq!(finding.category, Category::SupplyChain);
339        assert!(finding.recommendation.contains("1.5.0"));
340    }
341
342    #[test]
343    fn test_parse_version_with_prerelease() {
344        let parts = CveDatabase::parse_version("1.5.0-beta.1");
345        assert_eq!(parts, vec![1, 5, 0, 1]);
346    }
347
348    #[test]
349    fn test_entry_count() {
350        let db = CveDatabase::builtin().unwrap();
351        assert_eq!(db.len(), 7); // 7 CVEs in built-in database
352    }
353}