use crate::error::{Error, Result};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use std::time::Duration;
use crate::http::{TIMEOUT_SECS, USER_AGENT};
const WPVULN_API: &str = "https://www.wpvulnerability.net";
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum Severity {
#[default]
Low,
Medium,
High,
Critical,
}
impl Severity {
pub fn from_cvss(score: f32) -> Self {
match score {
s if s >= 9.0 => Severity::Critical,
s if s >= 7.0 => Severity::High,
s if s >= 4.0 => Severity::Medium,
_ => Severity::Low,
}
}
}
impl FromStr for Severity {
type Err = Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"low" => Ok(Self::Low),
"medium" => Ok(Self::Medium),
"high" => Ok(Self::High),
"critical" => Ok(Self::Critical),
_ => Err(Error::InvalidSeverity(s.to_string())),
}
}
}
impl std::fmt::Display for Severity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Severity::Low => write!(f, "Low"),
Severity::Medium => write!(f, "Medium"),
Severity::High => write!(f, "High"),
Severity::Critical => write!(f, "Critical"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Vulnerability {
pub id: String,
pub title: String,
pub severity: Severity,
pub cvss_score: Option<f32>,
pub affected_max: Option<String>,
pub fixed_in: Option<String>,
pub references: Vec<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct VulnerabilityReport {
pub vulnerabilities: Vec<Vulnerability>,
}
impl VulnerabilityReport {
pub fn is_empty(&self) -> bool {
self.vulnerabilities.is_empty()
}
pub fn max_severity(&self) -> Option<Severity> {
self.vulnerabilities.iter().map(|v| v.severity).max()
}
pub fn count_by_severity(&self, severity: Severity) -> usize {
self.vulnerabilities
.iter()
.filter(|v| v.severity == severity)
.count()
}
pub fn filter_by_version(&self, version: Option<&str>) -> Self {
let Some(ver) = version else {
return self.clone();
};
let filtered: Vec<_> = self
.vulnerabilities
.iter()
.filter(|v| version_is_affected(ver, v.affected_max.as_deref()))
.cloned()
.collect();
Self {
vulnerabilities: filtered,
}
}
}
fn version_is_affected(installed: &str, affected_max: Option<&str>) -> bool {
let Some(max) = affected_max else {
return true;
};
if let (Ok(installed_ver), Ok(max_ver)) = (
semver::Version::parse(installed),
semver::Version::parse(max),
) {
return installed_ver <= max_ver;
}
installed <= max
}
#[derive(Debug, Deserialize)]
struct WpVulnApiResponse {
error: i32,
#[allow(dead_code)]
message: Option<String>,
data: Option<WpVulnData>,
}
#[derive(Debug, Deserialize)]
struct WpVulnData {
#[allow(dead_code)]
name: Option<String>,
vulnerability: Option<Vec<WpVulnEntry>>,
}
#[derive(Debug, Deserialize)]
struct WpVulnEntry {
uuid: String,
name: Option<String>,
operator: Option<WpVulnOperator>,
source: Option<Vec<WpVulnSource>>,
impact: Option<WpVulnImpact>,
}
#[derive(Debug, Deserialize)]
struct WpVulnOperator {
max_version: Option<String>,
#[allow(dead_code)]
min_version: Option<String>,
}
#[derive(Debug, Deserialize)]
struct WpVulnSource {
#[serde(rename = "id")]
source_id: Option<String>,
#[allow(dead_code)]
name: Option<String>,
link: Option<String>,
description: Option<String>,
}
#[derive(Debug, Deserialize, Default)]
#[serde(default)]
struct WpVulnImpact {
cvss: Option<WpVulnCvss>,
#[allow(dead_code)]
cwe: Option<Vec<WpVulnCwe>>,
}
#[derive(Debug, Deserialize)]
struct WpVulnCvss {
#[serde(deserialize_with = "deserialize_score")]
score: Option<f32>,
}
fn deserialize_score<'de, D>(deserializer: D) -> std::result::Result<Option<f32>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
#[derive(Deserialize)]
#[serde(untagged)]
enum StringOrNumber {
String(String),
Number(f32),
}
match Option::<StringOrNumber>::deserialize(deserializer)? {
Some(StringOrNumber::String(s)) => s.parse::<f32>().map(Some).map_err(D::Error::custom),
Some(StringOrNumber::Number(n)) => Ok(Some(n)),
None => Ok(None),
}
}
#[derive(Debug, Deserialize)]
struct WpVulnCwe {
#[allow(dead_code)]
cwe: Option<String>,
}
pub struct VulnerabilityClient {
client: Client,
}
impl VulnerabilityClient {
pub fn new() -> Result<Self> {
let client = Client::builder()
.user_agent(USER_AGENT)
.timeout(Duration::from_secs(TIMEOUT_SECS))
.build()
.map_err(|e| Error::HttpClient(e.to_string()))?;
Ok(Self { client })
}
pub async fn fetch_core_vulns(&self, version: &str) -> Option<VulnerabilityReport> {
let url = format!("{}/core/{}/", WPVULN_API, version);
self.fetch_vulns(&url).await
}
pub async fn fetch_plugin_vulns(&self, slug: &str) -> Option<VulnerabilityReport> {
let encoded_slug = urlencoding::encode(slug);
let url = format!("{}/plugin/{}/", WPVULN_API, encoded_slug);
self.fetch_vulns(&url).await
}
pub async fn fetch_theme_vulns(&self, slug: &str) -> Option<VulnerabilityReport> {
let encoded_slug = urlencoding::encode(slug);
let url = format!("{}/theme/{}/", WPVULN_API, encoded_slug);
self.fetch_vulns(&url).await
}
async fn fetch_vulns(&self, url: &str) -> Option<VulnerabilityReport> {
let response = self.client.get(url).send().await.ok()?;
let body = response.text().await.ok()?;
let api_response: WpVulnApiResponse = serde_json::from_str(&body).ok()?;
if api_response.error != 0 {
return None;
}
let data = api_response.data?;
let vulns = data.vulnerability.unwrap_or_default();
let vulnerabilities: Vec<Vulnerability> = vulns
.into_iter()
.map(|entry| self.convert_entry(entry))
.collect();
Some(VulnerabilityReport { vulnerabilities })
}
fn convert_entry(&self, entry: WpVulnEntry) -> Vulnerability {
let cvss_score = entry.impact.and_then(|i| i.cvss).and_then(|c| c.score);
let severity = cvss_score
.map(Severity::from_cvss)
.unwrap_or(Severity::Medium);
let affected_max = entry.operator.and_then(|o| o.max_version);
let (id, references, title) = if let Some(ref sources) = entry.source {
let cve = sources
.iter()
.find(|s| {
s.source_id
.as_ref()
.is_some_and(|id| id.starts_with("CVE-"))
})
.and_then(|s| s.source_id.clone());
let refs: Vec<String> = sources.iter().filter_map(|s| s.link.clone()).collect();
let desc = entry.name.clone().unwrap_or_else(|| {
sources
.first()
.and_then(|s| s.description.clone())
.unwrap_or_else(|| "Unknown vulnerability".to_string())
});
(cve.unwrap_or_else(|| entry.uuid.clone()), refs, desc)
} else {
let desc = entry
.name
.unwrap_or_else(|| "Unknown vulnerability".to_string());
(entry.uuid, Vec::new(), desc)
};
Vulnerability {
id,
title,
severity,
cvss_score,
affected_max,
fixed_in: None, references,
}
}
}
impl Default for VulnerabilityClient {
fn default() -> Self {
Self::new().expect("Failed to create vulnerability client")
}
}