use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::sync::{mpsc, RwLock};
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum FirewallStatus {
#[default]
Unknown,
Accessible,
Blocked,
Error,
}
#[derive(Debug, Clone)]
pub struct FirewallDetectionConfig {
pub event_wait_timeout: Duration,
pub enable_caching: bool,
pub max_cached_devices: usize,
}
impl Default for FirewallDetectionConfig {
fn default() -> Self {
Self {
event_wait_timeout: Duration::from_secs(15),
enable_caching: true,
max_cached_devices: 100,
}
}
}
#[derive(Debug, Clone)]
pub struct DeviceFirewallState {
pub device_ip: IpAddr,
pub status: FirewallStatus,
pub first_subscription_time: SystemTime,
pub first_event_time: Option<SystemTime>,
pub detection_completed: bool,
pub timeout_duration: Duration,
}
#[derive(Debug, Clone)]
pub struct DetectionResult {
pub device_ip: IpAddr,
pub status: FirewallStatus,
pub reason: DetectionReason,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DetectionReason {
EventReceived,
Timeout,
SubscriptionFailed,
}
pub struct FirewallDetectionCoordinator {
device_states: Arc<RwLock<HashMap<IpAddr, Arc<RwLock<DeviceFirewallState>>>>>,
config: FirewallDetectionConfig,
detection_complete_tx: mpsc::UnboundedSender<DetectionResult>,
_timeout_task_handle: tokio::task::JoinHandle<()>,
}
impl FirewallDetectionCoordinator {
pub fn new(config: FirewallDetectionConfig) -> Self {
let (detection_complete_tx, mut detection_complete_rx) = mpsc::unbounded_channel();
let device_states = Arc::new(RwLock::new(HashMap::new()));
let timeout_task_handle = {
let device_states = device_states.clone();
let detection_complete_tx = detection_complete_tx.clone();
tokio::spawn(async move {
Self::monitor_timeouts(device_states, detection_complete_tx).await;
})
};
tokio::spawn(async move {
while let Some(result) = detection_complete_rx.recv().await {
match result.reason {
DetectionReason::EventReceived => {
info!(
device_ip = %result.device_ip,
reason = ?result.reason,
status = ?result.status,
"Firewall detection: Events accessible from device"
);
}
DetectionReason::Timeout => {
warn!(
device_ip = %result.device_ip,
reason = ?result.reason,
status = ?result.status,
"Firewall detection: No events received within timeout"
);
}
DetectionReason::SubscriptionFailed => {
warn!(
device_ip = %result.device_ip,
reason = ?result.reason,
status = ?result.status,
"Firewall detection: Subscription failed for device"
);
}
}
}
});
Self {
device_states,
config,
detection_complete_tx,
_timeout_task_handle: timeout_task_handle,
}
}
pub async fn on_first_subscription(&self, device_ip: IpAddr) -> FirewallStatus {
if !self.config.enable_caching {
self.start_detection_for_device(device_ip).await;
return FirewallStatus::Unknown;
}
let device_states = self.device_states.read().await;
if let Some(state_arc) = device_states.get(&device_ip) {
let state = state_arc.read().await;
if state.detection_completed {
debug!(
device_ip = %device_ip,
status = ?state.status,
"Firewall detection: Using cached status for device"
);
return state.status;
}
}
drop(device_states);
self.start_detection_for_device(device_ip).await;
debug!(
device_ip = %device_ip,
timeout = ?self.config.event_wait_timeout,
"Firewall detection: Started monitoring device for events"
);
FirewallStatus::Unknown
}
pub async fn on_event_received(&self, device_ip: IpAddr) {
let device_states = self.device_states.read().await;
if let Some(state_arc) = device_states.get(&device_ip) {
let mut state = state_arc.write().await;
if !state.detection_completed {
state.first_event_time = Some(SystemTime::now());
state.status = FirewallStatus::Accessible;
state.detection_completed = true;
let elapsed = SystemTime::now()
.duration_since(state.first_subscription_time)
.unwrap_or(Duration::ZERO);
let _ = self.detection_complete_tx.send(DetectionResult {
device_ip,
status: FirewallStatus::Accessible,
reason: DetectionReason::EventReceived,
});
info!(
device_ip = %device_ip,
elapsed = ?elapsed,
status = ?FirewallStatus::Accessible,
"Firewall detection: Event received from device, marking as accessible"
);
}
}
}
pub async fn get_device_status(&self, device_ip: IpAddr) -> FirewallStatus {
let device_states = self.device_states.read().await;
if let Some(state_arc) = device_states.get(&device_ip) {
let state = state_arc.read().await;
state.status
} else {
FirewallStatus::Unknown
}
}
pub async fn clear_device_cache(&self, device_ip: IpAddr) {
let mut device_states = self.device_states.write().await;
device_states.remove(&device_ip);
debug!(
device_ip = %device_ip,
"Firewall detection: Cleared cache for device"
);
}
async fn start_detection_for_device(&self, device_ip: IpAddr) {
let mut device_states = self.device_states.write().await;
let new_state = Arc::new(RwLock::new(DeviceFirewallState {
device_ip,
status: FirewallStatus::Unknown,
first_subscription_time: SystemTime::now(),
first_event_time: None,
detection_completed: false,
timeout_duration: self.config.event_wait_timeout,
}));
if device_states.len() >= self.config.max_cached_devices {
if let Some(oldest_ip) = device_states.keys().next().copied() {
device_states.remove(&oldest_ip);
debug!(
oldest_ip = %oldest_ip,
cache_size = self.config.max_cached_devices,
"Firewall detection: Removed oldest cached entry due to cache being full"
);
}
}
device_states.insert(device_ip, new_state);
}
async fn monitor_timeouts(
device_states: Arc<RwLock<HashMap<IpAddr, Arc<RwLock<DeviceFirewallState>>>>>,
detection_complete_tx: mpsc::UnboundedSender<DetectionResult>,
) {
let mut interval = tokio::time::interval(Duration::from_secs(1));
loop {
interval.tick().await;
let device_states_read = device_states.read().await;
for (device_ip, state_arc) in device_states_read.iter() {
let mut state = state_arc.write().await;
if !state.detection_completed {
let elapsed = SystemTime::now()
.duration_since(state.first_subscription_time)
.unwrap_or(Duration::ZERO);
if elapsed >= state.timeout_duration {
state.status = FirewallStatus::Blocked;
state.detection_completed = true;
let _ = detection_complete_tx.send(DetectionResult {
device_ip: *device_ip,
status: FirewallStatus::Blocked,
reason: DetectionReason::Timeout,
});
warn!(
device_ip = %device_ip,
timeout = ?state.timeout_duration,
status = ?FirewallStatus::Blocked,
"Firewall detection: No events received within timeout, marking as blocked"
);
}
}
}
}
}
pub async fn get_stats(&self) -> CoordinatorStats {
let device_states = self.device_states.read().await;
let mut stats = CoordinatorStats {
total_devices: device_states.len(),
accessible_devices: 0,
blocked_devices: 0,
unknown_devices: 0,
error_devices: 0,
};
for state_arc in device_states.values() {
let state = state_arc.read().await;
match state.status {
FirewallStatus::Accessible => stats.accessible_devices += 1,
FirewallStatus::Blocked => stats.blocked_devices += 1,
FirewallStatus::Unknown => stats.unknown_devices += 1,
FirewallStatus::Error => stats.error_devices += 1,
}
}
stats
}
}
#[derive(Debug, Clone)]
pub struct CoordinatorStats {
pub total_devices: usize,
pub accessible_devices: usize,
pub blocked_devices: usize,
pub unknown_devices: usize,
pub error_devices: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
#[tokio::test]
async fn test_coordinator_creation() {
let config = FirewallDetectionConfig::default();
let _coordinator = FirewallDetectionCoordinator::new(config);
}
#[tokio::test]
async fn test_first_subscription_starts_monitoring() {
let config = FirewallDetectionConfig::default();
let coordinator = FirewallDetectionCoordinator::new(config);
let device_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
let status = coordinator.on_first_subscription(device_ip).await;
assert_eq!(status, FirewallStatus::Unknown);
let cached_status = coordinator.get_device_status(device_ip).await;
assert_eq!(cached_status, FirewallStatus::Unknown);
}
#[tokio::test]
async fn test_event_received_marks_accessible() {
let config = FirewallDetectionConfig::default();
let coordinator = FirewallDetectionCoordinator::new(config);
let device_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
coordinator.on_first_subscription(device_ip).await;
coordinator.on_event_received(device_ip).await;
let status = coordinator.get_device_status(device_ip).await;
assert_eq!(status, FirewallStatus::Accessible);
}
#[tokio::test]
async fn test_timeout_marks_blocked() {
let config = FirewallDetectionConfig {
event_wait_timeout: Duration::from_millis(100), ..Default::default()
};
let coordinator = FirewallDetectionCoordinator::new(config);
let device_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
coordinator.on_first_subscription(device_ip).await;
tokio::time::sleep(Duration::from_millis(1200)).await;
let status = coordinator.get_device_status(device_ip).await;
assert_eq!(status, FirewallStatus::Blocked);
}
#[tokio::test]
async fn test_cached_status_reused() {
let config = FirewallDetectionConfig::default();
let coordinator = FirewallDetectionCoordinator::new(config);
let device_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
coordinator.on_first_subscription(device_ip).await;
coordinator.on_event_received(device_ip).await;
let status = coordinator.on_first_subscription(device_ip).await;
assert_eq!(status, FirewallStatus::Accessible);
}
#[tokio::test]
async fn test_clear_device_cache() {
let config = FirewallDetectionConfig::default();
let coordinator = FirewallDetectionCoordinator::new(config);
let device_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
coordinator.on_first_subscription(device_ip).await;
coordinator.on_event_received(device_ip).await;
assert_eq!(
coordinator.get_device_status(device_ip).await,
FirewallStatus::Accessible
);
coordinator.clear_device_cache(device_ip).await;
assert_eq!(
coordinator.get_device_status(device_ip).await,
FirewallStatus::Unknown
);
}
#[tokio::test]
async fn test_stats() {
let config = FirewallDetectionConfig::default();
let coordinator = FirewallDetectionCoordinator::new(config);
let device1 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
let device2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 101));
coordinator.on_first_subscription(device1).await;
coordinator.on_event_received(device1).await;
coordinator.on_first_subscription(device2).await;
let stats = coordinator.get_stats().await;
assert_eq!(stats.total_devices, 2);
assert_eq!(stats.accessible_devices, 1);
assert_eq!(stats.unknown_devices, 1);
assert_eq!(stats.blocked_devices, 0);
assert_eq!(stats.error_devices, 0);
}
}