use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::time::Instant;
use serde::{Deserialize, Serialize};
use super::languages::{
GoSecurityScanner, JavaSecurityScanner, NodeSecurityScanner, PythonSecurityScanner,
RustSecurityScanner,
};
use super::vulnerability::{Severity, Vulnerability};
pub trait LanguageSecurityScanner {
fn is_available(&self) -> bool;
fn name(&self) -> &str;
fn language(&self) -> &str;
fn detect(&self, path: &Path) -> bool;
fn scan(&self, path: &Path, options: &ScanOptions) -> Result<Vec<Vulnerability>, String>;
fn fix(&self, path: &Path, vulnerabilities: &[Vulnerability]) -> Result<FixResult, String>;
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ScanOptions {
pub path: PathBuf,
pub severity_threshold: Option<String>,
pub include_dev: bool,
pub packages: Vec<String>,
pub ignore: Vec<String>,
pub format: String,
pub generate_sbom: bool,
pub fail_on: Option<String>,
pub verbose: bool,
}
impl ScanOptions {
pub fn new(path: PathBuf) -> Self {
Self {
path,
format: "human".to_string(),
..Default::default()
}
}
pub fn get_severity_threshold(&self) -> Severity {
self.severity_threshold
.as_ref()
.map(|s| Severity::from_str(s))
.unwrap_or(Severity::Unknown)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct FixResult {
pub fixed: Vec<String>,
pub unfixed: Vec<String>,
pub commands: Vec<String>,
pub needs_review: bool,
pub messages: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScanResult {
pub vulnerabilities: Vec<Vulnerability>,
pub by_severity: HashMap<String, usize>,
pub by_language: HashMap<String, usize>,
pub duration_ms: u64,
pub languages_scanned: Vec<String>,
pub scanner_status: HashMap<String, bool>,
pub errors: Vec<String>,
pub total_packages: usize,
pub vulnerable_packages: usize,
}
impl ScanResult {
pub fn new() -> Self {
Self {
vulnerabilities: Vec::new(),
by_severity: HashMap::new(),
by_language: HashMap::new(),
duration_ms: 0,
languages_scanned: Vec::new(),
scanner_status: HashMap::new(),
errors: Vec::new(),
total_packages: 0,
vulnerable_packages: 0,
}
}
pub fn filter_by_severity(&self, threshold: Severity) -> Vec<&Vulnerability> {
self.vulnerabilities
.iter()
.filter(|v| v.meets_severity_threshold(&threshold))
.collect()
}
pub fn critical_high_count(&self) -> usize {
self.vulnerabilities
.iter()
.filter(|v| matches!(v.severity(), Severity::Critical | Severity::High))
.count()
}
pub fn has_vulnerabilities_above(&self, threshold: Severity) -> bool {
self.vulnerabilities
.iter()
.any(|v| v.meets_severity_threshold(&threshold))
}
}
impl Default for ScanResult {
fn default() -> Self {
Self::new()
}
}
pub struct SecurityScanner {
scanners: Vec<Box<dyn LanguageSecurityScanner>>,
}
impl Default for SecurityScanner {
fn default() -> Self {
Self::new()
}
}
impl SecurityScanner {
pub fn new() -> Self {
let scanners: Vec<Box<dyn LanguageSecurityScanner>> = vec![
Box::new(RustSecurityScanner::new()),
Box::new(NodeSecurityScanner::new()),
Box::new(PythonSecurityScanner::new()),
Box::new(GoSecurityScanner::new()),
Box::new(JavaSecurityScanner::new()),
];
Self { scanners }
}
pub fn available_scanners(&self) -> Vec<(&str, &str, bool)> {
self.scanners
.iter()
.map(|s| (s.name(), s.language(), s.is_available()))
.collect()
}
pub fn detect_languages(&self, path: &Path) -> Vec<String> {
self.scanners
.iter()
.filter(|s| s.detect(path))
.map(|s| s.language().to_string())
.collect()
}
pub fn scan(&self, options: &ScanOptions) -> Result<ScanResult, String> {
let start = Instant::now();
let mut result = ScanResult::new();
let path = &options.path;
for scanner in &self.scanners {
let is_available = scanner.is_available();
result
.scanner_status
.insert(scanner.name().to_string(), is_available);
if !is_available {
continue;
}
if !scanner.detect(path) {
continue;
}
result
.languages_scanned
.push(scanner.language().to_string());
match scanner.scan(path, options) {
Ok(vulns) => {
for vuln in vulns {
if options.ignore.contains(&vuln.advisory.id) {
continue;
}
*result
.by_severity
.entry(vuln.severity().to_string())
.or_insert(0) += 1;
*result.by_language.entry(vuln.language.clone()).or_insert(0) += 1;
result.vulnerabilities.push(vuln);
}
}
Err(e) => {
result.errors.push(format!("{}: {}", scanner.name(), e));
}
}
}
result.duration_ms = start.elapsed().as_millis() as u64;
result.vulnerable_packages = result
.vulnerabilities
.iter()
.flat_map(|v| &v.affected_packages)
.map(|p| format!("{}@{}", p.name, p.version))
.collect::<std::collections::HashSet<_>>()
.len();
Ok(result)
}
pub fn fix(&self, path: &Path, result: &ScanResult) -> Result<FixResult, String> {
let mut fix_result = FixResult::default();
let mut by_language: HashMap<String, Vec<&Vulnerability>> = HashMap::new();
for vuln in &result.vulnerabilities {
by_language
.entry(vuln.language.clone())
.or_default()
.push(vuln);
}
for (language, vulns) in by_language {
if let Some(scanner) = self.scanners.iter().find(|s| s.language() == language) {
let vuln_refs: Vec<Vulnerability> = vulns.into_iter().cloned().collect();
match scanner.fix(path, &vuln_refs) {
Ok(lang_result) => {
fix_result.fixed.extend(lang_result.fixed);
fix_result.unfixed.extend(lang_result.unfixed);
fix_result.commands.extend(lang_result.commands);
fix_result.messages.extend(lang_result.messages);
if lang_result.needs_review {
fix_result.needs_review = true;
}
}
Err(e) => {
fix_result
.messages
.push(format!("Failed to fix {} vulnerabilities: {}", language, e));
}
}
}
}
Ok(fix_result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scan_options_default() {
let options = ScanOptions::default();
assert_eq!(options.format, "");
assert!(!options.include_dev);
assert!(options.ignore.is_empty());
}
#[test]
fn test_scan_result_default() {
let result = ScanResult::new();
assert!(result.vulnerabilities.is_empty());
assert_eq!(result.duration_ms, 0);
}
#[test]
fn test_security_scanner_creation() {
let scanner = SecurityScanner::new();
let available = scanner.available_scanners();
assert!(!available.is_empty());
}
}