Skip to main content

sonos_stream/subscription/
manager.rs

1//! Subscription lifecycle management with SonosClient integration
2//!
3//! This module provides subscription management by integrating with SonosClient's
4//! ManagedSubscription system and coordinating with the callback server for event routing.
5
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::Arc;
9use std::time::SystemTime;
10use tokio::sync::{Mutex, RwLock};
11
12use callback_server::firewall_detection::FirewallStatus;
13use sonos_api::{ManagedSubscription, Service, SonosClient};
14
15use crate::error::{SubscriptionError, SubscriptionResult};
16use crate::registry::{RegistrationId, SpeakerServicePair};
17
18/// Wrapper around ManagedSubscription with additional context for event streaming
19#[derive(Debug)]
20pub struct ManagedSubscriptionWrapper {
21    /// The actual SonosClient subscription
22    subscription: ManagedSubscription,
23
24    /// Registration ID this subscription belongs to
25    registration_id: RegistrationId,
26
27    /// Speaker/service pair for this subscription
28    speaker_service_pair: SpeakerServicePair,
29
30    /// Timestamp of the last event received for this subscription
31    last_event_time: Arc<Mutex<Option<SystemTime>>>,
32
33    /// Whether polling is currently active for this subscription
34    is_polling_active: Arc<AtomicBool>,
35
36    /// Creation timestamp
37    created_at: SystemTime,
38
39    /// Number of renewal attempts
40    renewal_count: Arc<Mutex<u32>>,
41}
42
43impl ManagedSubscriptionWrapper {
44    /// Create a new wrapper around a ManagedSubscription
45    pub fn new(
46        subscription: ManagedSubscription,
47        registration_id: RegistrationId,
48        speaker_service_pair: SpeakerServicePair,
49    ) -> Self {
50        Self {
51            subscription,
52            registration_id,
53            speaker_service_pair,
54            last_event_time: Arc::new(Mutex::new(None)),
55            is_polling_active: Arc::new(AtomicBool::new(false)),
56            created_at: SystemTime::now(),
57            renewal_count: Arc::new(Mutex::new(0)),
58        }
59    }
60
61    /// Get the registration ID
62    pub fn registration_id(&self) -> RegistrationId {
63        self.registration_id
64    }
65
66    /// Get the speaker/service pair
67    pub fn speaker_service_pair(&self) -> &SpeakerServicePair {
68        &self.speaker_service_pair
69    }
70
71    /// Get the UPnP subscription ID
72    pub fn subscription_id(&self) -> &str {
73        self.subscription.subscription_id()
74    }
75
76    /// Check if the subscription is active
77    pub fn is_active(&self) -> bool {
78        self.subscription.is_active()
79    }
80
81    /// Check if the subscription needs renewal
82    pub fn needs_renewal(&self) -> bool {
83        self.subscription.needs_renewal()
84    }
85
86    /// Renew the subscription
87    pub async fn renew(&self) -> SubscriptionResult<()> {
88        self.subscription
89            .renew()
90            .map_err(|e| SubscriptionError::RenewalFailed(e.to_string()))?;
91
92        // Increment renewal count
93        let mut count = self.renewal_count.lock().await;
94        *count += 1;
95
96        Ok(())
97    }
98
99    /// Unsubscribe and clean up
100    pub async fn unsubscribe(&self) -> SubscriptionResult<()> {
101        self.subscription
102            .unsubscribe()
103            .map_err(|e| SubscriptionError::NetworkError(e.to_string()))?;
104        Ok(())
105    }
106
107    /// Record that an event was received for this subscription
108    pub async fn record_event_received(&self) {
109        let mut last_event_time = self.last_event_time.lock().await;
110        *last_event_time = Some(SystemTime::now());
111    }
112
113    /// Get the time of the last event received
114    pub async fn last_event_time(&self) -> Option<SystemTime> {
115        let last_event_time = self.last_event_time.lock().await;
116        *last_event_time
117    }
118
119    /// Set whether polling is active for this subscription
120    pub fn set_polling_active(&self, active: bool) {
121        self.is_polling_active.store(active, Ordering::Relaxed);
122    }
123
124    /// Check if polling is active for this subscription
125    pub fn is_polling_active(&self) -> bool {
126        self.is_polling_active.load(Ordering::Relaxed)
127    }
128
129    /// Get creation timestamp
130    pub fn created_at(&self) -> SystemTime {
131        self.created_at
132    }
133
134    /// Get renewal count
135    pub async fn renewal_count(&self) -> u32 {
136        let count = self.renewal_count.lock().await;
137        *count
138    }
139}
140
141/// Manages subscriptions for registered speaker/service pairs
142pub struct SubscriptionManager {
143    /// SonosClient for creating and managing subscriptions
144    sonos_client: SonosClient,
145
146    /// Callback URL for UPnP event notifications
147    callback_url: String,
148
149    /// Active subscriptions indexed by registration ID
150    active_subscriptions: Arc<RwLock<HashMap<RegistrationId, Arc<ManagedSubscriptionWrapper>>>>,
151
152    /// Current firewall status (shared with other components)
153    firewall_status: Arc<RwLock<FirewallStatus>>,
154}
155
156impl SubscriptionManager {
157    /// Create a new SubscriptionManager
158    pub fn new(callback_url: String) -> Self {
159        Self {
160            sonos_client: SonosClient::new(),
161            callback_url,
162            active_subscriptions: Arc::new(RwLock::new(HashMap::new())),
163            firewall_status: Arc::new(RwLock::new(FirewallStatus::Unknown)),
164        }
165    }
166
167    /// Set the firewall status (called by firewall detection system)
168    pub async fn set_firewall_status(&self, status: FirewallStatus) {
169        let mut current_status = self.firewall_status.write().await;
170        *current_status = status;
171    }
172
173    /// Get the current firewall status
174    pub async fn firewall_status(&self) -> FirewallStatus {
175        let status = self.firewall_status.read().await;
176        *status
177    }
178
179    /// Create a subscription for a speaker/service pair
180    pub async fn create_subscription(
181        &self,
182        registration_id: RegistrationId,
183        pair: SpeakerServicePair,
184    ) -> SubscriptionResult<Arc<ManagedSubscriptionWrapper>> {
185        // Convert Service to the format expected by SonosClient (no conversion needed since we're using the same enum)
186        let service = pair.service;
187
188        // Create the subscription using SonosClient
189        let subscription = self
190            .sonos_client
191            .subscribe(&pair.speaker_ip.to_string(), service, &self.callback_url)
192            .map_err(|e| SubscriptionError::CreationFailed(e.to_string()))?;
193
194        // Wrap it with our additional context
195        let wrapper = Arc::new(ManagedSubscriptionWrapper::new(
196            subscription,
197            registration_id,
198            pair,
199        ));
200
201        // Store in our active subscriptions
202        let mut subscriptions = self.active_subscriptions.write().await;
203        subscriptions.insert(registration_id, Arc::clone(&wrapper));
204
205        Ok(wrapper)
206    }
207
208    /// Remove a subscription
209    pub async fn remove_subscription(
210        &self,
211        registration_id: RegistrationId,
212    ) -> SubscriptionResult<()> {
213        let mut subscriptions = self.active_subscriptions.write().await;
214
215        if let Some(wrapper) = subscriptions.remove(&registration_id) {
216            // Unsubscribe from the UPnP service
217            wrapper.unsubscribe().await?;
218        } else {
219            return Err(SubscriptionError::InvalidState);
220        }
221
222        Ok(())
223    }
224
225    /// Get a subscription by registration ID
226    pub async fn get_subscription(
227        &self,
228        registration_id: RegistrationId,
229    ) -> Option<Arc<ManagedSubscriptionWrapper>> {
230        let subscriptions = self.active_subscriptions.read().await;
231        subscriptions.get(&registration_id).cloned()
232    }
233
234    /// Get subscription by UPnP subscription ID (for event routing)
235    pub async fn get_subscription_by_sid(
236        &self,
237        subscription_id: &str,
238    ) -> Option<Arc<ManagedSubscriptionWrapper>> {
239        let subscriptions = self.active_subscriptions.read().await;
240        subscriptions
241            .values()
242            .find(|wrapper| wrapper.subscription_id() == subscription_id)
243            .cloned()
244    }
245
246    /// List all active subscriptions
247    pub async fn list_subscriptions(&self) -> Vec<Arc<ManagedSubscriptionWrapper>> {
248        let subscriptions = self.active_subscriptions.read().await;
249        subscriptions.values().cloned().collect()
250    }
251
252    /// Check for subscriptions that need renewal and renew them
253    pub async fn check_renewals(&self) -> SubscriptionResult<usize> {
254        let subscriptions = self.active_subscriptions.read().await;
255        let mut renewed_count = 0;
256
257        for wrapper in subscriptions.values() {
258            if wrapper.needs_renewal() {
259                match wrapper.renew().await {
260                    Ok(()) => {
261                        renewed_count += 1;
262                        eprintln!(
263                            "✅ Renewed subscription for {} {:?}",
264                            wrapper.speaker_service_pair.speaker_ip,
265                            wrapper.speaker_service_pair.service
266                        );
267                    }
268                    Err(e) => {
269                        eprintln!(
270                            "❌ Failed to renew subscription for {} {:?}: {}",
271                            wrapper.speaker_service_pair.speaker_ip,
272                            wrapper.speaker_service_pair.service,
273                            e
274                        );
275                        // Note: We continue processing other subscriptions even if one fails
276                    }
277                }
278            }
279        }
280
281        Ok(renewed_count)
282    }
283
284    /// Record that an event was received for a subscription
285    pub async fn record_event_received(&self, subscription_id: &str) {
286        if let Some(wrapper) = self.get_subscription_by_sid(subscription_id).await {
287            wrapper.record_event_received().await;
288        }
289    }
290
291    /// Get statistics about managed subscriptions
292    pub async fn stats(&self) -> SubscriptionStats {
293        let subscriptions = self.active_subscriptions.read().await;
294        let total_count = subscriptions.len();
295        let firewall_status = *self.firewall_status.read().await;
296
297        let mut service_counts = HashMap::new();
298        let mut polling_count = 0;
299        let mut renewal_count = 0;
300
301        for wrapper in subscriptions.values() {
302            *service_counts
303                .entry(wrapper.speaker_service_pair.service)
304                .or_insert(0) += 1;
305
306            if wrapper.is_polling_active() {
307                polling_count += 1;
308            }
309
310            renewal_count += wrapper.renewal_count().await;
311        }
312
313        SubscriptionStats {
314            total_subscriptions: total_count,
315            service_breakdown: service_counts,
316            polling_active_count: polling_count,
317            total_renewals: renewal_count,
318            firewall_status,
319        }
320    }
321
322    /// Shutdown all subscriptions
323    pub async fn shutdown(&self) -> SubscriptionResult<()> {
324        let mut subscriptions = self.active_subscriptions.write().await;
325
326        for (registration_id, wrapper) in subscriptions.drain() {
327            match wrapper.unsubscribe().await {
328                Ok(()) => {
329                    eprintln!("✅ Unsubscribed {registration_id}");
330                }
331                Err(e) => {
332                    eprintln!("❌ Failed to unsubscribe {registration_id}: {e}");
333                }
334            }
335        }
336
337        Ok(())
338    }
339}
340
341/// Statistics about subscription manager state
342#[derive(Debug)]
343pub struct SubscriptionStats {
344    pub total_subscriptions: usize,
345    pub service_breakdown: HashMap<Service, usize>,
346    pub polling_active_count: usize,
347    pub total_renewals: u32,
348    pub firewall_status: FirewallStatus,
349}
350
351impl std::fmt::Display for SubscriptionStats {
352    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353        writeln!(f, "Subscription Manager Stats:")?;
354        writeln!(f, "  Total subscriptions: {}", self.total_subscriptions)?;
355        writeln!(f, "  Firewall status: {:?}", self.firewall_status)?;
356        writeln!(f, "  Polling active: {}", self.polling_active_count)?;
357        writeln!(f, "  Total renewals: {}", self.total_renewals)?;
358        writeln!(f, "  Service breakdown:")?;
359        for (service, count) in &self.service_breakdown {
360            writeln!(f, "    {service:?}: {count}")?;
361        }
362        Ok(())
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn test_subscription_wrapper_creation() {
372        // Note: We can't easily test ManagedSubscription creation without actual devices
373        // So we'll test the basic wrapper functionality that doesn't require network calls
374        let _reg_id = RegistrationId::new(1);
375        let pair = SpeakerServicePair::new("192.168.1.100".parse().unwrap(), Service::AVTransport);
376
377        // Basic tests for the pair functionality
378        assert_eq!(pair.speaker_ip.to_string(), "192.168.1.100");
379        assert_eq!(pair.service, Service::AVTransport);
380    }
381
382    #[tokio::test]
383    async fn test_subscription_manager_creation() {
384        let manager = SubscriptionManager::new("http://192.168.1.50:3400/callback".to_string());
385
386        // Test initial state
387        assert_eq!(manager.firewall_status().await, FirewallStatus::Unknown);
388        assert_eq!(manager.list_subscriptions().await.len(), 0);
389
390        // Test firewall status updates
391        manager
392            .set_firewall_status(FirewallStatus::Accessible)
393            .await;
394        assert_eq!(manager.firewall_status().await, FirewallStatus::Accessible);
395    }
396
397    #[tokio::test]
398    async fn test_subscription_stats() {
399        let manager = SubscriptionManager::new("http://192.168.1.50:3400/callback".to_string());
400
401        let stats = manager.stats().await;
402        assert_eq!(stats.total_subscriptions, 0);
403        assert_eq!(stats.polling_active_count, 0);
404        assert_eq!(stats.firewall_status, FirewallStatus::Unknown);
405    }
406}