use crate::core::{
FileHasher, FileInput, FileMetadata, ScanContext, ScanError, ScanOutcome, ScanResult, Scanner,
ThreatInfo,
};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::RwLock;
use std::time::Duration;
#[derive(Debug)]
pub struct MockScanner {
name: String,
responses: RwLock<HashMap<String, ScanOutcome>>,
default_outcome: ScanOutcome,
latency: Option<Duration>,
fail_rate: f32,
scan_count: AtomicU64,
unhealthy: RwLock<bool>,
}
impl MockScanner {
pub fn new() -> Self {
Self {
name: "mock".to_string(),
responses: RwLock::new(HashMap::new()),
default_outcome: ScanOutcome::Clean,
latency: None,
fail_rate: 0.0,
scan_count: AtomicU64::new(0),
unhealthy: RwLock::new(false),
}
}
pub fn new_clean() -> Self {
Self::new()
}
pub fn new_infected(threats: Vec<ThreatInfo>) -> Self {
Self {
default_outcome: ScanOutcome::Infected { threats },
..Self::new()
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
pub fn with_default_outcome(mut self, outcome: ScanOutcome) -> Self {
self.default_outcome = outcome;
self
}
pub fn with_response(self, hash: impl Into<String>, outcome: ScanOutcome) -> Self {
self.responses
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.insert(hash.into(), outcome);
self
}
pub fn with_latency(mut self, latency: Duration) -> Self {
self.latency = Some(latency);
self
}
pub fn with_fail_rate(mut self, rate: f32) -> Self {
self.fail_rate = rate.clamp(0.0, 1.0);
self
}
pub fn scan_count(&self) -> u64 {
self.scan_count.load(Ordering::Relaxed)
}
pub fn set_healthy(&self, healthy: bool) {
*self
.unhealthy
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner()) = !healthy;
}
pub fn make_unhealthy(&self) {
self.set_healthy(false);
}
pub fn make_healthy(&self) {
self.set_healthy(true);
}
pub fn add_response(&self, hash: impl Into<String>, outcome: ScanOutcome) {
self.responses
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.insert(hash.into(), outcome);
}
pub fn clear_responses(&self) {
self.responses
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.clear();
}
fn should_fail(&self) -> bool {
if self.fail_rate <= 0.0 {
return false;
}
if self.fail_rate >= 1.0 {
return true;
}
let count = self.scan_count.load(Ordering::Relaxed);
(count as f32 * 0.618033988749895) % 1.0 < self.fail_rate
}
}
impl Default for MockScanner {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Scanner for MockScanner {
fn name(&self) -> &str {
&self.name
}
async fn scan(&self, input: &FileInput) -> Result<ScanResult, ScanError> {
self.scan_count.fetch_add(1, Ordering::Relaxed);
if self.should_fail() {
return Err(ScanError::engine_unavailable(
&self.name,
"simulated failure",
));
}
if let Some(latency) = self.latency {
#[cfg(feature = "tokio-runtime")]
tokio::time::sleep(latency).await;
#[cfg(not(feature = "tokio-runtime"))]
std::thread::sleep(latency);
}
let hasher = FileHasher::new();
let hash = match input {
FileInput::Path(path) => hasher.hash_file(path)?,
FileInput::Bytes { data, .. } => hasher.hash_bytes(data),
FileInput::Stream { .. } => {
return Err(ScanError::internal(
"Mock scanner does not support streaming",
));
}
};
let outcome = self
.responses
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.get(&hash.blake3)
.cloned()
.unwrap_or_else(|| self.default_outcome.clone());
let size = input.size_hint().unwrap_or(0);
let metadata = FileMetadata::new(size, hash);
let duration = self.latency.unwrap_or(Duration::from_millis(1));
let context = ScanContext::new();
Ok(ScanResult::new(
outcome,
metadata,
self.name.clone(),
duration,
context,
))
}
async fn health_check(&self) -> Result<(), ScanError> {
if *self
.unhealthy
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner())
{
return Err(ScanError::engine_unavailable(
&self.name,
"mock scanner is unhealthy",
));
}
Ok(())
}
fn max_file_size(&self) -> Option<u64> {
Some(100 * 1024 * 1024) }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::ThreatSeverity;
#[tokio::test]
async fn test_mock_scanner_clean() {
let scanner = MockScanner::new_clean();
let input = FileInput::from_bytes(b"test data".to_vec());
let result = scanner.scan(&input).await.unwrap();
assert!(result.is_clean());
assert_eq!(scanner.scan_count(), 1);
}
#[tokio::test]
async fn test_mock_scanner_infected() {
let threats = vec![ThreatInfo::new(
"Test.Malware",
ThreatSeverity::High,
"mock",
)];
let scanner = MockScanner::new_infected(threats);
let input = FileInput::from_bytes(b"malicious data".to_vec());
let result = scanner.scan(&input).await.unwrap();
assert!(result.is_infected());
assert_eq!(result.threats().unwrap().len(), 1);
}
#[tokio::test]
async fn test_mock_scanner_health_check() {
let scanner = MockScanner::new();
assert!(scanner.health_check().await.is_ok());
scanner.make_unhealthy();
assert!(scanner.health_check().await.is_err());
scanner.make_healthy();
assert!(scanner.health_check().await.is_ok());
}
#[tokio::test]
async fn test_mock_scanner_custom_response() {
let scanner = MockScanner::new()
.with_default_outcome(ScanOutcome::Clean)
.with_response(
"known-malware-hash",
ScanOutcome::Infected {
threats: vec![ThreatInfo::new(
"Known.Malware",
ThreatSeverity::Critical,
"mock",
)],
},
);
let input = FileInput::from_bytes(b"unknown file".to_vec());
let result = scanner.scan(&input).await.unwrap();
assert!(result.is_clean());
}
}