1use std::collections::HashMap;
7use std::sync::atomic::Ordering;
8use std::sync::Arc;
9use std::time::Instant;
10
11use tokio::sync::{broadcast, RwLock};
12
13use super::config::RegistryConfig;
14use super::entry::{StreamEntry, StreamState, StreamStats};
15use super::error::RegistryError;
16use super::frame::{BroadcastFrame, StreamKey};
17
18pub struct StreamRegistry {
23 streams: RwLock<HashMap<StreamKey, Arc<RwLock<StreamEntry>>>>,
25
26 config: RegistryConfig,
28}
29
30impl StreamRegistry {
31 pub fn new() -> Self {
33 Self::with_config(RegistryConfig::default())
34 }
35
36 pub fn with_config(config: RegistryConfig) -> Self {
38 Self {
39 streams: RwLock::new(HashMap::new()),
40 config,
41 }
42 }
43
44 pub fn config(&self) -> &RegistryConfig {
46 &self.config
47 }
48
49 pub async fn register_publisher(
55 &self,
56 key: &StreamKey,
57 session_id: u64,
58 ) -> Result<(), RegistryError> {
59 let mut streams = self.streams.write().await;
60
61 if let Some(entry_arc) = streams.get(key) {
62 let mut entry = entry_arc.write().await;
63
64 match entry.state {
66 StreamState::Active if entry.publisher_id.is_some() => {
67 return Err(RegistryError::StreamAlreadyPublishing(key.clone()));
68 }
69 StreamState::GracePeriod | StreamState::Idle | StreamState::Active => {
70 entry.publisher_id = Some(session_id);
72 entry.publisher_disconnected_at = None;
73 entry.state = StreamState::Active;
74
75 tracing::info!(
76 stream = %key,
77 session_id = session_id,
78 subscribers = entry.subscriber_count(),
79 "Publisher registered (existing stream)"
80 );
81 }
82 }
83 } else {
84 let mut entry = StreamEntry::new(&self.config);
86 entry.publisher_id = Some(session_id);
87 entry.state = StreamState::Active;
88
89 streams.insert(key.clone(), Arc::new(RwLock::new(entry)));
90
91 tracing::info!(
92 stream = %key,
93 session_id = session_id,
94 "Publisher registered (new stream)"
95 );
96 }
97
98 Ok(())
99 }
100
101 pub async fn unregister_publisher(&self, key: &StreamKey, session_id: u64) {
106 let streams = self.streams.read().await;
107
108 if let Some(entry_arc) = streams.get(key) {
109 let mut entry = entry_arc.write().await;
110
111 if entry.publisher_id != Some(session_id) {
113 tracing::warn!(
114 stream = %key,
115 expected = ?entry.publisher_id,
116 actual = session_id,
117 "Publisher unregister mismatch"
118 );
119 return;
120 }
121
122 entry.publisher_id = None;
123 entry.publisher_disconnected_at = Some(Instant::now());
124
125 if entry.subscriber_count() > 0 {
127 entry.state = StreamState::GracePeriod;
128 tracing::info!(
129 stream = %key,
130 session_id = session_id,
131 subscribers = entry.subscriber_count(),
132 grace_period_secs = self.config.publisher_grace_period.as_secs(),
133 "Publisher disconnected, entering grace period"
134 );
135 } else {
136 entry.state = StreamState::Idle;
137 tracing::info!(
138 stream = %key,
139 session_id = session_id,
140 "Publisher disconnected, no subscribers"
141 );
142 }
143 }
144 }
145
146 pub async fn subscribe(
151 &self,
152 key: &StreamKey,
153 ) -> Result<(broadcast::Receiver<BroadcastFrame>, Vec<BroadcastFrame>), RegistryError> {
154 let streams = self.streams.read().await;
155
156 let entry_arc = streams
157 .get(key)
158 .ok_or_else(|| RegistryError::StreamNotFound(key.clone()))?;
159
160 let entry = entry_arc.read().await;
161
162 if entry.state == StreamState::Idle && entry.publisher_id.is_none() {
164 return Err(RegistryError::StreamNotActive(key.clone()));
165 }
166
167 let rx = entry.subscribe();
169 let catchup = entry.get_catchup_frames();
170
171 entry.subscriber_count.fetch_add(1, Ordering::Relaxed);
173
174 tracing::info!(
175 stream = %key,
176 subscribers = entry.subscriber_count(),
177 catchup_frames = catchup.len(),
178 "Subscriber added"
179 );
180
181 Ok((rx, catchup))
182 }
183
184 pub async fn unsubscribe(&self, key: &StreamKey) {
186 let streams = self.streams.read().await;
187
188 if let Some(entry_arc) = streams.get(key) {
189 let entry = entry_arc.read().await;
190 let prev = entry.subscriber_count.fetch_sub(1, Ordering::Relaxed);
191
192 tracing::debug!(
193 stream = %key,
194 subscribers = prev.saturating_sub(1),
195 "Subscriber removed"
196 );
197 }
198 }
199
200 pub async fn broadcast(&self, key: &StreamKey, frame: BroadcastFrame) {
204 let streams = self.streams.read().await;
205
206 if let Some(entry_arc) = streams.get(key) {
207 let mut entry = entry_arc.write().await;
208
209 entry.update_caches(&frame);
211
212 let _ = entry.send(frame);
215 }
216 }
217
218 pub async fn get_sequence_headers(&self, key: &StreamKey) -> Vec<BroadcastFrame> {
222 let streams = self.streams.read().await;
223
224 if let Some(entry_arc) = streams.get(key) {
225 let entry = entry_arc.read().await;
226 let mut frames = Vec::with_capacity(2);
227
228 if let Some(ref video) = entry.video_header {
229 frames.push(video.clone());
230 }
231 if let Some(ref audio) = entry.audio_header {
232 frames.push(audio.clone());
233 }
234
235 frames
236 } else {
237 Vec::new()
238 }
239 }
240
241 pub async fn has_active_stream(&self, key: &StreamKey) -> bool {
243 let streams = self.streams.read().await;
244
245 if let Some(entry_arc) = streams.get(key) {
246 let entry = entry_arc.read().await;
247 entry.state == StreamState::Active && entry.publisher_id.is_some()
248 } else {
249 false
250 }
251 }
252
253 pub async fn stream_exists(&self, key: &StreamKey) -> bool {
255 let streams = self.streams.read().await;
256
257 if let Some(entry_arc) = streams.get(key) {
258 let entry = entry_arc.read().await;
259 matches!(entry.state, StreamState::Active | StreamState::GracePeriod)
260 } else {
261 false
262 }
263 }
264
265 pub async fn get_stream_stats(&self, key: &StreamKey) -> Option<StreamStats> {
267 let streams = self.streams.read().await;
268
269 if let Some(entry_arc) = streams.get(key) {
270 let entry = entry_arc.read().await;
271 Some(StreamStats {
272 subscriber_count: entry.subscriber_count(),
273 has_publisher: entry.publisher_id.is_some(),
274 state: entry.state,
275 gop_frame_count: entry.gop_buffer.frame_count(),
276 gop_size_bytes: entry.gop_buffer.size(),
277 })
278 } else {
279 None
280 }
281 }
282
283 pub async fn stream_count(&self) -> usize {
285 self.streams.read().await.len()
286 }
287
288 pub async fn cleanup(&self) {
294 let mut streams = self.streams.write().await;
295 let now = Instant::now();
296
297 let keys_to_remove: Vec<StreamKey> = streams
298 .iter()
299 .filter_map(|(key, entry_arc)| {
300 if let Ok(entry) = entry_arc.try_read() {
302 let should_remove = match entry.state {
303 StreamState::GracePeriod => {
304 if let Some(disconnected_at) = entry.publisher_disconnected_at {
305 now.duration_since(disconnected_at)
306 > self.config.publisher_grace_period
307 } else {
308 false
309 }
310 }
311 StreamState::Idle => {
312 if let Some(disconnected_at) = entry.publisher_disconnected_at {
313 now.duration_since(disconnected_at)
314 > self.config.idle_stream_timeout
315 } else {
316 now.duration_since(entry.created_at)
317 > self.config.idle_stream_timeout
318 }
319 }
320 StreamState::Active => false,
321 };
322
323 if should_remove {
324 Some(key.clone())
325 } else {
326 None
327 }
328 } else {
329 None
330 }
331 })
332 .collect();
333
334 for key in keys_to_remove {
335 streams.remove(&key);
336 tracing::info!(stream = %key, "Stream removed by cleanup");
337 }
338 }
339
340 pub fn spawn_cleanup_task(self: &Arc<Self>) -> tokio::task::JoinHandle<()> {
344 let registry = Arc::clone(self);
345 let interval = registry.config.cleanup_interval;
346
347 tokio::spawn(async move {
348 let mut ticker = tokio::time::interval(interval);
349 loop {
350 ticker.tick().await;
351 registry.cleanup().await;
352 }
353 })
354 }
355}
356
357impl Default for StreamRegistry {
358 fn default() -> Self {
359 Self::new()
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use bytes::Bytes;
366
367 use super::*;
368
369 #[tokio::test]
370 async fn test_register_publisher() {
371 let registry = StreamRegistry::new();
372 let key = StreamKey::new("live", "test_stream");
373
374 registry.register_publisher(&key, 1).await.unwrap();
376 assert!(registry.has_active_stream(&key).await);
377
378 let result = registry.register_publisher(&key, 2).await;
380 assert!(matches!(
381 result,
382 Err(RegistryError::StreamAlreadyPublishing(_))
383 ));
384 }
385
386 #[tokio::test]
387 async fn test_subscribe_unsubscribe() {
388 let registry = StreamRegistry::new();
389 let key = StreamKey::new("live", "test_stream");
390
391 registry.register_publisher(&key, 1).await.unwrap();
393
394 let (mut rx, catchup) = registry.subscribe(&key).await.unwrap();
396 assert!(catchup.is_empty()); let frame = BroadcastFrame::video(0, Bytes::from_static(&[0x17, 0x01]), true, false);
400 registry.broadcast(&key, frame.clone()).await;
401
402 let received = rx.recv().await.unwrap();
404 assert_eq!(received.timestamp, 0);
405 assert!(received.is_keyframe);
406
407 registry.unsubscribe(&key).await;
409
410 let stats = registry.get_stream_stats(&key).await.unwrap();
411 assert_eq!(stats.subscriber_count, 0);
412 }
413
414 #[tokio::test]
415 async fn test_grace_period() {
416 let config =
417 RegistryConfig::default().publisher_grace_period(std::time::Duration::from_millis(100));
418 let registry = StreamRegistry::with_config(config);
419 let key = StreamKey::new("live", "test_stream");
420
421 registry.register_publisher(&key, 1).await.unwrap();
423 let (_rx, _) = registry.subscribe(&key).await.unwrap();
424
425 registry.unregister_publisher(&key, 1).await;
427
428 let stats = registry.get_stream_stats(&key).await.unwrap();
430 assert_eq!(stats.state, StreamState::GracePeriod);
431
432 assert!(registry.stream_exists(&key).await);
434
435 let result = registry.subscribe(&key).await;
437 assert!(result.is_ok());
438 }
439
440 #[tokio::test]
441 async fn test_publisher_reconnect() {
442 let registry = StreamRegistry::new();
443 let key = StreamKey::new("live", "test_stream");
444
445 registry.register_publisher(&key, 1).await.unwrap();
447
448 let (_rx, _) = registry.subscribe(&key).await.unwrap();
450
451 registry.unregister_publisher(&key, 1).await;
453
454 registry.register_publisher(&key, 2).await.unwrap();
456
457 let stats = registry.get_stream_stats(&key).await.unwrap();
458 assert!(stats.has_publisher);
459 assert_eq!(stats.state, StreamState::Active);
460 assert_eq!(stats.subscriber_count, 1); }
462
463 #[tokio::test]
464 async fn test_catchup_frames() {
465 let registry = StreamRegistry::new();
466 let key = StreamKey::new("live", "test_stream");
467
468 registry.register_publisher(&key, 1).await.unwrap();
469
470 let video_header = BroadcastFrame::video(0, Bytes::from_static(&[0x17, 0x00]), true, true);
472 let audio_header = BroadcastFrame::audio(0, Bytes::from_static(&[0xAF, 0x00]), true);
473 registry.broadcast(&key, video_header).await;
474 registry.broadcast(&key, audio_header).await;
475
476 let keyframe = BroadcastFrame::video(33, Bytes::from_static(&[0x17, 0x01]), true, false);
478 registry.broadcast(&key, keyframe).await;
479
480 let (_rx, catchup) = registry.subscribe(&key).await.unwrap();
482
483 assert_eq!(catchup.len(), 3);
485 assert!(catchup[0].is_header); assert!(catchup[1].is_header); assert!(catchup[2].is_keyframe); }
489}