1use std::collections::HashMap;
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::sync::RwLock;
30use std::time::{Duration, Instant};
31
32use tokio::sync::broadcast;
33
34use crate::subscription::event::ChangeEvent;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
44pub struct SubscriptionId(pub u64);
45
46impl std::fmt::Display for SubscriptionId {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 write!(f, "sub-{}", self.0)
49 }
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum SubscriptionState {
59 Active,
61 Paused,
63 Cancelled,
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum BackpressureStrategy {
74 DropOldest,
76 DropNewest,
78 Block,
80 Sample(usize),
82}
83
84#[derive(Debug, Clone)]
90pub struct SubscriptionConfig {
91 pub buffer_size: usize,
93 pub backpressure: BackpressureStrategy,
95 pub filter: Option<String>,
97 pub send_snapshot: bool,
99 pub max_batch_size: usize,
101 pub max_batch_delay_us: u64,
103}
104
105impl Default for SubscriptionConfig {
106 fn default() -> Self {
107 Self {
108 buffer_size: 1024,
109 backpressure: BackpressureStrategy::DropOldest,
110 filter: None,
111 send_snapshot: false,
112 max_batch_size: 64,
113 max_batch_delay_us: 100,
114 }
115 }
116}
117
118#[derive(Debug)]
124pub struct SubscriptionEntry {
125 pub id: SubscriptionId,
127 pub source_name: String,
129 pub source_id: u32,
131 pub state: SubscriptionState,
133 pub config: SubscriptionConfig,
135 pub sender: broadcast::Sender<ChangeEvent>,
137 pub created_at: Instant,
139 pub events_delivered: u64,
141 pub events_dropped: u64,
143 pub current_lag: u64,
145}
146
147#[derive(Debug, Clone)]
153pub struct SubscriptionMetrics {
154 pub id: SubscriptionId,
156 pub source_name: String,
158 pub state: SubscriptionState,
160 pub events_delivered: u64,
162 pub events_dropped: u64,
164 pub current_lag: u64,
166 pub age: Duration,
168}
169
170pub struct SubscriptionRegistry {
191 subscriptions: RwLock<HashMap<SubscriptionId, SubscriptionEntry>>,
193 by_source: RwLock<HashMap<u32, Vec<SubscriptionId>>>,
195 by_name: RwLock<HashMap<String, Vec<SubscriptionId>>>,
197 next_id: AtomicU64,
199}
200
201#[allow(clippy::missing_panics_doc)] impl SubscriptionRegistry {
203 #[must_use]
205 pub fn new() -> Self {
206 Self {
207 subscriptions: RwLock::new(HashMap::new()),
208 by_source: RwLock::new(HashMap::new()),
209 by_name: RwLock::new(HashMap::new()),
210 next_id: AtomicU64::new(1),
211 }
212 }
213
214 pub fn create(
225 &self,
226 source_name: String,
227 source_id: u32,
228 config: SubscriptionConfig,
229 ) -> (SubscriptionId, broadcast::Receiver<ChangeEvent>) {
230 let id = SubscriptionId(self.next_id.fetch_add(1, Ordering::Relaxed));
231 let (tx, rx) = broadcast::channel(config.buffer_size);
232
233 let entry = SubscriptionEntry {
234 id,
235 source_name: source_name.clone(),
236 source_id,
237 state: SubscriptionState::Active,
238 config,
239 sender: tx,
240 created_at: Instant::now(),
241 events_delivered: 0,
242 events_dropped: 0,
243 current_lag: 0,
244 };
245
246 self.subscriptions.write().unwrap().insert(id, entry);
248
249 self.by_source
251 .write()
252 .unwrap()
253 .entry(source_id)
254 .or_default()
255 .push(id);
256
257 self.by_name
259 .write()
260 .unwrap()
261 .entry(source_name)
262 .or_default()
263 .push(id);
264
265 (id, rx)
266 }
267
268 pub fn pause(&self, id: SubscriptionId) -> bool {
273 let mut subs = self.subscriptions.write().unwrap();
274 if let Some(entry) = subs.get_mut(&id) {
275 if entry.state == SubscriptionState::Active {
276 entry.state = SubscriptionState::Paused;
277 return true;
278 }
279 }
280 false
281 }
282
283 pub fn resume(&self, id: SubscriptionId) -> bool {
288 let mut subs = self.subscriptions.write().unwrap();
289 if let Some(entry) = subs.get_mut(&id) {
290 if entry.state == SubscriptionState::Paused {
291 entry.state = SubscriptionState::Active;
292 return true;
293 }
294 }
295 false
296 }
297
298 pub fn cancel(&self, id: SubscriptionId) -> bool {
302 let entry = self.subscriptions.write().unwrap().remove(&id);
303
304 if let Some(entry) = entry {
305 if let Some(ids) = self.by_source.write().unwrap().get_mut(&entry.source_id) {
307 ids.retain(|&i| i != id);
308 }
309
310 if let Some(ids) = self.by_name.write().unwrap().get_mut(&entry.source_name) {
312 ids.retain(|&i| i != id);
313 }
314
315 true
316 } else {
317 false
318 }
319 }
320
321 #[must_use]
326 pub fn get_senders_for_source(&self, source_id: u32) -> Vec<broadcast::Sender<ChangeEvent>> {
327 let by_source = self.by_source.read().unwrap();
328 let Some(ids) = by_source.get(&source_id) else {
329 return Vec::new();
330 };
331
332 let subs = self.subscriptions.read().unwrap();
333 ids.iter()
334 .filter_map(|id| {
335 subs.get(id).and_then(|entry| {
336 if entry.state == SubscriptionState::Active {
337 Some(entry.sender.clone())
338 } else {
339 None
340 }
341 })
342 })
343 .collect()
344 }
345
346 #[must_use]
348 pub fn get_subscriptions_by_name(&self, name: &str) -> Vec<SubscriptionId> {
349 let by_name = self.by_name.read().unwrap();
350 by_name.get(name).cloned().unwrap_or_default()
351 }
352
353 #[must_use]
355 pub fn subscription_count(&self) -> usize {
356 self.subscriptions.read().unwrap().len()
357 }
358
359 #[must_use]
361 pub fn active_count(&self) -> usize {
362 self.subscriptions
363 .read()
364 .unwrap()
365 .values()
366 .filter(|e| e.state == SubscriptionState::Active)
367 .count()
368 }
369
370 #[must_use]
372 pub fn metrics(&self, id: SubscriptionId) -> Option<SubscriptionMetrics> {
373 let subs = self.subscriptions.read().unwrap();
374 subs.get(&id).map(|entry| SubscriptionMetrics {
375 id: entry.id,
376 source_name: entry.source_name.clone(),
377 state: entry.state,
378 events_delivered: entry.events_delivered,
379 events_dropped: entry.events_dropped,
380 current_lag: entry.current_lag,
381 age: entry.created_at.elapsed(),
382 })
383 }
384
385 #[must_use]
387 pub fn state(&self, id: SubscriptionId) -> Option<SubscriptionState> {
388 self.subscriptions.read().unwrap().get(&id).map(|e| e.state)
389 }
390
391 pub fn record_delivery(&self, id: SubscriptionId, count: u64) {
395 if let Some(entry) = self.subscriptions.write().unwrap().get_mut(&id) {
396 entry.events_delivered += count;
397 }
398 }
399
400 pub fn record_drop(&self, id: SubscriptionId, count: u64) {
404 if let Some(entry) = self.subscriptions.write().unwrap().get_mut(&id) {
405 entry.events_dropped += count;
406 }
407 }
408}
409
410impl Default for SubscriptionRegistry {
411 fn default() -> Self {
412 Self::new()
413 }
414}
415
416#[cfg(test)]
421#[allow(clippy::cast_possible_wrap)]
422mod tests {
423 use super::*;
424 use std::sync::Arc;
425
426 use arrow_array::Int64Array;
427 use arrow_schema::{DataType, Field, Schema};
428
429 fn make_batch(n: usize) -> arrow_array::RecordBatch {
430 let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int64, false)]));
431 let values: Vec<i64> = (0..n as i64).collect();
432 let array = Int64Array::from(values);
433 arrow_array::RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
434 }
435
436 #[test]
439 fn test_registry_config_default() {
440 let cfg = SubscriptionConfig::default();
441 assert_eq!(cfg.buffer_size, 1024);
442 assert_eq!(cfg.backpressure, BackpressureStrategy::DropOldest);
443 assert!(cfg.filter.is_none());
444 assert!(!cfg.send_snapshot);
445 assert_eq!(cfg.max_batch_size, 64);
446 assert_eq!(cfg.max_batch_delay_us, 100);
447 }
448
449 #[test]
452 fn test_registry_create() {
453 let reg = SubscriptionRegistry::new();
454 let (id, _rx) = reg.create("mv_orders".into(), 0, SubscriptionConfig::default());
455 assert_eq!(id.0, 1);
456 assert_eq!(reg.subscription_count(), 1);
457 assert_eq!(reg.active_count(), 1);
458 }
459
460 #[test]
461 fn test_registry_create_multiple() {
462 let reg = SubscriptionRegistry::new();
463 let (id1, _rx1) = reg.create("mv_orders".into(), 0, SubscriptionConfig::default());
464 let (id2, _rx2) = reg.create("mv_orders".into(), 0, SubscriptionConfig::default());
465 let (id3, _rx3) = reg.create("mv_trades".into(), 1, SubscriptionConfig::default());
466
467 assert_ne!(id1, id2);
468 assert_ne!(id2, id3);
469 assert_eq!(reg.subscription_count(), 3);
470
471 let senders_0 = reg.get_senders_for_source(0);
473 assert_eq!(senders_0.len(), 2);
474 let senders_1 = reg.get_senders_for_source(1);
475 assert_eq!(senders_1.len(), 1);
476 }
477
478 #[test]
481 fn test_registry_pause_resume() {
482 let reg = SubscriptionRegistry::new();
483 let (id, _rx) = reg.create("mv_orders".into(), 0, SubscriptionConfig::default());
484
485 assert!(reg.pause(id));
487 assert_eq!(reg.state(id), Some(SubscriptionState::Paused));
488 assert_eq!(reg.active_count(), 0);
489
490 assert!(!reg.pause(id));
492
493 assert!(reg.resume(id));
495 assert_eq!(reg.state(id), Some(SubscriptionState::Active));
496 assert_eq!(reg.active_count(), 1);
497
498 assert!(!reg.resume(id));
500 }
501
502 #[test]
505 fn test_registry_cancel() {
506 let reg = SubscriptionRegistry::new();
507 let (id, _rx) = reg.create("mv_orders".into(), 0, SubscriptionConfig::default());
508 assert_eq!(reg.subscription_count(), 1);
509
510 assert!(reg.cancel(id));
511 assert_eq!(reg.subscription_count(), 0);
512 assert_eq!(reg.active_count(), 0);
513
514 let senders = reg.get_senders_for_source(0);
516 assert!(senders.is_empty());
517
518 let by_name = reg.get_subscriptions_by_name("mv_orders");
520 assert!(by_name.is_empty());
521 }
522
523 #[test]
524 fn test_registry_cancel_nonexistent() {
525 let reg = SubscriptionRegistry::new();
526 assert!(!reg.cancel(SubscriptionId(999)));
527 }
528
529 #[test]
532 fn test_registry_get_senders() {
533 let reg = SubscriptionRegistry::new();
534 let (_, _rx1) = reg.create("mv_a".into(), 0, SubscriptionConfig::default());
535 let (_, _rx2) = reg.create("mv_b".into(), 0, SubscriptionConfig::default());
536
537 let senders = reg.get_senders_for_source(0);
538 assert_eq!(senders.len(), 2);
539 }
540
541 #[test]
542 fn test_registry_get_senders_paused_excluded() {
543 let reg = SubscriptionRegistry::new();
544 let (id1, _rx1) = reg.create("mv_a".into(), 0, SubscriptionConfig::default());
545 let (_, _rx2) = reg.create("mv_b".into(), 0, SubscriptionConfig::default());
546
547 reg.pause(id1);
548 let senders = reg.get_senders_for_source(0);
549 assert_eq!(senders.len(), 1);
550 }
551
552 #[test]
553 fn test_registry_get_senders_no_source() {
554 let reg = SubscriptionRegistry::new();
555 let senders = reg.get_senders_for_source(42);
556 assert!(senders.is_empty());
557 }
558
559 #[test]
562 fn test_registry_subscription_count() {
563 let reg = SubscriptionRegistry::new();
564 assert_eq!(reg.subscription_count(), 0);
565 assert_eq!(reg.active_count(), 0);
566
567 let (id1, _rx1) = reg.create("mv_a".into(), 0, SubscriptionConfig::default());
568 let (_, _rx2) = reg.create("mv_b".into(), 1, SubscriptionConfig::default());
569 assert_eq!(reg.subscription_count(), 2);
570 assert_eq!(reg.active_count(), 2);
571
572 reg.pause(id1);
573 assert_eq!(reg.subscription_count(), 2);
574 assert_eq!(reg.active_count(), 1);
575 }
576
577 #[test]
578 fn test_registry_metrics() {
579 let reg = SubscriptionRegistry::new();
580 let (id, _rx) = reg.create("mv_orders".into(), 0, SubscriptionConfig::default());
581
582 let m = reg.metrics(id).unwrap();
583 assert_eq!(m.id, id);
584 assert_eq!(m.source_name, "mv_orders");
585 assert_eq!(m.state, SubscriptionState::Active);
586 assert_eq!(m.events_delivered, 0);
587 assert_eq!(m.events_dropped, 0);
588 assert_eq!(m.current_lag, 0);
589
590 assert!(reg.metrics(SubscriptionId(999)).is_none());
592 }
593
594 #[test]
595 fn test_registry_record_delivery_and_drop() {
596 let reg = SubscriptionRegistry::new();
597 let (id, _rx) = reg.create("mv_a".into(), 0, SubscriptionConfig::default());
598
599 reg.record_delivery(id, 10);
600 reg.record_delivery(id, 5);
601 reg.record_drop(id, 2);
602
603 let m = reg.metrics(id).unwrap();
604 assert_eq!(m.events_delivered, 15);
605 assert_eq!(m.events_dropped, 2);
606 }
607
608 #[test]
611 fn test_registry_thread_safety() {
612 let reg = Arc::new(SubscriptionRegistry::new());
613 let mut handles = Vec::new();
614
615 for t in 0..4u32 {
617 let reg = Arc::clone(®);
618 handles.push(std::thread::spawn(move || {
619 let mut ids = Vec::new();
620 for i in 0..100u32 {
621 let name = format!("mv_{t}_{i}");
622 let (id, _rx) = reg.create(name, t, SubscriptionConfig::default());
623 ids.push(id);
624 }
625 ids
626 }));
627 }
628
629 let all_ids: Vec<Vec<SubscriptionId>> =
630 handles.into_iter().map(|h| h.join().unwrap()).collect();
631
632 assert_eq!(reg.subscription_count(), 400);
634
635 let mut flat: Vec<u64> = all_ids.iter().flatten().map(|id| id.0).collect();
637 flat.sort_unstable();
638 flat.dedup();
639 assert_eq!(flat.len(), 400);
640
641 for t in 0..4u32 {
643 let senders = reg.get_senders_for_source(t);
644 assert_eq!(senders.len(), 100);
645 }
646
647 for id in &all_ids[0][..50] {
649 assert!(reg.cancel(*id));
650 }
651 assert_eq!(reg.subscription_count(), 350);
652 assert_eq!(reg.get_senders_for_source(0).len(), 50);
653 }
654
655 #[test]
658 fn test_registry_with_notification_hub() {
659 use crate::subscription::NotificationHub;
660
661 let mut hub = NotificationHub::new(4, 64);
662 let reg = SubscriptionRegistry::new();
663
664 let source_id = hub.register_source().unwrap();
666 let (sub_id, _rx) =
667 reg.create("mv_orders".into(), source_id, SubscriptionConfig::default());
668
669 let senders = reg.get_senders_for_source(source_id);
671 assert_eq!(senders.len(), 1);
672
673 assert!(hub.notify_source(
675 source_id,
676 crate::subscription::EventType::Insert,
677 10,
678 1000,
679 0,
680 ));
681
682 let mut count = 0;
684 hub.drain_notifications(|_n| count += 1);
685 assert_eq!(count, 1);
686
687 reg.cancel(sub_id);
689 assert!(reg.get_senders_for_source(source_id).is_empty());
690 }
691
692 #[test]
693 fn test_registry_broadcast_delivery() {
694 let reg = SubscriptionRegistry::new();
695 let (_, mut rx1) = reg.create("mv_a".into(), 0, SubscriptionConfig::default());
696 let (_, mut rx2) = reg.create("mv_a".into(), 0, SubscriptionConfig::default());
697
698 let senders = reg.get_senders_for_source(0);
700 assert_eq!(senders.len(), 2);
701
702 let batch = Arc::new(make_batch(5));
703 let event = ChangeEvent::insert(batch, 1000, 1);
704
705 for sender in &senders {
706 sender.send(event.clone()).unwrap();
707 }
708
709 let e1 = rx1.try_recv().unwrap();
711 assert_eq!(e1.timestamp(), 1000);
712 assert_eq!(e1.sequence(), Some(1));
713 assert_eq!(e1.row_count(), 5);
714
715 let e2 = rx2.try_recv().unwrap();
716 assert_eq!(e2.timestamp(), 1000);
717 assert_eq!(e2.sequence(), Some(1));
718 }
719
720 #[test]
721 fn test_subscription_id_display() {
722 let id = SubscriptionId(42);
723 assert_eq!(format!("{id}"), "sub-42");
724 }
725}