use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{mpsc, RwLock};
use callback_server::{FirewallDetectionCoordinator, FirewallStatus};
use tracing::debug;
use crate::broker::PollingReason;
use crate::registry::{RegistrationId, SpeakerServicePair};
struct MonitoredRegistration {
last_event_time: Instant,
pair: SpeakerServicePair,
polling_activated: bool,
}
pub struct EventDetector {
registrations: Arc<RwLock<HashMap<RegistrationId, MonitoredRegistration>>>,
event_timeout: Duration,
polling_activation_delay: Duration,
firewall_coordinator: Option<Arc<FirewallDetectionCoordinator>>,
polling_request_sender: Option<mpsc::UnboundedSender<PollingRequest>>,
}
#[derive(Debug, Clone)]
pub struct PollingRequest {
pub registration_id: RegistrationId,
pub speaker_service_pair: SpeakerServicePair,
pub action: PollingAction,
pub reason: PollingReason,
}
#[derive(Debug, Clone)]
pub enum PollingAction {
Start,
Stop,
}
impl EventDetector {
pub fn new(event_timeout: Duration, polling_activation_delay: Duration) -> Self {
Self {
registrations: Arc::new(RwLock::new(HashMap::new())),
event_timeout,
polling_activation_delay,
firewall_coordinator: None,
polling_request_sender: None,
}
}
pub fn set_firewall_coordinator(&mut self, coordinator: Arc<FirewallDetectionCoordinator>) {
self.firewall_coordinator = Some(coordinator);
}
pub fn set_polling_request_sender(&mut self, sender: mpsc::UnboundedSender<PollingRequest>) {
self.polling_request_sender = Some(sender);
}
pub async fn record_event(&self, registration_id: RegistrationId) {
let mut registrations = self.registrations.write().await;
if let Some(reg) = registrations.get_mut(®istration_id) {
reg.last_event_time = Instant::now();
}
}
pub async fn should_start_polling(&self, registration_id: RegistrationId) -> bool {
let registrations = self.registrations.read().await;
registrations
.get(®istration_id)
.map(|reg| reg.last_event_time.elapsed() > self.event_timeout)
.unwrap_or(false)
}
pub async fn should_stop_polling(&self, registration_id: RegistrationId) -> bool {
let registrations = self.registrations.read().await;
registrations
.get(®istration_id)
.map(|reg| reg.last_event_time.elapsed() <= self.polling_activation_delay)
.unwrap_or(false)
}
pub async fn evaluate_firewall_status(
&self,
registration_id: RegistrationId,
pair: &SpeakerServicePair,
) -> Option<PollingRequest> {
if let Some(firewall_coordinator) = &self.firewall_coordinator {
let status = firewall_coordinator
.get_device_status(pair.speaker_ip)
.await;
match status {
FirewallStatus::Blocked => {
Some(PollingRequest {
registration_id,
speaker_service_pair: pair.clone(),
action: PollingAction::Start,
reason: PollingReason::FirewallBlocked,
})
}
FirewallStatus::Accessible => {
None
}
FirewallStatus::Unknown => {
None
}
FirewallStatus::Error => {
Some(PollingRequest {
registration_id,
speaker_service_pair: pair.clone(),
action: PollingAction::Start,
reason: PollingReason::NetworkIssues,
})
}
}
} else {
None
}
}
pub async fn start_monitoring(&self) -> tokio::task::JoinHandle<()> {
let registrations = Arc::clone(&self.registrations);
let event_timeout = self.event_timeout;
let polling_request_sender = self.polling_request_sender.clone();
let check_interval = (event_timeout / 3).max(Duration::from_secs(1));
tokio::spawn(async move {
let mut interval = tokio::time::interval(check_interval);
loop {
interval.tick().await;
let now = Instant::now();
let timed_out: Vec<(RegistrationId, SpeakerServicePair)> = {
let regs = registrations.read().await;
regs.iter()
.filter(|(_, reg)| {
!reg.polling_activated
&& now.duration_since(reg.last_event_time) > event_timeout
})
.map(|(id, reg)| (*id, reg.pair.clone()))
.collect()
};
for (registration_id, pair) in timed_out {
if let Some(sender) = &polling_request_sender {
let request = PollingRequest {
registration_id,
speaker_service_pair: pair,
action: PollingAction::Start,
reason: PollingReason::EventTimeout,
};
if sender.send(request).is_ok() {
let mut regs = registrations.write().await;
if let Some(reg) = regs.get_mut(®istration_id) {
reg.polling_activated = true;
}
debug!(
registration_id = %registration_id,
"Event timeout detected, sent polling request"
);
}
}
}
}
})
}
pub async fn register_subscription(
&self,
registration_id: RegistrationId,
pair: SpeakerServicePair,
) {
let mut registrations = self.registrations.write().await;
registrations.insert(
registration_id,
MonitoredRegistration {
last_event_time: Instant::now(),
pair,
polling_activated: false,
},
);
}
pub async fn unregister_subscription(&self, registration_id: RegistrationId) {
let mut registrations = self.registrations.write().await;
registrations.remove(®istration_id);
}
pub async fn stats(&self) -> EventDetectorStats {
let registrations = self.registrations.read().await;
let total_monitored = registrations.len();
let now = Instant::now();
let mut timeout_count = 0;
let mut recent_events_count = 0;
for reg in registrations.values() {
let elapsed = now.duration_since(reg.last_event_time);
if elapsed > self.event_timeout {
timeout_count += 1;
} else if elapsed <= Duration::from_secs(60) {
recent_events_count += 1;
}
}
let firewall_status = FirewallStatus::Unknown;
EventDetectorStats {
total_monitored,
timeout_count,
recent_events_count,
firewall_status,
event_timeout: self.event_timeout,
}
}
}
#[derive(Debug)]
pub struct EventDetectorStats {
pub total_monitored: usize,
pub timeout_count: usize,
pub recent_events_count: usize,
pub firewall_status: FirewallStatus,
pub event_timeout: Duration,
}
impl std::fmt::Display for EventDetectorStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Event Detector Stats:")?;
writeln!(f, " Total monitored: {}", self.total_monitored)?;
writeln!(f, " Timeout count: {}", self.timeout_count)?;
writeln!(f, " Recent events: {}", self.recent_events_count)?;
writeln!(f, " Firewall status: {:?}", self.firewall_status)?;
writeln!(f, " Event timeout: {:?}", self.event_timeout)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_event_detector_creation() {
let detector = EventDetector::new(Duration::from_secs(30), Duration::from_secs(5));
assert_eq!(detector.event_timeout, Duration::from_secs(30));
assert_eq!(detector.polling_activation_delay, Duration::from_secs(5));
}
#[tokio::test]
async fn test_event_recording() {
let detector = EventDetector::new(Duration::from_secs(30), Duration::from_secs(5));
let registration_id = RegistrationId::new(1);
let pair = SpeakerServicePair::new(
"192.168.1.100".parse().unwrap(),
sonos_api::Service::AVTransport,
);
assert!(!detector.should_start_polling(registration_id).await);
detector.register_subscription(registration_id, pair).await;
detector.record_event(registration_id).await;
assert!(!detector.should_start_polling(registration_id).await);
}
#[tokio::test]
async fn test_subscription_registration() {
let detector = EventDetector::new(Duration::from_secs(30), Duration::from_secs(5));
let registration_id = RegistrationId::new(1);
let pair = SpeakerServicePair::new(
"192.168.1.100".parse().unwrap(),
sonos_api::Service::AVTransport,
);
detector.register_subscription(registration_id, pair).await;
let stats = detector.stats().await;
assert_eq!(stats.total_monitored, 1);
detector.unregister_subscription(registration_id).await;
let stats = detector.stats().await;
assert_eq!(stats.total_monitored, 0);
}
#[tokio::test]
async fn test_register_and_unregister() {
let detector = EventDetector::new(Duration::from_secs(30), Duration::from_secs(5));
let registration_id = RegistrationId::new(1);
let pair = SpeakerServicePair::new(
"192.168.1.100".parse().unwrap(),
sonos_api::Service::AVTransport,
);
detector
.register_subscription(registration_id, pair.clone())
.await;
let regs = detector.registrations.read().await;
assert!(regs.contains_key(®istration_id));
assert_eq!(regs[®istration_id].pair.speaker_ip, pair.speaker_ip);
drop(regs);
detector.unregister_subscription(registration_id).await;
let regs = detector.registrations.read().await;
assert!(!regs.contains_key(®istration_id));
}
#[tokio::test]
async fn test_event_timeout_sends_polling_request() {
use tokio::sync::mpsc;
let mut detector = EventDetector::new(Duration::from_millis(50), Duration::from_secs(5));
let (sender, mut receiver) = mpsc::unbounded_channel();
detector.set_polling_request_sender(sender);
let detector = Arc::new(detector);
let registration_id = RegistrationId::new(42);
let pair = SpeakerServicePair::new(
"192.168.1.100".parse().unwrap(),
sonos_api::Service::RenderingControl,
);
detector
.register_subscription(registration_id, pair.clone())
.await;
{
let mut regs = detector.registrations.write().await;
if let Some(reg) = regs.get_mut(®istration_id) {
reg.last_event_time = Instant::now() - Duration::from_secs(60);
}
}
detector.start_monitoring().await;
let request = tokio::time::timeout(Duration::from_secs(2), receiver.recv()).await;
assert!(
request.is_ok(),
"Should receive a polling request within timeout"
);
let request = request.unwrap().expect("Channel should have a message");
assert_eq!(request.registration_id, registration_id);
assert_eq!(request.speaker_service_pair.speaker_ip, pair.speaker_ip);
assert!(matches!(request.action, PollingAction::Start));
assert_eq!(request.reason, PollingReason::EventTimeout);
}
}