rtmp_rs/registry/
store.rs

1//! Stream registry implementation
2//!
3//! The central registry that manages all active streams and routes media
4//! from publishers to subscribers.
5
6use std::collections::HashMap;
7use std::sync::atomic::Ordering;
8use std::sync::Arc;
9use std::time::Instant;
10
11use tokio::sync::{broadcast, RwLock};
12
13use super::config::RegistryConfig;
14use super::entry::{StreamEntry, StreamState, StreamStats};
15use super::error::RegistryError;
16use super::frame::{BroadcastFrame, StreamKey};
17
18/// Central registry for all active streams
19///
20/// Thread-safe via `RwLock`. Read-heavy workloads (subscriber count checks,
21/// broadcasting) benefit from the concurrent read access.
22pub struct StreamRegistry {
23    /// Map of stream key to stream entry
24    streams: RwLock<HashMap<StreamKey, Arc<RwLock<StreamEntry>>>>,
25
26    /// Configuration
27    config: RegistryConfig,
28}
29
30impl StreamRegistry {
31    /// Create a new stream registry with default configuration
32    pub fn new() -> Self {
33        Self::with_config(RegistryConfig::default())
34    }
35
36    /// Create a new stream registry with custom configuration
37    pub fn with_config(config: RegistryConfig) -> Self {
38        Self {
39            streams: RwLock::new(HashMap::new()),
40            config,
41        }
42    }
43
44    /// Get the registry configuration
45    pub fn config(&self) -> &RegistryConfig {
46        &self.config
47    }
48
49    /// Register a publisher for a stream
50    ///
51    /// If the stream doesn't exist, it will be created.
52    /// If the stream exists and is in grace period, the publisher reclaims it.
53    /// Returns an error if the stream already has an active publisher.
54    pub async fn register_publisher(
55        &self,
56        key: &StreamKey,
57        session_id: u64,
58    ) -> Result<(), RegistryError> {
59        let mut streams = self.streams.write().await;
60
61        if let Some(entry_arc) = streams.get(key) {
62            let mut entry = entry_arc.write().await;
63
64            // Check if stream is available for publishing
65            match entry.state {
66                StreamState::Active if entry.publisher_id.is_some() => {
67                    return Err(RegistryError::StreamAlreadyPublishing(key.clone()));
68                }
69                StreamState::GracePeriod | StreamState::Idle | StreamState::Active => {
70                    // Reclaim or take over the stream
71                    entry.publisher_id = Some(session_id);
72                    entry.publisher_disconnected_at = None;
73                    entry.state = StreamState::Active;
74
75                    tracing::info!(
76                        stream = %key,
77                        session_id = session_id,
78                        subscribers = entry.subscriber_count(),
79                        "Publisher registered (existing stream)"
80                    );
81                }
82            }
83        } else {
84            // Create new stream entry
85            let mut entry = StreamEntry::new(&self.config);
86            entry.publisher_id = Some(session_id);
87            entry.state = StreamState::Active;
88
89            streams.insert(key.clone(), Arc::new(RwLock::new(entry)));
90
91            tracing::info!(
92                stream = %key,
93                session_id = session_id,
94                "Publisher registered (new stream)"
95            );
96        }
97
98        Ok(())
99    }
100
101    /// Unregister a publisher from a stream
102    ///
103    /// The stream enters grace period if there are active subscribers,
104    /// allowing the publisher to reconnect.
105    pub async fn unregister_publisher(&self, key: &StreamKey, session_id: u64) {
106        let streams = self.streams.read().await;
107
108        if let Some(entry_arc) = streams.get(key) {
109            let mut entry = entry_arc.write().await;
110
111            // Verify this is the actual publisher
112            if entry.publisher_id != Some(session_id) {
113                tracing::warn!(
114                    stream = %key,
115                    expected = ?entry.publisher_id,
116                    actual = session_id,
117                    "Publisher unregister mismatch"
118                );
119                return;
120            }
121
122            entry.publisher_id = None;
123            entry.publisher_disconnected_at = Some(Instant::now());
124
125            // If there are subscribers, enter grace period; otherwise go idle
126            if entry.subscriber_count() > 0 {
127                entry.state = StreamState::GracePeriod;
128                tracing::info!(
129                    stream = %key,
130                    session_id = session_id,
131                    subscribers = entry.subscriber_count(),
132                    grace_period_secs = self.config.publisher_grace_period.as_secs(),
133                    "Publisher disconnected, entering grace period"
134                );
135            } else {
136                entry.state = StreamState::Idle;
137                tracing::info!(
138                    stream = %key,
139                    session_id = session_id,
140                    "Publisher disconnected, no subscribers"
141                );
142            }
143        }
144    }
145
146    /// Subscribe to a stream
147    ///
148    /// Returns a broadcast receiver and catchup frames for the subscriber.
149    /// The catchup frames contain sequence headers and recent GOP data.
150    pub async fn subscribe(
151        &self,
152        key: &StreamKey,
153    ) -> Result<(broadcast::Receiver<BroadcastFrame>, Vec<BroadcastFrame>), RegistryError> {
154        let streams = self.streams.read().await;
155
156        let entry_arc = streams
157            .get(key)
158            .ok_or_else(|| RegistryError::StreamNotFound(key.clone()))?;
159
160        let entry = entry_arc.read().await;
161
162        // Allow subscription even during grace period (publisher might reconnect)
163        if entry.state == StreamState::Idle && entry.publisher_id.is_none() {
164            return Err(RegistryError::StreamNotActive(key.clone()));
165        }
166
167        // Get receiver and catchup frames
168        let rx = entry.subscribe();
169        let catchup = entry.get_catchup_frames();
170
171        // Increment subscriber count
172        entry.subscriber_count.fetch_add(1, Ordering::Relaxed);
173
174        tracing::info!(
175            stream = %key,
176            subscribers = entry.subscriber_count(),
177            catchup_frames = catchup.len(),
178            "Subscriber added"
179        );
180
181        Ok((rx, catchup))
182    }
183
184    /// Unsubscribe from a stream
185    pub async fn unsubscribe(&self, key: &StreamKey) {
186        let streams = self.streams.read().await;
187
188        if let Some(entry_arc) = streams.get(key) {
189            let entry = entry_arc.read().await;
190            let prev = entry.subscriber_count.fetch_sub(1, Ordering::Relaxed);
191
192            tracing::debug!(
193                stream = %key,
194                subscribers = prev.saturating_sub(1),
195                "Subscriber removed"
196            );
197        }
198    }
199
200    /// Broadcast a frame to all subscribers of a stream
201    ///
202    /// Also updates the GOP buffer and sequence headers as needed.
203    pub async fn broadcast(&self, key: &StreamKey, frame: BroadcastFrame) {
204        let streams = self.streams.read().await;
205
206        if let Some(entry_arc) = streams.get(key) {
207            let mut entry = entry_arc.write().await;
208
209            // Update cached headers and GOP buffer
210            entry.update_caches(&frame);
211
212            // Broadcast to subscribers
213            // Note: send() returns Ok(n) where n is number of receivers, or Err if no receivers
214            let _ = entry.send(frame);
215        }
216    }
217
218    /// Get sequence headers for a stream (video and audio decoder config)
219    ///
220    /// Used when resuming playback after pause to reinitialize decoders.
221    pub async fn get_sequence_headers(&self, key: &StreamKey) -> Vec<BroadcastFrame> {
222        let streams = self.streams.read().await;
223
224        if let Some(entry_arc) = streams.get(key) {
225            let entry = entry_arc.read().await;
226            let mut frames = Vec::with_capacity(2);
227
228            if let Some(ref video) = entry.video_header {
229                frames.push(video.clone());
230            }
231            if let Some(ref audio) = entry.audio_header {
232                frames.push(audio.clone());
233            }
234
235            frames
236        } else {
237            Vec::new()
238        }
239    }
240
241    /// Check if a stream exists and has an active publisher
242    pub async fn has_active_stream(&self, key: &StreamKey) -> bool {
243        let streams = self.streams.read().await;
244
245        if let Some(entry_arc) = streams.get(key) {
246            let entry = entry_arc.read().await;
247            entry.state == StreamState::Active && entry.publisher_id.is_some()
248        } else {
249            false
250        }
251    }
252
253    /// Check if a stream exists (active or in grace period)
254    pub async fn stream_exists(&self, key: &StreamKey) -> bool {
255        let streams = self.streams.read().await;
256
257        if let Some(entry_arc) = streams.get(key) {
258            let entry = entry_arc.read().await;
259            matches!(entry.state, StreamState::Active | StreamState::GracePeriod)
260        } else {
261            false
262        }
263    }
264
265    /// Get stream statistics
266    pub async fn get_stream_stats(&self, key: &StreamKey) -> Option<StreamStats> {
267        let streams = self.streams.read().await;
268
269        if let Some(entry_arc) = streams.get(key) {
270            let entry = entry_arc.read().await;
271            Some(StreamStats {
272                subscriber_count: entry.subscriber_count(),
273                has_publisher: entry.publisher_id.is_some(),
274                state: entry.state,
275                gop_frame_count: entry.gop_buffer.frame_count(),
276                gop_size_bytes: entry.gop_buffer.size(),
277            })
278        } else {
279            None
280        }
281    }
282
283    /// Get total number of streams
284    pub async fn stream_count(&self) -> usize {
285        self.streams.read().await.len()
286    }
287
288    /// Run cleanup task once
289    ///
290    /// Removes streams that have:
291    /// - Been in grace period longer than `publisher_grace_period`
292    /// - Been idle longer than `idle_stream_timeout`
293    pub async fn cleanup(&self) {
294        let mut streams = self.streams.write().await;
295        let now = Instant::now();
296
297        let keys_to_remove: Vec<StreamKey> = streams
298            .iter()
299            .filter_map(|(key, entry_arc)| {
300                // Try to get read lock without blocking
301                if let Ok(entry) = entry_arc.try_read() {
302                    let should_remove = match entry.state {
303                        StreamState::GracePeriod => {
304                            if let Some(disconnected_at) = entry.publisher_disconnected_at {
305                                now.duration_since(disconnected_at)
306                                    > self.config.publisher_grace_period
307                            } else {
308                                false
309                            }
310                        }
311                        StreamState::Idle => {
312                            if let Some(disconnected_at) = entry.publisher_disconnected_at {
313                                now.duration_since(disconnected_at)
314                                    > self.config.idle_stream_timeout
315                            } else {
316                                now.duration_since(entry.created_at)
317                                    > self.config.idle_stream_timeout
318                            }
319                        }
320                        StreamState::Active => false,
321                    };
322
323                    if should_remove {
324                        Some(key.clone())
325                    } else {
326                        None
327                    }
328                } else {
329                    None
330                }
331            })
332            .collect();
333
334        for key in keys_to_remove {
335            streams.remove(&key);
336            tracing::info!(stream = %key, "Stream removed by cleanup");
337        }
338    }
339
340    /// Spawn background cleanup task
341    ///
342    /// Returns a handle that can be used to abort the task.
343    pub fn spawn_cleanup_task(self: &Arc<Self>) -> tokio::task::JoinHandle<()> {
344        let registry = Arc::clone(self);
345        let interval = registry.config.cleanup_interval;
346
347        tokio::spawn(async move {
348            let mut ticker = tokio::time::interval(interval);
349            loop {
350                ticker.tick().await;
351                registry.cleanup().await;
352            }
353        })
354    }
355}
356
357impl Default for StreamRegistry {
358    fn default() -> Self {
359        Self::new()
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use bytes::Bytes;
366
367    use super::*;
368
369    #[tokio::test]
370    async fn test_register_publisher() {
371        let registry = StreamRegistry::new();
372        let key = StreamKey::new("live", "test_stream");
373
374        // Register publisher
375        registry.register_publisher(&key, 1).await.unwrap();
376        assert!(registry.has_active_stream(&key).await);
377
378        // Can't register another publisher
379        let result = registry.register_publisher(&key, 2).await;
380        assert!(matches!(
381            result,
382            Err(RegistryError::StreamAlreadyPublishing(_))
383        ));
384    }
385
386    #[tokio::test]
387    async fn test_subscribe_unsubscribe() {
388        let registry = StreamRegistry::new();
389        let key = StreamKey::new("live", "test_stream");
390
391        // Need a publisher first
392        registry.register_publisher(&key, 1).await.unwrap();
393
394        // Subscribe
395        let (mut rx, catchup) = registry.subscribe(&key).await.unwrap();
396        assert!(catchup.is_empty()); // No data yet
397
398        // Broadcast a frame
399        let frame = BroadcastFrame::video(0, Bytes::from_static(&[0x17, 0x01]), true, false);
400        registry.broadcast(&key, frame.clone()).await;
401
402        // Receive the frame
403        let received = rx.recv().await.unwrap();
404        assert_eq!(received.timestamp, 0);
405        assert!(received.is_keyframe);
406
407        // Unsubscribe
408        registry.unsubscribe(&key).await;
409
410        let stats = registry.get_stream_stats(&key).await.unwrap();
411        assert_eq!(stats.subscriber_count, 0);
412    }
413
414    #[tokio::test]
415    async fn test_grace_period() {
416        let config =
417            RegistryConfig::default().publisher_grace_period(std::time::Duration::from_millis(100));
418        let registry = StreamRegistry::with_config(config);
419        let key = StreamKey::new("live", "test_stream");
420
421        // Register publisher and subscriber
422        registry.register_publisher(&key, 1).await.unwrap();
423        let (_rx, _) = registry.subscribe(&key).await.unwrap();
424
425        // Publisher disconnects
426        registry.unregister_publisher(&key, 1).await;
427
428        // Stream should be in grace period
429        let stats = registry.get_stream_stats(&key).await.unwrap();
430        assert_eq!(stats.state, StreamState::GracePeriod);
431
432        // Stream still exists
433        assert!(registry.stream_exists(&key).await);
434
435        // New subscriber can still join
436        let result = registry.subscribe(&key).await;
437        assert!(result.is_ok());
438    }
439
440    #[tokio::test]
441    async fn test_publisher_reconnect() {
442        let registry = StreamRegistry::new();
443        let key = StreamKey::new("live", "test_stream");
444
445        // Register publisher
446        registry.register_publisher(&key, 1).await.unwrap();
447
448        // Add subscriber
449        let (_rx, _) = registry.subscribe(&key).await.unwrap();
450
451        // Publisher disconnects
452        registry.unregister_publisher(&key, 1).await;
453
454        // New publisher takes over
455        registry.register_publisher(&key, 2).await.unwrap();
456
457        let stats = registry.get_stream_stats(&key).await.unwrap();
458        assert!(stats.has_publisher);
459        assert_eq!(stats.state, StreamState::Active);
460        assert_eq!(stats.subscriber_count, 1); // Subscriber still there
461    }
462
463    #[tokio::test]
464    async fn test_catchup_frames() {
465        let registry = StreamRegistry::new();
466        let key = StreamKey::new("live", "test_stream");
467
468        registry.register_publisher(&key, 1).await.unwrap();
469
470        // Broadcast sequence headers
471        let video_header = BroadcastFrame::video(0, Bytes::from_static(&[0x17, 0x00]), true, true);
472        let audio_header = BroadcastFrame::audio(0, Bytes::from_static(&[0xAF, 0x00]), true);
473        registry.broadcast(&key, video_header).await;
474        registry.broadcast(&key, audio_header).await;
475
476        // Broadcast a keyframe
477        let keyframe = BroadcastFrame::video(33, Bytes::from_static(&[0x17, 0x01]), true, false);
478        registry.broadcast(&key, keyframe).await;
479
480        // Late joiner subscribes
481        let (_rx, catchup) = registry.subscribe(&key).await.unwrap();
482
483        // Should receive headers + keyframe
484        assert_eq!(catchup.len(), 3);
485        assert!(catchup[0].is_header); // video header
486        assert!(catchup[1].is_header); // audio header
487        assert!(catchup[2].is_keyframe); // keyframe
488    }
489}