1use std::cmp::Ordering;
20use std::collections::{BTreeMap, BTreeSet};
21use std::path::Path;
22
23use fleetreach_core::{Severity, VulnFinding};
24
25#[derive(Debug, Default, Clone)]
27pub struct Enrichment {
28 pub kev: BTreeSet<String>,
30 pub epss: BTreeMap<String, f32>,
32 pub cvss: BTreeMap<String, f32>,
36}
37
38impl Enrichment {
39 #[cfg(feature = "network")]
44 pub fn fetch(cves: &[String], backfill_cves: &[String]) -> Result<Self, String> {
45 Ok(Self {
46 kev: parse_kev(&net::http_get(net::KEV_URL)?)?,
47 epss: net::fetch_epss(cves)?,
48 cvss: net::fetch_nvd_scores(backfill_cves),
49 })
50 }
51
52 pub fn from_files(kev_path: Option<&Path>, epss_path: Option<&Path>) -> Result<Self, String> {
56 let kev = match kev_path {
57 Some(p) => parse_kev(&read(p)?)?,
58 None => BTreeSet::new(),
59 };
60 let epss = match epss_path {
61 Some(p) => parse_epss_csv(&read(p)?),
62 None => BTreeMap::new(),
63 };
64 Ok(Self {
65 kev,
66 epss,
67 cvss: BTreeMap::new(),
68 })
69 }
70
71 pub fn apply(&self, findings: &mut [VulnFinding]) {
77 for finding in findings {
78 let cves: Vec<&String> = finding
79 .aliases
80 .iter()
81 .filter(|a| a.starts_with("CVE-"))
82 .collect();
83 finding.exploit.kev = cves.iter().any(|c| self.kev.contains(*c));
84 finding.exploit.epss = cves
85 .iter()
86 .filter_map(|c| self.epss.get(*c).copied())
87 .fold(None, |acc, v| Some(acc.map_or(v, |a: f32| a.max(v))));
88
89 if finding.severity == Severity::Unknown {
90 let worst = cves
91 .iter()
92 .filter_map(|c| self.cvss.get(*c).copied())
93 .fold(None, |acc, v| Some(acc.map_or(v, |a: f32| a.max(v))));
94 if let Some(score) = worst {
95 let sev = severity_from_score(f64::from(score));
96 if sev > Severity::Unknown {
97 finding.severity = sev;
98 finding.cvss_score = Some(score);
99 }
100 }
101 }
102 }
103 }
104}
105
106pub fn rank(findings: &mut [VulnFinding]) {
109 findings.sort_by(|a, b| {
110 let ae = a.exploit.epss.unwrap_or(-1.0);
111 let be = b.exploit.epss.unwrap_or(-1.0);
112 b.exploit
113 .kev
114 .cmp(&a.exploit.kev)
115 .then(be.partial_cmp(&ae).unwrap_or(Ordering::Equal))
116 .then(b.severity.cmp(&a.severity))
117 .then(a.advisory_id.cmp(&b.advisory_id))
118 });
119}
120
121fn severity_from_score(score: f64) -> Severity {
125 match score {
126 s if s >= 9.0 => Severity::Critical,
127 s if s >= 7.0 => Severity::High,
128 s if s >= 4.0 => Severity::Medium,
129 s if s > 0.0 => Severity::Low,
130 _ => Severity::Unknown,
131 }
132}
133
134fn parse_kev(body: &str) -> Result<BTreeSet<String>, String> {
135 let value: serde_json::Value =
136 serde_json::from_str(body).map_err(|e| format!("KEV JSON: {e}"))?;
137 let entries = value
138 .get("vulnerabilities")
139 .and_then(|v| v.as_array())
140 .ok_or("KEV JSON missing `vulnerabilities` array")?;
141 Ok(entries
142 .iter()
143 .filter_map(|e| e.get("cveID").and_then(|c| c.as_str()))
144 .map(String::from)
145 .collect())
146}
147
148#[cfg(feature = "network")]
154pub(crate) fn is_valid_cve(s: &str) -> bool {
155 let Some(rest) = s.strip_prefix("CVE-") else {
156 return false;
157 };
158 let mut parts = rest.splitn(2, '-');
159 let year = parts.next().unwrap_or_default();
160 let seq = parts.next().unwrap_or_default();
161 year.len() == 4
162 && year.bytes().all(|b| b.is_ascii_digit())
163 && seq.len() >= 4
164 && seq.bytes().all(|b| b.is_ascii_digit())
165}
166
167fn parse_epss_csv(body: &str) -> BTreeMap<String, f32> {
169 let mut out = BTreeMap::new();
170 for line in body.lines() {
171 if line.starts_with('#') || line.starts_with("cve") {
172 continue;
173 }
174 let mut parts = line.split(',');
175 if let (Some(cve), Some(score)) = (parts.next(), parts.next()) {
176 if let Ok(score) = score.trim().parse::<f32>() {
177 out.insert(cve.trim().to_string(), score);
178 }
179 }
180 }
181 out
182}
183
184fn read(path: &Path) -> Result<String, String> {
185 std::fs::read_to_string(path).map_err(|e| format!("reading {}: {e}", path.display()))
186}
187
188#[cfg(feature = "network")]
192mod net {
193 use std::collections::BTreeMap;
194 use std::sync::OnceLock;
195 use std::time::Duration;
196
197 pub(super) const KEV_URL: &str =
198 "https://www.cisa.gov/sites/default/files/feeds/known_exploited_vulnerabilities.json";
199 const EPSS_API: &str = "https://api.first.org/data/v1/epss";
200 const NVD_API: &str = "https://services.nvd.nist.gov/rest/json/cves/2.0";
201
202 fn agent() -> &'static ureq::Agent {
206 static AGENT: OnceLock<ureq::Agent> = OnceLock::new();
207 AGENT.get_or_init(|| {
208 ureq::AgentBuilder::new()
209 .timeout_connect(Duration::from_secs(10))
210 .timeout_read(Duration::from_secs(30))
211 .timeout_write(Duration::from_secs(30))
212 .timeout(Duration::from_secs(60))
213 .build()
214 })
215 }
216
217 pub(super) fn http_get(url: &str) -> Result<String, String> {
218 agent()
219 .get(url)
220 .call()
221 .map_err(|e| format!("GET {url}: {e}"))?
222 .into_string()
223 .map_err(|e| format!("reading {url}: {e}"))
224 }
225
226 pub(super) fn fetch_epss(cves: &[String]) -> Result<BTreeMap<String, f32>, String> {
227 let mut out = BTreeMap::new();
228 let cve_ids: Vec<&str> = cves
229 .iter()
230 .filter(|c| super::is_valid_cve(c))
231 .map(String::as_str)
232 .collect();
233 for chunk in cve_ids.chunks(100) {
235 if chunk.is_empty() {
236 continue;
237 }
238 let url = format!("{EPSS_API}?cve={}", chunk.join(","));
239 merge_epss_json(&http_get(&url)?, &mut out)?;
240 }
241 Ok(out)
242 }
243
244 fn merge_epss_json(body: &str, out: &mut BTreeMap<String, f32>) -> Result<(), String> {
245 let value: serde_json::Value =
246 serde_json::from_str(body).map_err(|e| format!("EPSS JSON: {e}"))?;
247 if let Some(rows) = value.get("data").and_then(|d| d.as_array()) {
248 for row in rows {
249 if let (Some(cve), Some(score)) = (
250 row.get("cve").and_then(|c| c.as_str()),
251 row.get("epss").and_then(|s| s.as_str()),
252 ) {
253 if let Ok(score) = score.parse::<f32>() {
254 out.insert(cve.to_string(), score);
255 }
256 }
257 }
258 }
259 Ok(())
260 }
261
262 pub(super) fn fetch_nvd_scores(cves: &[String]) -> BTreeMap<String, f32> {
269 let mut out = BTreeMap::new();
270 let api_key = std::env::var("NVD_API_KEY").ok();
271 let cve_ids: Vec<&String> = cves.iter().filter(|c| super::is_valid_cve(c)).collect();
272 for (i, cve) in cve_ids.iter().enumerate() {
273 if i > 0 {
276 let delay = if api_key.is_some() { 700 } else { 6000 };
277 std::thread::sleep(Duration::from_millis(delay));
278 }
279 let url = format!("{NVD_API}?cveId={cve}");
280 if let Ok(body) = nvd_get(&url, api_key.as_deref()) {
281 if let Some(score) = parse_nvd_score(&body, cve) {
282 out.insert((*cve).clone(), score);
283 }
284 }
285 }
286 out
287 }
288
289 fn nvd_get(url: &str, api_key: Option<&str>) -> Result<String, String> {
290 let mut req = agent().get(url);
291 if let Some(key) = api_key {
292 req = req.set("apiKey", key);
293 }
294 req.call()
295 .map_err(|e| format!("GET {url}: {e}"))?
296 .into_string()
297 .map_err(|e| format!("reading {url}: {e}"))
298 }
299
300 pub(super) fn parse_nvd_score(body: &str, cve: &str) -> Option<f32> {
304 let value: serde_json::Value = serde_json::from_str(body).ok()?;
305 let metrics = value
306 .get("vulnerabilities")?
307 .as_array()?
308 .iter()
309 .find(|v| v.pointer("/cve/id").and_then(|id| id.as_str()) == Some(cve))
310 .and_then(|v| v.pointer("/cve/metrics"))?;
311 [
312 "cvssMetricV40",
313 "cvssMetricV31",
314 "cvssMetricV30",
315 "cvssMetricV2",
316 ]
317 .iter()
318 .find_map(|key| {
319 metrics
320 .get(key)
321 .and_then(|m| m.as_array())
322 .and_then(|arr| arr.first())
323 .and_then(|m| m.pointer("/cvssData/baseScore"))
324 .and_then(serde_json::Value::as_f64)
325 })
326 .map(|s| s as f32)
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[cfg(feature = "network")]
336 const NVD_BODY: &str = r#"{
337 "vulnerabilities": [{
338 "cve": {
339 "id": "CVE-2022-0778",
340 "metrics": {
341 "cvssMetricV31": [{ "cvssData": { "baseScore": 7.5 } }],
342 "cvssMetricV2": [{ "cvssData": { "baseScore": 5.0 } }]
343 }
344 }
345 }]
346 }"#;
347
348 #[cfg(feature = "network")]
349 #[test]
350 fn parses_nvd_score_prefers_v31() {
351 assert_eq!(net::parse_nvd_score(NVD_BODY, "CVE-2022-0778"), Some(7.5));
352 }
353
354 #[cfg(feature = "network")]
355 #[test]
356 fn nvd_score_falls_back_to_older_cvss_versions() {
357 let body = r#"{"vulnerabilities":[{"cve":{"id":"CVE-1","metrics":{
358 "cvssMetricV2":[{"cvssData":{"baseScore":9.1}}]}}}]}"#;
359 assert_eq!(net::parse_nvd_score(body, "CVE-1"), Some(9.1));
360 }
361
362 #[cfg(feature = "network")]
363 #[test]
364 fn nvd_score_reads_cvss_v40() {
365 let body = r#"{"vulnerabilities":[{"cve":{"id":"CVE-1","metrics":{
367 "cvssMetricV40":[{"cvssData":{"baseScore":6.3}}]}}}]}"#;
368 assert_eq!(net::parse_nvd_score(body, "CVE-1"), Some(6.3));
369 }
370
371 #[cfg(feature = "network")]
372 #[test]
373 fn nvd_score_none_when_cve_absent_or_unscored() {
374 assert_eq!(net::parse_nvd_score(NVD_BODY, "CVE-9999-9999"), None);
376 let empty = r#"{"vulnerabilities":[{"cve":{"id":"CVE-1","metrics":{}}}]}"#;
378 assert_eq!(net::parse_nvd_score(empty, "CVE-1"), None);
379 }
380
381 #[cfg(feature = "network")]
382 #[test]
383 fn cve_validation_blocks_url_injection() {
384 assert!(is_valid_cve("CVE-2022-0778"));
385 assert!(is_valid_cve("CVE-2026-12345678"));
386 assert!(!is_valid_cve("CVE-2022-0778&inject=1"));
388 assert!(!is_valid_cve("CVE-2022-0778 OR 1=1"));
389 assert!(!is_valid_cve("CVE-22-1"));
390 assert!(!is_valid_cve("CVE-2022-abc"));
391 assert!(!is_valid_cve("GHSA-xxxx"));
392 assert!(!is_valid_cve("CVE-"));
393 }
394
395 #[test]
396 fn score_bands_match_cvss_v3() {
397 assert_eq!(severity_from_score(0.0), Severity::Unknown);
398 assert_eq!(severity_from_score(3.9), Severity::Low);
399 assert_eq!(severity_from_score(4.0), Severity::Medium);
400 assert_eq!(severity_from_score(7.0), Severity::High);
401 assert_eq!(severity_from_score(9.0), Severity::Critical);
402 }
403}