use crate::diff::result::VulnerabilityDetail;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum VulnGroupStatus {
Introduced,
Resolved,
Persistent,
}
impl std::fmt::Display for VulnGroupStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Introduced => write!(f, "Introduced"),
Self::Resolved => write!(f, "Resolved"),
Self::Persistent => write!(f, "Persistent"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VulnerabilityGroup {
pub component_id: String,
pub component_name: String,
pub component_version: Option<String>,
pub vulnerabilities: Vec<VulnerabilityDetail>,
pub max_severity: String,
pub max_cvss: Option<f32>,
pub severity_counts: HashMap<String, usize>,
pub status: VulnGroupStatus,
pub has_kev: bool,
pub has_ransomware_kev: bool,
}
impl VulnerabilityGroup {
#[must_use]
pub fn new(component_id: String, component_name: String, status: VulnGroupStatus) -> Self {
Self {
component_id,
component_name,
component_version: None,
vulnerabilities: Vec::new(),
max_severity: "Unknown".to_string(),
max_cvss: None,
severity_counts: HashMap::new(),
status,
has_kev: false,
has_ransomware_kev: false,
}
}
pub fn add_vulnerability(&mut self, vuln: VulnerabilityDetail) {
*self
.severity_counts
.entry(vuln.severity.clone())
.or_insert(0) += 1;
let vuln_priority = severity_priority(&vuln.severity);
let current_priority = severity_priority(&self.max_severity);
if vuln_priority < current_priority {
self.max_severity.clone_from(&vuln.severity);
}
if let Some(score) = vuln.cvss_score {
self.max_cvss = Some(self.max_cvss.map_or(score, |c| c.max(score)));
}
if self.component_version.is_none() {
self.component_version.clone_from(&vuln.version);
}
self.vulnerabilities.push(vuln);
}
#[must_use]
pub fn vuln_count(&self) -> usize {
self.vulnerabilities.len()
}
#[must_use]
pub fn has_critical(&self) -> bool {
self.severity_counts.get("Critical").copied().unwrap_or(0) > 0
}
#[must_use]
pub fn has_high(&self) -> bool {
self.severity_counts.get("High").copied().unwrap_or(0) > 0
}
#[must_use]
pub fn summary_line(&self) -> String {
let version_str = self
.component_version
.as_ref()
.map(|v| format!("@{v}"))
.unwrap_or_default();
let severity_badges: Vec<String> = ["Critical", "High", "Medium", "Low"]
.iter()
.filter_map(|sev| {
self.severity_counts.get(*sev).and_then(|&count| {
if count > 0 {
Some(format!("{}:{}", &sev[..1], count))
} else {
None
}
})
})
.collect();
format!(
"{}{}: {} CVEs [{}]",
self.component_name,
version_str,
self.vuln_count(),
severity_badges.join(" ")
)
}
}
fn severity_priority(severity: &str) -> u8 {
match severity.to_lowercase().as_str() {
"critical" => 0,
"high" => 1,
"medium" => 2,
"low" => 3,
"info" => 4,
"none" => 5,
_ => 6,
}
}
#[must_use]
pub fn group_vulnerabilities(
vulns: &[VulnerabilityDetail],
status: VulnGroupStatus,
) -> Vec<VulnerabilityGroup> {
let mut groups: HashMap<String, VulnerabilityGroup> = HashMap::new();
for vuln in vulns {
let group = groups.entry(vuln.component_id.clone()).or_insert_with(|| {
VulnerabilityGroup::new(
vuln.component_id.clone(),
vuln.component_name.clone(),
status,
)
});
group.add_vulnerability(vuln.clone());
}
let mut result: Vec<_> = groups.into_values().collect();
result.sort_by(|a, b| {
let sev_cmp = severity_priority(&a.max_severity).cmp(&severity_priority(&b.max_severity));
if sev_cmp == std::cmp::Ordering::Equal {
b.vuln_count().cmp(&a.vuln_count())
} else {
sev_cmp
}
});
result
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct VulnerabilityGroupedView {
pub introduced_groups: Vec<VulnerabilityGroup>,
pub resolved_groups: Vec<VulnerabilityGroup>,
pub persistent_groups: Vec<VulnerabilityGroup>,
}
impl VulnerabilityGroupedView {
#[must_use]
pub fn from_changes(
introduced: &[VulnerabilityDetail],
resolved: &[VulnerabilityDetail],
persistent: &[VulnerabilityDetail],
) -> Self {
Self {
introduced_groups: group_vulnerabilities(introduced, VulnGroupStatus::Introduced),
resolved_groups: group_vulnerabilities(resolved, VulnGroupStatus::Resolved),
persistent_groups: group_vulnerabilities(persistent, VulnGroupStatus::Persistent),
}
}
#[must_use]
pub fn total_groups(&self) -> usize {
self.introduced_groups.len() + self.resolved_groups.len() + self.persistent_groups.len()
}
pub fn total_vulns(&self) -> usize {
self.introduced_groups
.iter()
.map(VulnerabilityGroup::vuln_count)
.sum::<usize>()
+ self
.resolved_groups
.iter()
.map(VulnerabilityGroup::vuln_count)
.sum::<usize>()
+ self
.persistent_groups
.iter()
.map(VulnerabilityGroup::vuln_count)
.sum::<usize>()
}
#[must_use]
pub fn has_any_kev(&self) -> bool {
self.introduced_groups.iter().any(|g| g.has_kev)
|| self.resolved_groups.iter().any(|g| g.has_kev)
|| self.persistent_groups.iter().any(|g| g.has_kev)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_vuln(id: &str, component_id: &str, severity: &str) -> VulnerabilityDetail {
VulnerabilityDetail {
id: id.to_string(),
source: "OSV".to_string(),
severity: severity.to_string(),
cvss_score: None,
component_id: component_id.to_string(),
component_canonical_id: None,
component_ref: None,
component_name: format!("{}-pkg", component_id),
version: Some("1.0.0".to_string()),
description: None,
remediation: None,
is_kev: false,
cwes: Vec::new(),
component_depth: None,
published_date: None,
kev_due_date: None,
days_since_published: None,
days_until_due: None,
vex_state: None,
vex_justification: None,
vex_impact_statement: None,
}
}
#[test]
fn test_group_vulnerabilities() {
let vulns = vec![
make_vuln("CVE-2024-0001", "lodash", "Critical"),
make_vuln("CVE-2024-0002", "lodash", "High"),
make_vuln("CVE-2024-0003", "lodash", "High"),
make_vuln("CVE-2024-0004", "express", "Medium"),
];
let groups = group_vulnerabilities(&vulns, VulnGroupStatus::Introduced);
assert_eq!(groups.len(), 2);
assert_eq!(groups[0].component_id, "lodash");
assert_eq!(groups[0].vuln_count(), 3);
assert_eq!(groups[0].max_severity, "Critical");
assert_eq!(groups[0].severity_counts.get("Critical"), Some(&1));
assert_eq!(groups[0].severity_counts.get("High"), Some(&2));
assert_eq!(groups[1].component_id, "express");
assert_eq!(groups[1].vuln_count(), 1);
}
#[test]
fn test_grouped_view() {
let introduced = vec![
make_vuln("CVE-2024-0001", "lodash", "High"),
make_vuln("CVE-2024-0002", "lodash", "Medium"),
];
let resolved = vec![make_vuln("CVE-2024-0003", "old-dep", "Critical")];
let persistent = vec![];
let view = VulnerabilityGroupedView::from_changes(&introduced, &resolved, &persistent);
assert_eq!(view.total_groups(), 2);
assert_eq!(view.total_vulns(), 3);
assert_eq!(view.introduced_groups.len(), 1);
assert_eq!(view.resolved_groups.len(), 1);
}
#[test]
fn test_summary_line() {
let mut group = VulnerabilityGroup::new(
"lodash".to_string(),
"lodash".to_string(),
VulnGroupStatus::Introduced,
);
group.add_vulnerability(make_vuln("CVE-1", "lodash", "Critical"));
group.add_vulnerability(make_vuln("CVE-2", "lodash", "High"));
group.add_vulnerability(make_vuln("CVE-3", "lodash", "High"));
let summary = group.summary_line();
assert!(summary.contains("lodash"));
assert!(summary.contains("3 CVEs"));
assert!(summary.contains("C:1"));
assert!(summary.contains("H:2"));
}
}