use super::types::{StorageError, StorageResult, UploadedFile};
use async_trait::async_trait;
use chrono::Utc;
use serde::{Deserialize, Serialize};
use std::fmt;
use tracing::{error, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
struct QuarantineMetadata {
quarantined_at: String,
threat_name: String,
original_filename: String,
original_mime_type: String,
file_size: usize,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ScanResult {
Clean,
Infected {
threat: String,
},
Error {
message: String,
},
}
impl fmt::Display for ScanResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Clean => write!(f, "Clean"),
Self::Infected { threat } => write!(f, "Infected: {threat}"),
Self::Error { message } => write!(f, "Scan error: {message}"),
}
}
}
#[cfg_attr(test, mockall::automock)]
#[async_trait]
pub trait VirusScanner: Send + Sync {
async fn scan(&self, file: &UploadedFile) -> StorageResult<ScanResult>;
fn name(&self) -> &'static str;
async fn is_available(&self) -> bool;
}
#[derive(Debug, Clone, Default)]
pub struct NoOpScanner;
impl NoOpScanner {
#[must_use]
pub const fn new() -> Self {
Self
}
#[must_use]
pub const fn is_development_only(&self) -> bool {
true
}
}
#[async_trait]
impl VirusScanner for NoOpScanner {
async fn scan(&self, _file: &UploadedFile) -> StorageResult<ScanResult> {
Ok(ScanResult::Clean)
}
fn name(&self) -> &'static str {
"NoOp Scanner"
}
async fn is_available(&self) -> bool {
true
}
}
#[cfg(feature = "clamav")]
#[derive(Debug, Clone)]
pub enum ClamAvConnection {
Tcp {
host: String,
port: u16,
},
#[cfg(unix)]
Socket {
path: std::path::PathBuf,
},
}
#[cfg(feature = "clamav")]
#[derive(Debug, Clone)]
pub struct ClamAvScanner {
connection: ClamAvConnection,
}
#[cfg(feature = "clamav")]
impl ClamAvScanner {
#[must_use]
pub const fn new(connection: ClamAvConnection) -> Self {
Self { connection }
}
#[must_use]
pub fn default_tcp() -> Self {
Self::new(ClamAvConnection::Tcp {
host: "localhost".to_string(),
port: 3310,
})
}
#[must_use]
#[cfg(unix)]
pub fn default_socket() -> Self {
Self::new(ClamAvConnection::Socket {
path: "/var/run/clamav/clamd.sock".into(),
})
}
}
#[cfg(feature = "clamav")]
#[async_trait]
impl VirusScanner for ClamAvScanner {
async fn scan(&self, file: &UploadedFile) -> StorageResult<ScanResult> {
use clamav_client::tokio::{scan_buffer, Tcp};
#[cfg(unix)]
use clamav_client::tokio::Socket;
let data = &file.data;
let response = match &self.connection {
ClamAvConnection::Tcp { host, port } => {
let host_address = format!("{}:{}", host, port);
let clamd = Tcp {
host_address: &host_address,
};
scan_buffer(data, clamd, None)
.await
.map_err(|e| StorageError::Other(format!("ClamAV scan failed: {}", e)))?
}
#[cfg(unix)]
ClamAvConnection::Socket { path } => {
let path_str = path
.to_str()
.ok_or_else(|| StorageError::Other("Invalid socket path".to_string()))?;
let clamd = Socket {
socket_path: path_str,
};
scan_buffer(data, clamd, None)
.await
.map_err(|e| StorageError::Other(format!("ClamAV scan failed: {}", e)))?
}
#[cfg(not(unix))]
ClamAvConnection::Socket { .. } => {
return Err(StorageError::Other(
"Unix socket connections not supported on this platform".to_string(),
))
}
};
match clamav_client::clean(&response) {
Ok(true) => Ok(ScanResult::Clean),
Ok(false) => {
let threat = String::from_utf8_lossy(&response).trim().to_string();
Ok(ScanResult::Infected { threat })
}
Err(e) => Ok(ScanResult::Error {
message: format!("Failed to parse scan result: {}", e),
}),
}
}
fn name(&self) -> &'static str {
"ClamAV Scanner"
}
async fn is_available(&self) -> bool {
use clamav_client::tokio::{ping, Tcp};
use clamav_client::PONG;
#[cfg(unix)]
use clamav_client::tokio::Socket;
match &self.connection {
ClamAvConnection::Tcp { host, port } => {
let host_address = format!("{}:{}", host, port);
let clamd = Tcp {
host_address: &host_address,
};
matches!(ping(clamd).await, Ok(response) if response == *PONG)
}
#[cfg(unix)]
ClamAvConnection::Socket { path } => {
let Some(path_str) = path.to_str() else {
return false;
};
let clamd = Socket {
socket_path: path_str,
};
matches!(ping(clamd).await, Ok(response) if response == *PONG)
}
#[cfg(not(unix))]
ClamAvConnection::Socket { .. } => false,
}
}
}
#[cfg(not(feature = "clamav"))]
#[derive(Debug, Clone, Default)]
pub struct ClamAvScanner;
#[cfg(not(feature = "clamav"))]
impl ClamAvScanner {
#[must_use]
pub const fn new() -> Self {
Self
}
}
#[cfg(not(feature = "clamav"))]
#[async_trait]
impl VirusScanner for ClamAvScanner {
async fn scan(&self, _file: &UploadedFile) -> StorageResult<ScanResult> {
Err(StorageError::Other(
"ClamAV support not enabled. Recompile with 'clamav' feature.".to_string(),
))
}
fn name(&self) -> &'static str {
"ClamAV Scanner (disabled)"
}
async fn is_available(&self) -> bool {
false
}
}
#[derive(Debug)]
pub struct QuarantineScanner<S: VirusScanner> {
inner: S,
quarantine_path: std::path::PathBuf,
}
impl<S: VirusScanner> QuarantineScanner<S> {
#[must_use]
pub const fn new(scanner: S, quarantine_path: std::path::PathBuf) -> Self {
Self {
inner: scanner,
quarantine_path,
}
}
}
#[async_trait]
impl<S: VirusScanner> VirusScanner for QuarantineScanner<S> {
async fn scan(&self, file: &UploadedFile) -> StorageResult<ScanResult> {
let result = self.inner.scan(file).await?;
if let ScanResult::Infected { ref threat } = result {
if let Err(e) = self.quarantine_file(file, threat).await {
error!(
"Failed to quarantine infected file '{}': {}",
file.filename, e
);
warn!(
"File '{}' detected as infected with '{}' but quarantine failed",
file.filename, threat
);
}
}
Ok(result)
}
fn name(&self) -> &'static str {
"Quarantine Scanner"
}
async fn is_available(&self) -> bool {
self.inner.is_available().await
}
}
impl<S: VirusScanner> QuarantineScanner<S> {
async fn quarantine_file(&self, file: &UploadedFile, threat: &str) -> StorageResult<()> {
tokio::fs::create_dir_all(&self.quarantine_path)
.await
.map_err(|e| {
StorageError::Other(format!("Failed to create quarantine directory: {e}"))
})?;
let unique_id = uuid::Uuid::new_v4();
let timestamp = Utc::now().format("%Y%m%d_%H%M%S");
let quarantine_filename = format!("{timestamp}_{unique_id}");
let quarantine_file_path = self.quarantine_path.join(&quarantine_filename);
let metadata_path = self.quarantine_path.join(format!("{quarantine_filename}.json"));
let metadata = QuarantineMetadata {
quarantined_at: Utc::now().to_rfc3339(),
threat_name: threat.to_string(),
original_filename: file.filename.clone(),
original_mime_type: file.content_type.clone(),
file_size: file.data.len(),
};
tokio::fs::write(&quarantine_file_path, &file.data)
.await
.map_err(|e| StorageError::Other(format!("Failed to write quarantined file: {e}")))?;
let metadata_json = serde_json::to_string_pretty(&metadata)
.map_err(|e| {
StorageError::Other(format!("Failed to serialize quarantine metadata: {e}"))
})?;
tokio::fs::write(&metadata_path, metadata_json)
.await
.map_err(|e| {
StorageError::Other(format!("Failed to write quarantine metadata: {e}"))
})?;
warn!(
"File '{}' quarantined as '{}' - Threat: {}",
file.filename, quarantine_filename, threat
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_noop_scanner_always_clean() {
let file = UploadedFile::new("test.txt", "text/plain", b"harmless data".to_vec());
let scanner = NoOpScanner::new();
let result = scanner.scan(&file).await.unwrap();
assert_eq!(result, ScanResult::Clean);
}
#[tokio::test]
async fn test_noop_scanner_available() {
let scanner = NoOpScanner::new();
assert!(scanner.is_available().await);
}
#[tokio::test]
async fn test_noop_scanner_name() {
let scanner = NoOpScanner::new();
assert_eq!(scanner.name(), "NoOp Scanner");
}
#[cfg(feature = "clamav")]
#[tokio::test]
async fn test_clamav_scanner_tcp_not_available() {
let scanner = ClamAvScanner::new(ClamAvConnection::Tcp {
host: "nonexistent.invalid".to_string(),
port: 9999,
});
assert!(!scanner.is_available().await);
}
#[cfg(all(feature = "clamav", unix))]
#[tokio::test]
async fn test_clamav_scanner_socket_not_available() {
let scanner = ClamAvScanner::new(ClamAvConnection::Socket {
path: "/nonexistent/path.sock".into(),
});
assert!(!scanner.is_available().await);
}
#[cfg(feature = "clamav")]
#[tokio::test]
async fn test_clamav_scanner_default_tcp() {
let scanner = ClamAvScanner::default_tcp();
assert_eq!(scanner.name(), "ClamAV Scanner");
}
#[cfg(all(feature = "clamav", unix))]
#[tokio::test]
async fn test_clamav_scanner_default_socket() {
let scanner = ClamAvScanner::default_socket();
assert_eq!(scanner.name(), "ClamAV Scanner");
}
#[cfg(feature = "clamav")]
#[tokio::test]
async fn test_clamav_scanner_scan_connection_refused() {
let file = UploadedFile::new("test.txt", "text/plain", b"test data".to_vec());
let scanner = ClamAvScanner::new(ClamAvConnection::Tcp {
host: "localhost".to_string(),
port: 9999, });
let result = scanner.scan(&file).await;
assert!(result.is_err());
if let Err(StorageError::Other(msg)) = result {
assert!(msg.contains("ClamAV scan failed"));
}
}
#[cfg(not(feature = "clamav"))]
#[tokio::test]
async fn test_clamav_scanner_disabled() {
let file = UploadedFile::new("test.txt", "text/plain", b"test data".to_vec());
let scanner = ClamAvScanner::new();
let result = scanner.scan(&file).await;
assert!(result.is_err());
if let Err(StorageError::Other(msg)) = result {
assert!(msg.contains("not enabled"));
}
}
#[cfg(not(feature = "clamav"))]
#[tokio::test]
async fn test_clamav_scanner_disabled_not_available() {
let scanner = ClamAvScanner::new();
assert!(!scanner.is_available().await);
assert_eq!(scanner.name(), "ClamAV Scanner (disabled)");
}
#[test]
fn test_scan_result_display() {
assert_eq!(ScanResult::Clean.to_string(), "Clean");
assert_eq!(
ScanResult::Infected {
threat: "EICAR".to_string()
}
.to_string(),
"Infected: EICAR"
);
assert_eq!(
ScanResult::Error {
message: "Scanner offline".to_string()
}
.to_string(),
"Scan error: Scanner offline"
);
}
#[tokio::test]
async fn test_quarantine_scanner_wraps_inner() {
let file = UploadedFile::new("test.txt", "text/plain", b"test".to_vec());
let scanner = QuarantineScanner::new(
NoOpScanner::new(),
std::path::PathBuf::from("/tmp/quarantine"),
);
let result = scanner.scan(&file).await.unwrap();
assert_eq!(result, ScanResult::Clean);
}
#[derive(Debug, Clone)]
struct MockInfectedScanner {
threat: String,
}
impl MockInfectedScanner {
fn new(threat: impl Into<String>) -> Self {
Self {
threat: threat.into(),
}
}
}
#[async_trait]
impl VirusScanner for MockInfectedScanner {
async fn scan(&self, _file: &UploadedFile) -> StorageResult<ScanResult> {
Ok(ScanResult::Infected {
threat: self.threat.clone(),
})
}
fn name(&self) -> &'static str {
"Mock Infected Scanner"
}
async fn is_available(&self) -> bool {
true
}
}
#[tokio::test]
async fn test_quarantine_scanner_quarantines_infected_files() {
let temp_dir = tempfile::tempdir().unwrap();
let quarantine_path = temp_dir.path().to_path_buf();
let file = UploadedFile::new(
"malware.exe",
"application/octet-stream",
b"EICAR test file".to_vec(),
);
let scanner = QuarantineScanner::new(
MockInfectedScanner::new("EICAR.Test.Signature"),
quarantine_path.clone(),
);
let result = scanner.scan(&file).await.unwrap();
assert!(matches!(result, ScanResult::Infected { .. }));
assert!(quarantine_path.exists());
let entries: Vec<_> = std::fs::read_dir(&quarantine_path)
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(entries.len(), 2, "Should have quarantine file and metadata");
let metadata_file = entries
.iter()
.find(|e| e.path().extension().and_then(|s| s.to_str()) == Some("json"))
.expect("Should have metadata JSON file");
let metadata_json = std::fs::read_to_string(metadata_file.path()).unwrap();
let metadata: QuarantineMetadata = serde_json::from_str(&metadata_json).unwrap();
assert_eq!(metadata.threat_name, "EICAR.Test.Signature");
assert_eq!(metadata.original_filename, "malware.exe");
assert_eq!(metadata.original_mime_type, "application/octet-stream");
assert_eq!(metadata.file_size, b"EICAR test file".len());
let data_file = entries
.iter()
.find(|e| e.path().extension().is_none())
.expect("Should have quarantine data file");
let quarantined_data = std::fs::read(data_file.path()).unwrap();
assert_eq!(quarantined_data, b"EICAR test file");
}
#[tokio::test]
async fn test_quarantine_scanner_clean_files_not_quarantined() {
let temp_dir = tempfile::tempdir().unwrap();
let quarantine_path = temp_dir.path().to_path_buf();
let file = UploadedFile::new("clean.txt", "text/plain", b"clean data".to_vec());
let scanner = QuarantineScanner::new(NoOpScanner::new(), quarantine_path.clone());
let result = scanner.scan(&file).await.unwrap();
assert_eq!(result, ScanResult::Clean);
let entries: Vec<_> = std::fs::read_dir(&quarantine_path)
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(entries.len(), 0, "Clean files should not be quarantined");
}
#[tokio::test]
async fn test_quarantine_scanner_creates_directory() {
let temp_dir = tempfile::tempdir().unwrap();
let quarantine_path = temp_dir.path().join("nested").join("quarantine");
assert!(!quarantine_path.exists());
let file = UploadedFile::new("malware.bin", "application/octet-stream", b"bad".to_vec());
let scanner = QuarantineScanner::new(
MockInfectedScanner::new("Test.Virus"),
quarantine_path.clone(),
);
scanner.scan(&file).await.unwrap();
assert!(quarantine_path.exists());
assert!(quarantine_path.is_dir());
}
#[tokio::test]
async fn test_quarantine_scanner_unique_filenames() {
let temp_dir = tempfile::tempdir().unwrap();
let quarantine_path = temp_dir.path().to_path_buf();
let scanner = QuarantineScanner::new(
MockInfectedScanner::new("Test.Virus"),
quarantine_path.clone(),
);
let file1 = UploadedFile::new("malware.exe", "application/octet-stream", b"bad1".to_vec());
let file2 = UploadedFile::new("malware.exe", "application/octet-stream", b"bad2".to_vec());
scanner.scan(&file1).await.unwrap();
scanner.scan(&file2).await.unwrap();
let entries: Vec<_> = std::fs::read_dir(&quarantine_path)
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(entries.len(), 4, "Should have 4 files (2 files + 2 metadata)");
let mut filenames: Vec<_> = entries
.iter()
.map(|e| e.file_name().to_string_lossy().to_string())
.collect();
filenames.sort();
filenames.dedup();
assert_eq!(filenames.len(), 4, "All quarantined files should have unique names");
}
#[tokio::test]
async fn test_quarantine_scanner_name() {
let scanner = QuarantineScanner::new(
NoOpScanner::new(),
std::path::PathBuf::from("/tmp/quarantine"),
);
assert_eq!(scanner.name(), "Quarantine Scanner");
}
#[tokio::test]
async fn test_quarantine_scanner_availability() {
let scanner = QuarantineScanner::new(
NoOpScanner::new(),
std::path::PathBuf::from("/tmp/quarantine"),
);
assert!(scanner.is_available().await);
let unavailable_scanner = QuarantineScanner::new(
MockInfectedScanner::new("test"),
std::path::PathBuf::from("/tmp/quarantine"),
);
assert!(unavailable_scanner.is_available().await);
}
}