use crate::circuit_breaker::config::{CircuitBreakerConfig, FallbackBehavior};
use crate::circuit_breaker::state::{BreakerMetrics, BreakerState};
use crate::core::{FileInput, ScanError, ScanResult, Scanner};
use async_trait::async_trait;
use std::fmt;
use std::sync::RwLock;
use std::time::Instant;
pub struct CircuitBreaker<S: Scanner> {
inner: S,
state: RwLock<BreakerState>,
config: CircuitBreakerConfig,
metrics: RwLock<BreakerMetrics>,
}
impl<S: Scanner> CircuitBreaker<S> {
pub fn new(scanner: S, config: CircuitBreakerConfig) -> Self {
Self {
inner: scanner,
state: RwLock::new(BreakerState::closed()),
config,
metrics: RwLock::new(BreakerMetrics::new()),
}
}
pub fn with_defaults(scanner: S) -> Self {
Self::new(scanner, CircuitBreakerConfig::default())
}
pub fn state(&self) -> BreakerState {
self.state
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.clone()
}
pub fn metrics(&self) -> BreakerMetrics {
self.metrics
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.clone()
}
pub fn force_open(&self) {
let until = Instant::now() + self.config.open_duration;
*self
.state
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner()) = BreakerState::Open {
opened_at: Instant::now(),
until,
};
self.metrics
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.record_opened();
}
pub fn force_close(&self) {
*self
.state
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner()) = BreakerState::closed();
self.metrics
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.record_closed();
}
pub fn reset(&self) {
*self
.state
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner()) = BreakerState::closed();
*self
.metrics
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner()) = BreakerMetrics::new();
}
pub fn inner(&self) -> &S {
&self.inner
}
pub fn config(&self) -> &CircuitBreakerConfig {
&self.config
}
fn should_allow_request(&self) -> Result<(), ScanError> {
let mut state = self
.state
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let now = Instant::now();
match &*state {
BreakerState::Closed { .. } => Ok(()),
BreakerState::Open { until, .. } => {
if now >= *until {
*state = BreakerState::HalfOpen {
success_count: 0,
probe_count: 1,
};
Ok(())
} else {
Err(ScanError::CircuitOpen {
engine: self.inner.name().to_string(),
recovery_hint: Some(format!("Circuit may recover in {:?}", *until - now)),
})
}
}
BreakerState::HalfOpen {
success_count,
probe_count,
} => {
if *probe_count < self.config.half_open_max_probes {
*state = BreakerState::HalfOpen {
success_count: *success_count,
probe_count: probe_count + 1,
};
Ok(())
} else {
Err(ScanError::CircuitOpen {
engine: self.inner.name().to_string(),
recovery_hint: Some("Maximum probes in progress".to_string()),
})
}
}
}
}
fn record_success(&self) {
let mut state = self
.state
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner());
self.metrics
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.record_success();
match &*state {
BreakerState::Closed { .. } => {
*state = BreakerState::closed();
}
BreakerState::HalfOpen {
success_count,
probe_count,
} => {
let new_success_count = success_count + 1;
if new_success_count >= self.config.success_threshold {
*state = BreakerState::closed();
self.metrics
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.record_closed();
} else {
*state = BreakerState::HalfOpen {
success_count: new_success_count,
probe_count: *probe_count,
};
}
}
BreakerState::Open { .. } => {
}
}
}
fn record_failure(&self, error: &ScanError) {
if !self.config.failure_policy.should_count(error) {
return;
}
let mut state = self
.state
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner());
self.metrics
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.record_failure();
match &*state {
BreakerState::Closed { failure_count } => {
let new_count = failure_count + 1;
if new_count >= self.config.failure_threshold {
let until = Instant::now() + self.config.open_duration;
*state = BreakerState::Open {
opened_at: Instant::now(),
until,
};
self.metrics
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.record_opened();
} else {
*state = BreakerState::Closed {
failure_count: new_count,
};
}
}
BreakerState::HalfOpen { .. } => {
let until = Instant::now() + self.config.open_duration;
*state = BreakerState::Open {
opened_at: Instant::now(),
until,
};
self.metrics
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.record_opened();
}
BreakerState::Open { .. } => {
}
}
}
async fn handle_open_circuit(&self, input: &FileInput) -> Result<ScanResult, ScanError> {
self.metrics
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.record_rejected();
match &self.config.fallback_behavior {
FallbackBehavior::FailClosed => Err(ScanError::CircuitOpen {
engine: self.inner.name().to_string(),
recovery_hint: Some("Circuit is open; scan rejected".to_string()),
}),
FallbackBehavior::FailOpen => {
tracing::warn!(
engine = self.inner.name(),
"Circuit open, allowing file through (fail-open mode)"
);
use crate::core::{FileHash, FileHasher, FileMetadata, ScanContext, ScanOutcome};
use std::time::Duration;
let hasher = FileHasher::new();
let hash = match input {
FileInput::Path(path) => hasher.hash_file(path)?,
FileInput::Bytes { data, .. } => hasher.hash_bytes(data),
FileInput::Stream { .. } => FileHash::new("unknown-stream"),
};
let metadata = FileMetadata::new(input.size_hint().unwrap_or(0), hash);
let context = ScanContext::new();
let mut result = ScanResult::new(
ScanOutcome::Clean,
metadata,
format!("{}-failopen", self.inner.name()),
Duration::ZERO,
context,
);
result.details.insert(
"warning".to_string(),
serde_json::Value::String(
"Scan skipped due to circuit breaker; file allowed through fail-open policy"
.to_string(),
),
);
Ok(result)
}
FallbackBehavior::Fallback(fallback) => {
tracing::info!(
primary = self.inner.name(),
fallback = fallback.name(),
"Using fallback scanner due to open circuit"
);
fallback.scan(input).await
}
}
}
}
impl<S: Scanner> fmt::Debug for CircuitBreaker<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CircuitBreaker")
.field("inner", &self.inner)
.field(
"state",
&*self
.state
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner()),
)
.field("config", &self.config)
.finish()
}
}
#[async_trait]
impl<S: Scanner> Scanner for CircuitBreaker<S> {
fn name(&self) -> &str {
self.inner.name()
}
async fn scan(&self, input: &FileInput) -> Result<ScanResult, ScanError> {
if self.should_allow_request().is_err() {
return self.handle_open_circuit(input).await;
}
match self.inner.scan(input).await {
Ok(result) => {
self.record_success();
Ok(result)
}
Err(e) => {
self.record_failure(&e);
Err(e)
}
}
}
async fn health_check(&self) -> Result<(), ScanError> {
match self.inner.health_check().await {
Ok(()) => {
if self.state().is_half_open() {
self.record_success();
}
Ok(())
}
Err(e) => {
if self.state().is_half_open() {
self.record_failure(&e);
}
Err(e)
}
}
}
fn max_file_size(&self) -> Option<u64> {
self.inner.max_file_size()
}
async fn signature_version(&self) -> Option<String> {
self.inner.signature_version().await
}
fn supports_streaming(&self) -> bool {
self.inner.supports_streaming()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backends::MockScanner;
use std::time::Duration;
#[tokio::test]
async fn test_circuit_breaker_passes_through() {
let scanner = MockScanner::new_clean();
let breaker = CircuitBreaker::with_defaults(scanner);
let input = FileInput::from_bytes(b"test".to_vec());
let result = breaker.scan(&input).await.unwrap();
assert!(result.is_clean());
assert!(breaker.state().is_closed());
}
#[tokio::test]
async fn test_circuit_opens_on_failures() {
let scanner = MockScanner::new().with_fail_rate(1.0);
let config = CircuitBreakerConfig::default().with_failure_threshold(3);
let breaker = CircuitBreaker::new(scanner, config);
let input = FileInput::from_bytes(b"test".to_vec());
for _ in 0..3 {
let _ = breaker.scan(&input).await;
}
assert!(breaker.state().is_open());
assert_eq!(breaker.metrics().times_opened, 1);
}
#[tokio::test]
async fn test_circuit_rejects_when_open() {
let scanner = MockScanner::new_clean();
let breaker = CircuitBreaker::with_defaults(scanner);
breaker.force_open();
assert!(breaker.state().is_open());
let input = FileInput::from_bytes(b"test".to_vec());
let result = breaker.scan(&input).await;
assert!(matches!(result, Err(ScanError::CircuitOpen { .. })));
}
#[tokio::test]
async fn test_circuit_transitions_to_half_open() {
let scanner = MockScanner::new_clean();
let config = CircuitBreakerConfig::default().with_open_duration(Duration::from_millis(10));
let breaker = CircuitBreaker::new(scanner, config);
breaker.force_open();
assert!(breaker.state().is_open());
tokio::time::sleep(Duration::from_millis(20)).await;
let input = FileInput::from_bytes(b"test".to_vec());
let result = breaker.scan(&input).await;
assert!(result.is_ok());
}
#[test]
fn test_force_open_close() {
let scanner = MockScanner::new_clean();
let breaker = CircuitBreaker::with_defaults(scanner);
assert!(breaker.state().is_closed());
breaker.force_open();
assert!(breaker.state().is_open());
breaker.force_close();
assert!(breaker.state().is_closed());
}
}