use crate::error::{AdvisoryError, Result};
use chrono::{DateTime, NaiveDate, Utc};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
use tracing::{debug, info};
pub const EPSS_API_URL: &str = "https://api.first.org/data/v1/epss";
const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
pub struct EpssSource {
client: ClientWithMiddleware,
}
impl EpssSource {
pub fn new() -> Self {
let raw_client = reqwest::Client::builder()
.timeout(REQUEST_TIMEOUT)
.connect_timeout(CONNECT_TIMEOUT)
.build()
.unwrap_or_default();
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
let client = ClientBuilder::new(raw_client)
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();
Self { client }
}
pub async fn fetch_scores(&self, cve_ids: &[&str]) -> Result<HashMap<String, EpssScore>> {
if cve_ids.is_empty() {
return Ok(HashMap::new());
}
let cve_param = cve_ids.join(",");
let url = format!("{}?cve={}", EPSS_API_URL, cve_param);
debug!("Fetching EPSS scores for {} CVEs", cve_ids.len());
let response = self.client.get(&url).send().await?;
if !response.status().is_success() {
return Err(AdvisoryError::source_fetch(
"EPSS",
format!("HTTP {}", response.status()),
));
}
let epss_response: EpssResponse = response.json().await?;
let scores: HashMap<String, EpssScore> = epss_response
.data
.into_iter()
.map(|s| (s.cve.clone(), s))
.collect();
debug!("Retrieved {} EPSS scores", scores.len());
Ok(scores)
}
pub async fn fetch_score(&self, cve_id: &str) -> Result<Option<EpssScore>> {
let scores = self.fetch_scores(&[cve_id]).await?;
Ok(scores.get(cve_id).cloned())
}
pub async fn fetch_high_risk(
&self,
min_epss: f64,
limit: Option<u32>,
) -> Result<Vec<EpssScore>> {
let limit = limit.unwrap_or(100);
let url = format!("{}?epss-gt={}&limit={}", EPSS_API_URL, min_epss, limit);
info!("Fetching CVEs with EPSS > {}", min_epss);
let response = self.client.get(&url).send().await?;
if !response.status().is_success() {
return Err(AdvisoryError::source_fetch(
"EPSS",
format!("HTTP {}", response.status()),
));
}
let epss_response: EpssResponse = response.json().await?;
info!("Found {} high-risk CVEs", epss_response.data.len());
Ok(epss_response.data)
}
pub async fn fetch_top_percentile(
&self,
min_percentile: f64,
limit: Option<u32>,
) -> Result<Vec<EpssScore>> {
let limit = limit.unwrap_or(100);
let url = format!(
"{}?percentile-gt={}&limit={}",
EPSS_API_URL, min_percentile, limit
);
info!(
"Fetching CVEs in top {} percentile",
(1.0 - min_percentile) * 100.0
);
let response = self.client.get(&url).send().await?;
if !response.status().is_success() {
return Err(AdvisoryError::source_fetch(
"EPSS",
format!("HTTP {}", response.status()),
));
}
let epss_response: EpssResponse = response.json().await?;
Ok(epss_response.data)
}
pub async fn fetch_scores_batch(
&self,
cve_ids: &[String],
batch_size: usize,
) -> Result<HashMap<String, EpssScore>> {
let mut all_scores = HashMap::new();
for chunk in cve_ids.chunks(batch_size) {
let refs: Vec<&str> = chunk.iter().map(|s| s.as_str()).collect();
let scores = self.fetch_scores(&refs).await?;
all_scores.extend(scores);
}
Ok(all_scores)
}
}
impl Default for EpssSource {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct EpssResponse {
pub status: String,
#[serde(rename = "status-code")]
pub status_code: Option<i32>,
pub version: Option<String>,
pub total: Option<u64>,
pub offset: Option<u64>,
pub limit: Option<u64>,
pub data: Vec<EpssScore>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EpssScore {
pub cve: String,
#[serde(deserialize_with = "deserialize_f64_from_string")]
pub epss: f64,
#[serde(deserialize_with = "deserialize_f64_from_string")]
pub percentile: f64,
#[serde(default)]
pub date: Option<String>,
}
impl EpssScore {
pub fn is_top_percentile(&self, threshold: f64) -> bool {
self.percentile >= threshold
}
pub fn risk_category(&self) -> EpssRiskCategory {
match self.epss {
s if s >= 0.7 => EpssRiskCategory::Critical,
s if s >= 0.4 => EpssRiskCategory::High,
s if s >= 0.1 => EpssRiskCategory::Medium,
_ => EpssRiskCategory::Low,
}
}
pub fn date_utc(&self) -> Option<DateTime<Utc>> {
self.date.as_ref().and_then(|d| {
NaiveDate::parse_from_str(d, "%Y-%m-%d")
.ok()
.map(|nd| nd.and_hms_opt(0, 0, 0).unwrap().and_utc())
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EpssRiskCategory {
Low,
Medium,
High,
Critical,
}
fn deserialize_f64_from_string<'de, D>(deserializer: D) -> std::result::Result<f64, D::Error>
where
D: serde::Deserializer<'de>,
{
let s: String = String::deserialize(deserializer)?;
s.parse().map_err(serde::de::Error::custom)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_epss_risk_category() {
let score = EpssScore {
cve: "CVE-2024-1234".to_string(),
epss: 0.75,
percentile: 0.98,
date: None,
};
assert_eq!(score.risk_category(), EpssRiskCategory::Critical);
assert!(score.is_top_percentile(0.95));
}
#[test]
fn test_epss_low_risk() {
let score = EpssScore {
cve: "CVE-2024-5678".to_string(),
epss: 0.05,
percentile: 0.3,
date: None,
};
assert_eq!(score.risk_category(), EpssRiskCategory::Low);
assert!(!score.is_top_percentile(0.95));
}
}