use crate::ports::outbound::VulnerabilityRepository;
use crate::sbom_generation::domain::{Package, PackageVulnerabilities};
use crate::shared::Result;
use indicatif::{ProgressBar, ProgressStyle};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
pub struct CheckVulnerabilitiesUseCase<R: VulnerabilityRepository> {
vulnerability_repository: R,
}
impl<R: VulnerabilityRepository> CheckVulnerabilitiesUseCase<R> {
pub fn new(vulnerability_repository: R) -> Self {
Self {
vulnerability_repository,
}
}
pub async fn check_with_progress(
&self,
packages: Vec<Package>,
) -> Result<Vec<PackageVulnerabilities>> {
let progress_current = Arc::new(AtomicUsize::new(0));
let progress_total = Arc::new(AtomicUsize::new(0));
let is_done = Arc::new(AtomicBool::new(false));
let current_clone = progress_current.clone();
let total_clone = progress_total.clone();
let done_clone = is_done.clone();
let progress_handle = thread::spawn(move || {
let pb = ProgressBar::new(0);
pb.set_style(
ProgressStyle::default_bar()
.template(" {spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} - {msg}")
.expect("Failed to set progress bar template")
.progress_chars("=>-"),
);
pb.set_message("Fetching vulnerability details...");
while !done_clone.load(Ordering::Relaxed) {
let current = current_clone.load(Ordering::Relaxed);
let total = total_clone.load(Ordering::Relaxed);
if total > 0 {
pb.set_length(total as u64);
pb.set_position(current as u64);
} else {
pb.tick();
}
thread::sleep(Duration::from_millis(50));
}
pb.finish_and_clear();
});
let progress_callback: Box<dyn Fn(usize, usize) + Send> =
Box::new(move |current: usize, total: usize| {
progress_current.store(current, Ordering::Relaxed);
progress_total.store(total, Ordering::Relaxed);
});
let vulnerabilities = self
.vulnerability_repository
.fetch_vulnerabilities_with_progress(packages, progress_callback)
.await?;
is_done.store(true, Ordering::Relaxed);
let _ = progress_handle.join();
Ok(vulnerabilities)
}
pub fn summarize(vulnerabilities: &[PackageVulnerabilities]) -> (usize, usize) {
let total_vulns: usize = vulnerabilities
.iter()
.map(|v| v.vulnerabilities().len())
.sum();
let affected_packages = vulnerabilities.len();
(total_vulns, affected_packages)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ports::outbound::ProgressCallback;
use crate::sbom_generation::domain::vulnerability::{CvssScore, Severity, Vulnerability};
use async_trait::async_trait;
struct MockVulnerabilityRepository {
vulnerabilities: Vec<PackageVulnerabilities>,
}
#[async_trait]
impl VulnerabilityRepository for MockVulnerabilityRepository {
async fn fetch_vulnerabilities(
&self,
_packages: Vec<Package>,
) -> Result<Vec<PackageVulnerabilities>> {
Ok(self.vulnerabilities.clone())
}
async fn fetch_vulnerabilities_with_progress(
&self,
_packages: Vec<Package>,
_progress_callback: ProgressCallback<'static>,
) -> Result<Vec<PackageVulnerabilities>> {
Ok(self.vulnerabilities.clone())
}
}
fn create_test_package(name: &str, version: &str) -> Package {
Package::new(name.to_string(), version.to_string()).unwrap()
}
fn create_test_vulnerability(id: &str, severity: Severity, cvss: Option<f32>) -> Vulnerability {
let cvss_score = cvss.map(|score| CvssScore::new(score).unwrap());
Vulnerability::new(
id.to_string(),
cvss_score,
severity,
None,
Some(format!("Test vulnerability {}", id)),
)
.unwrap()
}
fn create_test_pkg_vulns(
name: &str,
version: &str,
vulns: Vec<Vulnerability>,
) -> PackageVulnerabilities {
PackageVulnerabilities::new(name.to_string(), version.to_string(), vulns)
}
#[test]
fn test_summarize_empty() {
let (total, packages) =
CheckVulnerabilitiesUseCase::<MockVulnerabilityRepository>::summarize(&[]);
assert_eq!(total, 0);
assert_eq!(packages, 0);
}
#[test]
fn test_summarize_single_package_single_vuln() {
let vuln = create_test_vulnerability("CVE-2024-0001", Severity::High, Some(7.5));
let pkg_vulns = create_test_pkg_vulns("requests", "2.31.0", vec![vuln]);
let (total, packages) =
CheckVulnerabilitiesUseCase::<MockVulnerabilityRepository>::summarize(&[pkg_vulns]);
assert_eq!(total, 1);
assert_eq!(packages, 1);
}
#[test]
fn test_summarize_single_package_multiple_vulns() {
let vuln1 = create_test_vulnerability("CVE-2024-0001", Severity::High, Some(7.5));
let vuln2 = create_test_vulnerability("CVE-2024-0002", Severity::Critical, Some(9.8));
let vuln3 = create_test_vulnerability("CVE-2024-0003", Severity::Low, Some(2.0));
let pkg_vulns = create_test_pkg_vulns("requests", "2.31.0", vec![vuln1, vuln2, vuln3]);
let (total, packages) =
CheckVulnerabilitiesUseCase::<MockVulnerabilityRepository>::summarize(&[pkg_vulns]);
assert_eq!(total, 3);
assert_eq!(packages, 1);
}
#[test]
fn test_summarize_multiple_packages() {
let vuln1 = create_test_vulnerability("CVE-2024-0001", Severity::High, Some(7.5));
let vuln2 = create_test_vulnerability("CVE-2024-0002", Severity::Critical, Some(9.8));
let pkg_vulns1 = create_test_pkg_vulns("requests", "2.31.0", vec![vuln1]);
let pkg_vulns2 = create_test_pkg_vulns("urllib3", "1.26.0", vec![vuln2]);
let (total, packages) =
CheckVulnerabilitiesUseCase::<MockVulnerabilityRepository>::summarize(&[
pkg_vulns1, pkg_vulns2,
]);
assert_eq!(total, 2);
assert_eq!(packages, 2);
}
#[tokio::test]
async fn test_check_with_progress_no_vulnerabilities() {
let repo = MockVulnerabilityRepository {
vulnerabilities: vec![],
};
let use_case = CheckVulnerabilitiesUseCase::new(repo);
let packages = vec![create_test_package("requests", "2.31.0")];
let result = use_case.check_with_progress(packages).await.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn test_check_with_progress_with_vulnerabilities() {
let vuln = create_test_vulnerability("CVE-2024-0001", Severity::Critical, Some(9.8));
let pkg_vulns = create_test_pkg_vulns("requests", "2.31.0", vec![vuln]);
let repo = MockVulnerabilityRepository {
vulnerabilities: vec![pkg_vulns],
};
let use_case = CheckVulnerabilitiesUseCase::new(repo);
let packages = vec![create_test_package("requests", "2.31.0")];
let result = use_case.check_with_progress(packages).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].package_name(), "requests");
assert_eq!(result[0].vulnerabilities().len(), 1);
}
#[tokio::test]
async fn test_check_with_progress_multiple_packages() {
let vuln1 = create_test_vulnerability("CVE-2024-0001", Severity::High, Some(7.5));
let vuln2 = create_test_vulnerability("CVE-2024-0002", Severity::Critical, Some(9.8));
let pkg_vulns1 = create_test_pkg_vulns("requests", "2.31.0", vec![vuln1]);
let pkg_vulns2 = create_test_pkg_vulns("urllib3", "1.26.0", vec![vuln2]);
let repo = MockVulnerabilityRepository {
vulnerabilities: vec![pkg_vulns1, pkg_vulns2],
};
let use_case = CheckVulnerabilitiesUseCase::new(repo);
let packages = vec![
create_test_package("requests", "2.31.0"),
create_test_package("urllib3", "1.26.0"),
];
let result = use_case.check_with_progress(packages).await.unwrap();
assert_eq!(result.len(), 2);
}
}