use std::collections::HashMap;
use std::path::PathBuf;
use serde::{Deserialize, Serialize};
use super::vulnerability::{Advisory, Severity};
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct AdvisoryDatabase {
advisories: HashMap<String, Advisory>,
last_updated: Option<String>,
sources: Vec<String>,
}
impl AdvisoryDatabase {
pub fn new() -> Self {
Self::default()
}
pub fn load_from_cache() -> Result<Self, String> {
let cache_path = Self::cache_path()?;
if !cache_path.exists() {
return Ok(Self::new());
}
let content = std::fs::read_to_string(&cache_path)
.map_err(|e| format!("Failed to read advisory cache: {}", e))?;
serde_json::from_str(&content).map_err(|e| format!("Failed to parse advisory cache: {}", e))
}
pub fn save_to_cache(&self) -> Result<(), String> {
let cache_path = Self::cache_path()?;
if let Some(parent) = cache_path.parent() {
std::fs::create_dir_all(parent)
.map_err(|e| format!("Failed to create cache directory: {}", e))?;
}
let content = serde_json::to_string_pretty(self)
.map_err(|e| format!("Failed to serialize advisory cache: {}", e))?;
std::fs::write(&cache_path, content)
.map_err(|e| format!("Failed to write advisory cache: {}", e))
}
fn cache_path() -> Result<PathBuf, String> {
let dirs = directories::ProjectDirs::from("io", "linthis", "linthis")
.ok_or_else(|| "Failed to get project directories".to_string())?;
Ok(dirs.cache_dir().join("security").join("advisories.json"))
}
pub fn add_advisory(&mut self, advisory: Advisory) {
self.advisories.insert(advisory.id.clone(), advisory);
}
pub fn get_advisory(&self, id: &str) -> Option<&Advisory> {
self.advisories.get(id)
}
pub fn search(&self, keyword: &str) -> Vec<&Advisory> {
let keyword_lower = keyword.to_lowercase();
self.advisories
.values()
.filter(|a| {
a.id.to_lowercase().contains(&keyword_lower)
|| a.title.to_lowercase().contains(&keyword_lower)
|| a.description.to_lowercase().contains(&keyword_lower)
})
.collect()
}
pub fn by_severity(&self, severity: Severity) -> Vec<&Advisory> {
self.advisories
.values()
.filter(|a| a.severity == severity)
.collect()
}
pub fn get_all_cve_ids(&self) -> Vec<&str> {
self.advisories
.values()
.filter_map(|a| a.cve_id())
.collect()
}
pub fn len(&self) -> usize {
self.advisories.len()
}
pub fn is_empty(&self) -> bool {
self.advisories.is_empty()
}
pub fn clear(&mut self) {
self.advisories.clear();
self.last_updated = None;
}
}
#[allow(dead_code)]
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct SuppressionList {
suppressions: HashMap<String, Suppression>,
}
#[allow(dead_code)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Suppression {
pub id: String,
pub reason: String,
pub expires: Option<String>,
pub added_by: Option<String>,
pub added_at: Option<String>,
}
#[allow(dead_code)]
impl SuppressionList {
pub fn new() -> Self {
Self::default()
}
pub fn load_from_file(path: &std::path::Path) -> Result<Self, String> {
if !path.exists() {
return Ok(Self::new());
}
let content = std::fs::read_to_string(path)
.map_err(|e| format!("Failed to read suppression file: {}", e))?;
if path.extension().map(|e| e == "toml").unwrap_or(false) {
toml::from_str(&content).map_err(|e| format!("Failed to parse suppression TOML: {}", e))
} else {
serde_json::from_str(&content)
.map_err(|e| format!("Failed to parse suppression JSON: {}", e))
}
}
pub fn is_suppressed(&self, id: &str) -> bool {
if let Some(suppression) = self.suppressions.get(id) {
if let Some(ref expires) = suppression.expires {
if let Ok(expiry) = chrono::NaiveDate::parse_from_str(expires, "%Y-%m-%d") {
let today = chrono::Local::now().date_naive();
if today > expiry {
return false; }
}
}
return true;
}
false
}
pub fn get_reason(&self, id: &str) -> Option<&str> {
self.suppressions.get(id).map(|s| s.reason.as_str())
}
pub fn add(&mut self, suppression: Suppression) {
self.suppressions
.insert(suppression.id.clone(), suppression);
}
pub fn remove(&mut self, id: &str) -> Option<Suppression> {
self.suppressions.remove(id)
}
pub fn all(&self) -> impl Iterator<Item = &Suppression> {
self.suppressions.values()
}
pub fn len(&self) -> usize {
self.suppressions.len()
}
pub fn is_empty(&self) -> bool {
self.suppressions.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_advisory_database() {
let mut db = AdvisoryDatabase::new();
assert!(db.is_empty());
let advisory = Advisory {
id: "CVE-2024-1234".to_string(),
aliases: vec![],
title: "Test vulnerability".to_string(),
description: "A test vulnerability description".to_string(),
severity: Severity::High,
cvss_score: Some(7.5),
cvss_vector: None,
url: None,
published: None,
updated: None,
cwe_ids: vec![],
references: vec![],
};
db.add_advisory(advisory);
assert_eq!(db.len(), 1);
let found = db.get_advisory("CVE-2024-1234");
assert!(found.is_some());
assert_eq!(found.unwrap().title, "Test vulnerability");
}
#[test]
fn test_suppression_list() {
let mut list = SuppressionList::new();
assert!(list.is_empty());
list.add(Suppression {
id: "CVE-2024-1234".to_string(),
reason: "False positive".to_string(),
expires: None,
added_by: None,
added_at: None,
});
assert!(list.is_suppressed("CVE-2024-1234"));
assert!(!list.is_suppressed("CVE-2024-9999"));
}
}