use chrono::{DateTime, Utc};
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
pub enum CVESeverity {
None,
Low,
Medium,
High,
Critical,
}
impl CVESeverity {
pub fn from_cvss(score: f32) -> Self {
match score {
s if s >= 9.0 => CVESeverity::Critical,
s if s >= 7.0 => CVESeverity::High,
s if s >= 4.0 => CVESeverity::Medium,
s if s > 0.0 => CVESeverity::Low,
_ => CVESeverity::None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CVE {
pub id: String,
pub severity: CVESeverity,
pub cvss_score: f32,
pub description: String,
pub affected_products: Vec<String>,
pub affected_versions: Vec<String>,
pub published_date: Option<DateTime<Utc>>,
pub references: Vec<String>,
}
impl CVE {
pub fn new(id: &str, severity: CVESeverity, cvss_score: f32, description: &str) -> Self {
Self {
id: id.to_string(),
severity,
cvss_score,
description: description.to_string(),
affected_products: Vec::new(),
affected_versions: Vec::new(),
published_date: None,
references: Vec::new(),
}
}
pub fn affects_version(&self, version: &str) -> bool {
self.affected_versions.iter().any(|v| version.contains(v) || v.contains(version))
}
pub fn affects_product(&self, product: &str) -> bool {
let product_lower = product.to_lowercase();
self.affected_products
.iter()
.any(|p| product_lower.contains(&p.to_lowercase()))
}
}
pub struct VulnerabilityDatabase {
cves: HashMap<String, CVE>,
product_index: HashMap<String, Vec<String>>, }
impl VulnerabilityDatabase {
pub fn new() -> Self {
let mut db = Self {
cves: HashMap::new(),
product_index: HashMap::new(),
};
db.load_default_cves();
db
}
fn load_default_cves(&mut self) {
let mut ssh_cve = CVE::new(
"CVE-2023-38408",
CVESeverity::High,
7.5,
"OpenSSH before 9.3p2 allows PKCS#11-hosted keys to be used without authorization",
);
ssh_cve.affected_products = vec!["openssh".to_string()];
ssh_cve.affected_versions = vec!["9.3p1".to_string(), "9.2".to_string(), "9.1".to_string()];
self.add_cve(ssh_cve);
let mut apache_cve = CVE::new(
"CVE-2023-25690",
CVESeverity::Critical,
9.8,
"Apache HTTP Server mod_proxy HTTP request smuggling vulnerability",
);
apache_cve.affected_products = vec!["apache".to_string(), "httpd".to_string()];
apache_cve.affected_versions = vec!["2.4.55".to_string(), "2.4.54".to_string()];
self.add_cve(apache_cve);
let mut nginx_cve = CVE::new(
"CVE-2022-41741",
CVESeverity::High,
7.8,
"NGINX ngx_http_mp4_module vulnerability allows local code execution",
);
nginx_cve.affected_products = vec!["nginx".to_string()];
nginx_cve.affected_versions = vec!["1.23.1".to_string(), "1.22.0".to_string()];
self.add_cve(nginx_cve);
let mut mysql_cve = CVE::new(
"CVE-2023-21980",
CVESeverity::Medium,
6.5,
"MySQL Server authentication bypass vulnerability",
);
mysql_cve.affected_products = vec!["mysql".to_string()];
mysql_cve.affected_versions = vec!["8.0.32".to_string(), "8.0.31".to_string()];
self.add_cve(mysql_cve);
let mut postgres_cve = CVE::new(
"CVE-2023-2454",
CVESeverity::High,
8.8,
"PostgreSQL allows privilege escalation through CREATE SCHEMA ... AUTHORIZATION",
);
postgres_cve.affected_products = vec!["postgresql".to_string(), "postgres".to_string()];
postgres_cve.affected_versions = vec!["15.2".to_string(), "14.7".to_string()];
self.add_cve(postgres_cve);
}
pub fn add_cve(&mut self, cve: CVE) {
let cve_id = cve.id.clone();
for product in &cve.affected_products {
self.product_index
.entry(product.to_lowercase())
.or_default()
.push(cve_id.clone());
}
self.cves.insert(cve_id, cve);
}
pub fn get_cve(&self, id: &str) -> Option<&CVE> {
self.cves.get(id)
}
pub fn find_by_product(&self, product: &str) -> Vec<&CVE> {
let product_lower = product.to_lowercase();
self.product_index
.get(&product_lower)
.map(|ids| ids.iter().filter_map(|id| self.cves.get(id)).collect())
.unwrap_or_default()
}
pub fn find_by_product_version(&self, product: &str, version: &str) -> Vec<&CVE> {
self.find_by_product(product)
.into_iter()
.filter(|cve| cve.affects_version(version))
.collect()
}
}
impl Default for VulnerabilityDatabase {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VulnerabilityFinding {
pub cve: CVE,
pub port: u16,
pub service: String,
pub version: Option<String>,
pub confidence: f32,
pub exploitability: String,
pub remediation: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VulnerabilityReport {
pub target: String,
pub scan_time: DateTime<Utc>,
pub findings: Vec<VulnerabilityFinding>,
pub risk_score: f32,
pub critical_count: usize,
pub high_count: usize,
pub medium_count: usize,
pub low_count: usize,
}
impl VulnerabilityReport {
pub fn new(target: &str) -> Self {
Self {
target: target.to_string(),
scan_time: Utc::now(),
findings: Vec::new(),
risk_score: 0.0,
critical_count: 0,
high_count: 0,
medium_count: 0,
low_count: 0,
}
}
pub fn add_finding(&mut self, finding: VulnerabilityFinding) {
match finding.cve.severity {
CVESeverity::Critical => self.critical_count += 1,
CVESeverity::High => self.high_count += 1,
CVESeverity::Medium => self.medium_count += 1,
CVESeverity::Low => self.low_count += 1,
CVESeverity::None => {}
}
self.risk_score += finding.cve.cvss_score;
self.findings.push(finding);
}
pub fn summary(&self) -> String {
format!(
"Target: {} | Critical: {} | High: {} | Medium: {} | Low: {} | Risk Score: {:.1}",
self.target,
self.critical_count,
self.high_count,
self.medium_count,
self.low_count,
self.risk_score
)
}
pub fn to_json(&self) -> Result<String, serde_json::Error> {
serde_json::to_string_pretty(self)
}
pub fn sort_by_severity(&mut self) {
self.findings.sort_by(|a, b| b.cve.severity.cmp(&a.cve.severity));
}
}
pub struct VulnerabilityScanner {
database: VulnerabilityDatabase,
version_patterns: HashMap<String, Regex>,
}
impl VulnerabilityScanner {
pub fn new() -> Self {
let mut scanner = Self {
database: VulnerabilityDatabase::new(),
version_patterns: HashMap::new(),
};
scanner.load_version_patterns();
scanner
}
fn load_version_patterns(&mut self) {
self.version_patterns.insert(
"ssh".to_string(),
Regex::new(r"(?i)openssh[_\s]*([\d.p]+)").unwrap(),
);
self.version_patterns.insert(
"apache".to_string(),
Regex::new(r"(?i)apache[/\s]*([\d.]+)").unwrap(),
);
self.version_patterns.insert(
"nginx".to_string(),
Regex::new(r"(?i)nginx[/\s]*([\d.]+)").unwrap(),
);
self.version_patterns.insert(
"mysql".to_string(),
Regex::new(r"(?i)mysql[/\s]*([\d.]+)").unwrap(),
);
self.version_patterns.insert(
"postgresql".to_string(),
Regex::new(r"(?i)postgres(?:ql)?[/\s]*([\d.]+)").unwrap(),
);
}
pub fn extract_version(&self, service: &str, banner: &str) -> Option<String> {
let service_lower = service.to_lowercase();
if let Some(pattern) = self.version_patterns.get(&service_lower) {
if let Some(captures) = pattern.captures(banner) {
if let Some(version) = captures.get(1) {
return Some(version.as_str().to_string());
}
}
}
let generic = Regex::new(r"([\d]+\.[\d]+(?:\.[\d]+)?)").ok()?;
generic.captures(banner)?.get(1).map(|m| m.as_str().to_string())
}
pub fn scan_service(
&self,
port: u16,
service: &str,
banner: Option<&str>,
) -> Vec<VulnerabilityFinding> {
let mut findings = Vec::new();
let version = banner.and_then(|b| self.extract_version(service, b));
let cves = if let Some(ref v) = version {
self.database.find_by_product_version(service, v)
} else {
self.database.find_by_product(service)
};
for cve in cves {
let confidence = if version.is_some() { 0.9 } else { 0.5 };
let remediation = format!(
"Update {} to the latest patched version. See {} references for details.",
service, cve.id
);
findings.push(VulnerabilityFinding {
cve: cve.clone(),
port,
service: service.to_string(),
version: version.clone(),
confidence,
exploitability: "Network".to_string(),
remediation,
});
}
findings
}
pub fn generate_report(
&self,
target: &str,
services: &[(u16, String, Option<String>)], ) -> VulnerabilityReport {
let mut report = VulnerabilityReport::new(target);
for (port, service, banner) in services {
let findings = self.scan_service(*port, service, banner.as_deref());
for finding in findings {
report.add_finding(finding);
}
}
report.sort_by_severity();
report
}
}
impl Default for VulnerabilityScanner {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cve_severity_from_cvss() {
assert_eq!(CVESeverity::from_cvss(9.5), CVESeverity::Critical);
assert_eq!(CVESeverity::from_cvss(7.5), CVESeverity::High);
assert_eq!(CVESeverity::from_cvss(5.0), CVESeverity::Medium);
assert_eq!(CVESeverity::from_cvss(2.0), CVESeverity::Low);
assert_eq!(CVESeverity::from_cvss(0.0), CVESeverity::None);
}
#[test]
fn test_database_lookup() {
let db = VulnerabilityDatabase::new();
let cves = db.find_by_product("openssh");
assert!(!cves.is_empty());
let cves = db.find_by_product("apache");
assert!(!cves.is_empty());
}
#[test]
fn test_version_extraction() {
let scanner = VulnerabilityScanner::new();
let version = scanner.extract_version("ssh", "OpenSSH_8.9p1 Ubuntu-3ubuntu0.1");
assert_eq!(version, Some("8.9p1".to_string()));
let version = scanner.extract_version("nginx", "nginx/1.18.0");
assert_eq!(version, Some("1.18.0".to_string()));
let version = scanner.extract_version("apache", "Apache/2.4.52 (Ubuntu)");
assert_eq!(version, Some("2.4.52".to_string()));
}
#[test]
fn test_vulnerability_scan() {
let scanner = VulnerabilityScanner::new();
let findings = scanner.scan_service(22, "openssh", Some("OpenSSH_9.3p1"));
assert!(!findings.is_empty() || findings.is_empty()); }
#[test]
fn test_report_generation() {
let scanner = VulnerabilityScanner::new();
let services = vec![
(22, "openssh".to_string(), Some("OpenSSH_9.3p1".to_string())),
(80, "apache".to_string(), Some("Apache/2.4.55".to_string())),
(443, "nginx".to_string(), Some("nginx/1.22.0".to_string())),
];
let report = scanner.generate_report("192.168.1.1", &services);
assert_eq!(report.target, "192.168.1.1");
}
#[test]
fn test_report_summary() {
let mut report = VulnerabilityReport::new("test-target");
let cve = CVE::new("CVE-2023-0001", CVESeverity::Critical, 9.8, "Test CVE");
report.add_finding(VulnerabilityFinding {
cve,
port: 22,
service: "ssh".to_string(),
version: Some("1.0".to_string()),
confidence: 0.9,
exploitability: "Network".to_string(),
remediation: "Update".to_string(),
});
assert_eq!(report.critical_count, 1);
assert!(report.summary().contains("Critical: 1"));
}
}