Skip to main content

cc_audit/
cve_db.rs

1//! CVE database for known vulnerabilities in AI coding tools.
2//!
3//! This module provides functionality to load and query a database of known CVEs
4//! affecting MCP servers, AI coding assistants, and related tools.
5
6use crate::rules::{Category, Confidence, Finding, Location, Severity};
7use serde::{Deserialize, Serialize};
8use std::fs;
9use std::path::Path;
10use thiserror::Error;
11
12/// Built-in CVE database (embedded at compile time)
13const BUILTIN_DATABASE: &str = include_str!("../data/cve-database.json");
14
15#[derive(Debug, Error)]
16pub enum CveDbError {
17    #[error("Failed to read CVE database file: {0}")]
18    ReadFile(#[from] std::io::Error),
19
20    #[error("Failed to parse CVE database JSON: {0}")]
21    ParseJson(#[from] serde_json::Error),
22
23    #[error("Failed to parse version requirement for {cve_id}: {version}")]
24    InvalidVersion { cve_id: String, version: String },
25}
26
27/// Affected product information in a CVE entry
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct AffectedProduct {
30    pub vendor: String,
31    pub product: String,
32    pub version_affected: String,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub version_fixed: Option<String>,
35}
36
37/// A CVE entry in the database
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct CveEntry {
40    pub id: String,
41    pub title: String,
42    pub description: String,
43    pub severity: String,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub cvss_score: Option<f32>,
46    pub affected_products: Vec<AffectedProduct>,
47    #[serde(default)]
48    pub cwe_ids: Vec<String>,
49    #[serde(default)]
50    pub references: Vec<String>,
51    pub published_at: String,
52}
53
54/// CVE database file format
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct CveDatabaseFile {
57    pub version: String,
58    pub updated_at: String,
59    pub entries: Vec<CveEntry>,
60}
61
62/// CVE database for querying known vulnerabilities
63pub struct CveDatabase {
64    entries: Vec<CveEntry>,
65    version: String,
66    updated_at: String,
67}
68
69impl CveDatabase {
70    /// Load the built-in CVE database
71    pub fn builtin() -> Result<Self, CveDbError> {
72        Self::from_json(BUILTIN_DATABASE)
73    }
74
75    /// Load CVE database from a JSON file
76    pub fn from_file(path: &Path) -> Result<Self, CveDbError> {
77        let content = fs::read_to_string(path)?;
78        Self::from_json(&content)
79    }
80
81    /// Load CVE database from a JSON string
82    pub fn from_json(json: &str) -> Result<Self, CveDbError> {
83        let file: CveDatabaseFile = serde_json::from_str(json)?;
84        Ok(Self {
85            entries: file.entries,
86            version: file.version,
87            updated_at: file.updated_at,
88        })
89    }
90
91    /// Get database version
92    pub fn version(&self) -> &str {
93        &self.version
94    }
95
96    /// Get last update timestamp
97    pub fn updated_at(&self) -> &str {
98        &self.updated_at
99    }
100
101    /// Get all entries
102    pub fn entries(&self) -> &[CveEntry] {
103        &self.entries
104    }
105
106    /// Get entry count
107    pub fn len(&self) -> usize {
108        self.entries.len()
109    }
110
111    /// Check if database is empty
112    pub fn is_empty(&self) -> bool {
113        self.entries.is_empty()
114    }
115
116    /// Check if a product/version combination is affected by any CVE
117    /// Returns matching CVE entries
118    pub fn check_product(&self, vendor: &str, product: &str, version: &str) -> Vec<&CveEntry> {
119        self.entries
120            .iter()
121            .filter(|entry| {
122                entry.affected_products.iter().any(|p| {
123                    p.vendor.eq_ignore_ascii_case(vendor)
124                        && p.product.eq_ignore_ascii_case(product)
125                        && Self::version_matches(&p.version_affected, version)
126                })
127            })
128            .collect()
129    }
130
131    /// Check if a version string matches a version requirement
132    /// Supports: "< X.Y.Z", "<= X.Y.Z", "= X.Y.Z", ">= X.Y.Z", "> X.Y.Z"
133    fn version_matches(requirement: &str, version: &str) -> bool {
134        let requirement = requirement.trim();
135
136        // Parse the operator and version from the requirement
137        let (op, req_version) = if let Some(rest) = requirement.strip_prefix("<=") {
138            ("<=", rest.trim())
139        } else if let Some(rest) = requirement.strip_prefix(">=") {
140            (">=", rest.trim())
141        } else if let Some(rest) = requirement.strip_prefix('<') {
142            ("<", rest.trim())
143        } else if let Some(rest) = requirement.strip_prefix('>') {
144            (">", rest.trim())
145        } else if let Some(rest) = requirement.strip_prefix('=') {
146            ("=", rest.trim())
147        } else {
148            ("=", requirement) // Default to exact match
149        };
150
151        // Parse versions into comparable parts
152        let version_parts = Self::parse_version(version);
153        let req_parts = Self::parse_version(req_version);
154
155        match op {
156            "<" => Self::compare_versions(&version_parts, &req_parts) < 0,
157            "<=" => Self::compare_versions(&version_parts, &req_parts) <= 0,
158            ">" => Self::compare_versions(&version_parts, &req_parts) > 0,
159            ">=" => Self::compare_versions(&version_parts, &req_parts) >= 0,
160            _ => Self::compare_versions(&version_parts, &req_parts) == 0,
161        }
162    }
163
164    /// Parse version string into comparable parts
165    fn parse_version(version: &str) -> Vec<u32> {
166        version
167            .split(['.', '-', '_'])
168            .filter_map(|s| {
169                // Extract leading numeric part
170                let num_str: String = s.chars().take_while(|c| c.is_ascii_digit()).collect();
171                num_str.parse().ok()
172            })
173            .collect()
174    }
175
176    /// Compare two parsed versions
177    /// Returns: -1 if a < b, 0 if a == b, 1 if a > b
178    fn compare_versions(a: &[u32], b: &[u32]) -> i32 {
179        let max_len = a.len().max(b.len());
180        for i in 0..max_len {
181            let av = a.get(i).copied().unwrap_or(0);
182            let bv = b.get(i).copied().unwrap_or(0);
183            if av < bv {
184                return -1;
185            }
186            if av > bv {
187                return 1;
188            }
189        }
190        0
191    }
192
193    /// Create findings for matching CVEs
194    pub fn create_findings(
195        &self,
196        vendor: &str,
197        product: &str,
198        version: &str,
199        file_path: &str,
200        line: usize,
201    ) -> Vec<Finding> {
202        let matches = self.check_product(vendor, product, version);
203
204        matches
205            .into_iter()
206            .map(|cve| Finding {
207                id: cve.id.clone(),
208                severity: Self::parse_severity(&cve.severity),
209                category: Category::SupplyChain,
210                confidence: Confidence::Certain,
211                name: cve.title.clone(),
212                location: Location {
213                    file: file_path.to_string(),
214                    line,
215                    column: None,
216                },
217                code: format!("{}/{} v{}", vendor, product, version),
218                message: cve.description.clone(),
219                recommendation: if let Some(ref fixed) = cve
220                    .affected_products
221                    .iter()
222                    .find(|p| {
223                        p.vendor.eq_ignore_ascii_case(vendor)
224                            && p.product.eq_ignore_ascii_case(product)
225                    })
226                    .and_then(|p| p.version_fixed.clone())
227                {
228                    format!("Update to version {} or later", fixed)
229                } else {
230                    "Check for security updates from the vendor".to_string()
231                },
232                fix_hint: None,
233                cwe_ids: cve.cwe_ids.clone(),
234                rule_severity: None,
235                client: None,
236                context: None,
237            })
238            .collect()
239    }
240
241    fn parse_severity(s: &str) -> Severity {
242        match s.to_lowercase().as_str() {
243            "critical" => Severity::Critical,
244            "high" => Severity::High,
245            "medium" => Severity::Medium,
246            "low" => Severity::Low,
247            _ => Severity::Medium,
248        }
249    }
250}
251
252impl Default for CveDatabase {
253    fn default() -> Self {
254        Self::builtin().expect("Built-in CVE database should be valid")
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_load_builtin_database() {
264        let db = CveDatabase::builtin().unwrap();
265        assert!(!db.is_empty());
266        // Version should be a valid semver string (e.g., "1.0.0", "1.0.1")
267        assert!(db.version().starts_with("1."));
268    }
269
270    #[test]
271    fn test_version_comparison_less_than() {
272        assert!(CveDatabase::version_matches("< 1.5.0", "1.4.9"));
273        assert!(CveDatabase::version_matches("< 1.5.0", "1.4.0"));
274        assert!(CveDatabase::version_matches("< 1.5.0", "0.9.0"));
275        assert!(!CveDatabase::version_matches("< 1.5.0", "1.5.0"));
276        assert!(!CveDatabase::version_matches("< 1.5.0", "1.5.1"));
277        assert!(!CveDatabase::version_matches("< 1.5.0", "2.0.0"));
278    }
279
280    #[test]
281    fn test_version_comparison_less_than_or_equal() {
282        assert!(CveDatabase::version_matches("<= 1.5.0", "1.4.9"));
283        assert!(CveDatabase::version_matches("<= 1.5.0", "1.5.0"));
284        assert!(!CveDatabase::version_matches("<= 1.5.0", "1.5.1"));
285    }
286
287    #[test]
288    fn test_version_comparison_greater_than() {
289        assert!(CveDatabase::version_matches("> 1.5.0", "1.5.1"));
290        assert!(CveDatabase::version_matches("> 1.5.0", "2.0.0"));
291        assert!(!CveDatabase::version_matches("> 1.5.0", "1.5.0"));
292        assert!(!CveDatabase::version_matches("> 1.5.0", "1.4.9"));
293    }
294
295    #[test]
296    fn test_version_comparison_equal() {
297        assert!(CveDatabase::version_matches("= 1.5.0", "1.5.0"));
298        assert!(!CveDatabase::version_matches("= 1.5.0", "1.5.1"));
299        assert!(!CveDatabase::version_matches("= 1.5.0", "1.4.9"));
300    }
301
302    #[test]
303    fn test_check_product_matches() {
304        let db = CveDatabase::builtin().unwrap();
305        let matches = db.check_product("anthropic", "claude-code-vscode", "1.4.0");
306        assert!(!matches.is_empty());
307        assert!(matches.iter().any(|e| e.id == "CVE-2025-52882"));
308    }
309
310    #[test]
311    fn test_check_product_no_match_fixed_version() {
312        let db = CveDatabase::builtin().unwrap();
313        let matches = db.check_product("anthropic", "claude-code-vscode", "1.5.0");
314        assert!(matches.is_empty());
315    }
316
317    #[test]
318    fn test_check_product_case_insensitive() {
319        let db = CveDatabase::builtin().unwrap();
320        let matches = db.check_product("Anthropic", "Claude-Code-VSCode", "1.4.0");
321        assert!(!matches.is_empty());
322    }
323
324    #[test]
325    fn test_create_findings() {
326        let db = CveDatabase::builtin().unwrap();
327        let findings = db.create_findings(
328            "anthropic",
329            "claude-code-vscode",
330            "1.4.0",
331            "package.json",
332            10,
333        );
334        assert!(!findings.is_empty());
335
336        let finding = &findings[0];
337        assert_eq!(finding.id, "CVE-2025-52882");
338        assert_eq!(finding.severity, Severity::Critical);
339        assert_eq!(finding.category, Category::SupplyChain);
340        assert!(finding.recommendation.contains("1.5.0"));
341    }
342
343    #[test]
344    fn test_parse_version_with_prerelease() {
345        let parts = CveDatabase::parse_version("1.5.0-beta.1");
346        assert_eq!(parts, vec![1, 5, 0, 1]);
347    }
348
349    #[test]
350    fn test_entry_count() {
351        let db = CveDatabase::builtin().unwrap();
352        // Database should have at least the initial 7 CVEs (may grow over time)
353        assert!(db.len() >= 7);
354    }
355
356    #[test]
357    fn test_updated_at() {
358        let db = CveDatabase::builtin().unwrap();
359        let updated = db.updated_at();
360        // Should be a valid ISO 8601 date string (e.g., "2025-01-26T00:00:00Z")
361        assert!(!updated.is_empty());
362        // Validate year is reasonable (2024-2030)
363        let year: i32 = updated[..4].parse().unwrap_or(0);
364        assert!(
365            (2024..=2030).contains(&year),
366            "Unexpected year in updated_at: {updated}"
367        );
368    }
369
370    #[test]
371    fn test_entries() {
372        let db = CveDatabase::builtin().unwrap();
373        let entries = db.entries();
374        assert!(!entries.is_empty());
375        // First entry should have a CVE ID
376        assert!(entries[0].id.starts_with("CVE-"));
377    }
378
379    #[test]
380    fn test_from_file() {
381        use std::io::Write;
382        use tempfile::NamedTempFile;
383
384        // Create a temporary file with valid CVE database JSON
385        let mut temp_file = NamedTempFile::new().unwrap();
386        let json = r#"{
387            "version": "1.0.0",
388            "updated_at": "2025-01-01",
389            "entries": []
390        }"#;
391        temp_file.write_all(json.as_bytes()).unwrap();
392
393        let db = CveDatabase::from_file(temp_file.path()).unwrap();
394        assert_eq!(db.version(), "1.0.0");
395        assert!(db.is_empty());
396    }
397
398    #[test]
399    fn test_from_file_invalid_path() {
400        let result = CveDatabase::from_file(Path::new("/nonexistent/file.json"));
401        assert!(result.is_err());
402    }
403
404    #[test]
405    fn test_version_comparison_greater_than_or_equal() {
406        // Test >= operator (line 140)
407        assert!(CveDatabase::version_matches(">= 1.5.0", "1.5.0"));
408        assert!(CveDatabase::version_matches(">= 1.5.0", "1.5.1"));
409        assert!(CveDatabase::version_matches(">= 1.5.0", "2.0.0"));
410        assert!(!CveDatabase::version_matches(">= 1.5.0", "1.4.9"));
411        assert!(!CveDatabase::version_matches(">= 1.5.0", "1.4.0"));
412    }
413
414    #[test]
415    fn test_version_comparison_exact_match_no_operator() {
416        // Test default exact match without operator (line 148)
417        assert!(CveDatabase::version_matches("1.5.0", "1.5.0"));
418        assert!(!CveDatabase::version_matches("1.5.0", "1.5.1"));
419        assert!(!CveDatabase::version_matches("1.5.0", "1.4.9"));
420    }
421}