daedalus_core/channels/
broadcast.rs

1use std::collections::VecDeque;
2use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
3use std::sync::{Arc, Mutex, Weak};
4
5use super::{Backpressure, ChannelRecv, ChannelSend, ChannelStats, CloseBehavior, RecvOutcome};
6
7#[cfg(feature = "metrics")]
8use crate::metrics::MetricsSink;
9
10struct Subscriber<T> {
11    buffer: Mutex<VecDeque<Arc<T>>>,
12}
13
14struct BroadcastInner<T> {
15    subscribers: Mutex<Vec<Weak<Subscriber<T>>>>,
16    closed: AtomicBool,
17    senders: AtomicUsize,
18    receivers: AtomicUsize,
19    capacity: usize,
20    enqueued: AtomicU64,
21    dropped: AtomicU64,
22    drained: AtomicU64,
23    close_behavior: CloseBehavior,
24    #[cfg(feature = "metrics")]
25    metrics: Option<Arc<dyn MetricsSink>>,
26}
27
28impl<T> BroadcastInner<T> {
29    fn new(capacity: usize, close_behavior: CloseBehavior) -> Self {
30        Self {
31            subscribers: Mutex::new(Vec::new()),
32            closed: AtomicBool::new(false),
33            senders: AtomicUsize::new(1),
34            receivers: AtomicUsize::new(1),
35            capacity,
36            enqueued: AtomicU64::new(0),
37            dropped: AtomicU64::new(0),
38            drained: AtomicU64::new(0),
39            close_behavior,
40            #[cfg(feature = "metrics")]
41            metrics: None,
42        }
43    }
44
45    #[cfg(feature = "metrics")]
46    fn new_with_metrics(
47        capacity: usize,
48        close_behavior: CloseBehavior,
49        metrics: Arc<dyn MetricsSink>,
50    ) -> Self {
51        Self {
52            subscribers: Mutex::new(Vec::new()),
53            closed: AtomicBool::new(false),
54            senders: AtomicUsize::new(1),
55            receivers: AtomicUsize::new(1),
56            capacity,
57            enqueued: AtomicU64::new(0),
58            dropped: AtomicU64::new(0),
59            drained: AtomicU64::new(0),
60            close_behavior,
61            metrics: Some(metrics),
62        }
63    }
64
65    fn mark_closed(&self) {
66        self.closed.store(true, Ordering::Release);
67    }
68
69    fn try_close(&self) {
70        match self.close_behavior {
71            CloseBehavior::FailFast => {
72                if self.senders.load(Ordering::Acquire) == 0
73                    || self.receivers.load(Ordering::Acquire) == 0
74                {
75                    self.mark_closed();
76                }
77            }
78            CloseBehavior::DrainUntilSendersDone => {
79                if self.senders.load(Ordering::Acquire) == 0 {
80                    self.mark_closed();
81                }
82            }
83        }
84    }
85
86    #[cfg(feature = "metrics")]
87    fn inc(&self, key: &'static str) {
88        if let Some(metrics) = &self.metrics {
89            metrics.increment(key, 1);
90        }
91    }
92}
93
94pub struct BroadcastSender<T> {
95    inner: Arc<BroadcastInner<T>>,
96}
97
98impl<T> Clone for BroadcastSender<T> {
99    fn clone(&self) -> Self {
100        self.inner.senders.fetch_add(1, Ordering::Relaxed);
101        Self {
102            inner: Arc::clone(&self.inner),
103        }
104    }
105}
106
107impl<T> Drop for BroadcastSender<T> {
108    fn drop(&mut self) {
109        self.inner.senders.fetch_sub(1, Ordering::Relaxed);
110        self.inner.try_close();
111    }
112}
113
114impl<T> std::fmt::Debug for BroadcastSender<T> {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        f.debug_struct("BroadcastSender").finish_non_exhaustive()
117    }
118}
119
120pub struct BroadcastReceiver<T> {
121    inner: Arc<BroadcastInner<T>>,
122    subscriber: Arc<Subscriber<T>>,
123}
124
125impl<T> Drop for BroadcastReceiver<T> {
126    fn drop(&mut self) {
127        self.inner.receivers.fetch_sub(1, Ordering::Relaxed);
128        self.inner.try_close();
129    }
130}
131
132pub fn broadcast<T: Send + Sync>(capacity: usize) -> (BroadcastSender<T>, BroadcastReceiver<T>) {
133    assert!(capacity > 0, "capacity must be greater than zero");
134    let inner = Arc::new(BroadcastInner::new(capacity, CloseBehavior::FailFast));
135    let recv = subscribe_inner(&inner);
136    (
137        BroadcastSender {
138            inner: Arc::clone(&inner),
139        },
140        recv,
141    )
142}
143
144pub fn broadcast_with_behavior<T: Send + Sync>(
145    capacity: usize,
146    close_behavior: CloseBehavior,
147) -> (BroadcastSender<T>, BroadcastReceiver<T>) {
148    assert!(capacity > 0, "capacity must be greater than zero");
149    let inner = Arc::new(BroadcastInner::new(capacity, close_behavior));
150    let recv = subscribe_inner(&inner);
151    (
152        BroadcastSender {
153            inner: Arc::clone(&inner),
154        },
155        recv,
156    )
157}
158
159#[cfg(feature = "metrics")]
160pub fn broadcast_with_metrics<T: Send + Sync>(
161    capacity: usize,
162    metrics: Arc<dyn MetricsSink>,
163) -> (BroadcastSender<T>, BroadcastReceiver<T>) {
164    assert!(capacity > 0, "capacity must be greater than zero");
165    let inner = Arc::new(BroadcastInner::new_with_metrics(
166        capacity,
167        CloseBehavior::FailFast,
168        metrics,
169    ));
170    let recv = subscribe_inner(&inner);
171    (
172        BroadcastSender {
173            inner: Arc::clone(&inner),
174        },
175        recv,
176    )
177}
178
179#[cfg(feature = "metrics")]
180pub fn broadcast_with_metrics_and_behavior<T: Send + Sync>(
181    capacity: usize,
182    close_behavior: CloseBehavior,
183    metrics: Arc<dyn MetricsSink>,
184) -> (BroadcastSender<T>, BroadcastReceiver<T>) {
185    assert!(capacity > 0, "capacity must be greater than zero");
186    let inner = Arc::new(BroadcastInner::new_with_metrics(
187        capacity,
188        close_behavior,
189        metrics,
190    ));
191    let recv = subscribe_inner(&inner);
192    (
193        BroadcastSender {
194            inner: Arc::clone(&inner),
195        },
196        recv,
197    )
198}
199
200fn subscribe_inner<T: Send + Sync>(inner: &Arc<BroadcastInner<T>>) -> BroadcastReceiver<T> {
201    let subscriber = Arc::new(Subscriber {
202        buffer: Mutex::new(VecDeque::with_capacity(inner.capacity)),
203    });
204    {
205        let mut subs = inner
206            .subscribers
207            .lock()
208            .expect("broadcast subscriber list poisoned");
209        subs.push(Arc::downgrade(&subscriber));
210    }
211    BroadcastReceiver {
212        inner: Arc::clone(inner),
213        subscriber,
214    }
215}
216
217impl<T: Send + Sync> BroadcastSender<T> {
218    pub fn subscribe(&self) -> BroadcastReceiver<T> {
219        self.inner.receivers.fetch_add(1, Ordering::Relaxed);
220        subscribe_inner(&self.inner)
221    }
222}
223
224impl<T: Send + Sync> ChannelSend<Arc<T>> for BroadcastSender<T> {
225    fn send(&self, value: Arc<T>) -> Backpressure {
226        if self.inner.closed.load(Ordering::Acquire) {
227            #[cfg(feature = "metrics")]
228            self.inner.inc("channel.broadcast.closed");
229            return Backpressure::Closed;
230        }
231
232        let mut live = 0usize;
233        let mut upgraded = Vec::new();
234        {
235            let mut subs = self
236                .inner
237                .subscribers
238                .lock()
239                .expect("broadcast subscriber list poisoned");
240            subs.retain(|weak_sub| {
241                if let Some(sub) = weak_sub.upgrade() {
242                    upgraded.push(sub);
243                    true
244                } else {
245                    false
246                }
247            });
248        }
249
250        for sub in upgraded {
251            live += 1;
252            let mut buf = sub.buffer.lock().expect("broadcast buffer poisoned");
253            if buf.len() >= self.inner.capacity {
254                buf.pop_front();
255                #[cfg(feature = "metrics")]
256                self.inner.inc("channel.broadcast.dropped");
257                self.inner.dropped.fetch_add(1, Ordering::Relaxed);
258            }
259            buf.push_back(Arc::clone(&value));
260            self.inner.enqueued.fetch_add(1, Ordering::Relaxed);
261        }
262
263        if live == 0 {
264            self.inner.mark_closed();
265            #[cfg(feature = "metrics")]
266            self.inner.inc("channel.broadcast.closed");
267            Backpressure::Closed
268        } else {
269            Backpressure::Ok
270        }
271    }
272}
273
274impl<T: Send + Sync> ChannelRecv<Arc<T>> for BroadcastReceiver<T> {
275    fn try_recv(&self) -> RecvOutcome<Arc<T>> {
276        let mut buf = self
277            .subscriber
278            .buffer
279            .lock()
280            .expect("broadcast buffer poisoned");
281        match buf.pop_front() {
282            Some(v) => {
283                self.inner.drained.fetch_add(1, Ordering::Relaxed);
284                RecvOutcome::Data(v)
285            }
286            None if self.inner.closed.load(Ordering::Acquire) => RecvOutcome::Closed,
287            None => RecvOutcome::Empty,
288        }
289    }
290}
291
292impl<T: Send + Sync> BroadcastReceiver<T> {
293    pub fn stats(&self) -> ChannelStats {
294        ChannelStats {
295            enqueued: self.inner.enqueued.load(Ordering::Relaxed),
296            dropped: self.inner.dropped.load(Ordering::Relaxed),
297            drained: self.inner.drained.load(Ordering::Relaxed),
298            depth: self.subscriber.buffer.lock().map(|b| b.len()).unwrap_or(0),
299            closed: self.inner.closed.load(Ordering::Relaxed),
300        }
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use proptest::prelude::*;
308    use std::sync::Arc;
309    use std::sync::atomic::{AtomicUsize, Ordering};
310    use std::thread;
311
312    #[test]
313    fn broadcast_to_multiple_subscribers() {
314        let (tx, rx1) = broadcast::<u32>(2);
315        let rx2 = tx.subscribe();
316
317        let payload = Arc::new(5);
318        tx.send(Arc::clone(&payload));
319
320        assert_eq!(rx1.try_recv(), RecvOutcome::Data(Arc::new(5)));
321        assert_eq!(rx2.try_recv(), RecvOutcome::Data(Arc::new(5)));
322    }
323
324    #[test]
325    fn broadcast_drops_oldest() {
326        let (tx, rx) = broadcast::<u32>(1);
327        tx.send(Arc::new(1));
328        tx.send(Arc::new(2));
329        assert_eq!(rx.try_recv(), RecvOutcome::Data(Arc::new(2)));
330    }
331
332    proptest! {
333        #[test]
334        fn broadcast_respects_per_subscriber_capacity(values in proptest::collection::vec(any::<u8>(), 2..20)) {
335            let (tx, rx1) = broadcast::<u8>(2);
336            let rx2 = tx.subscribe();
337            for v in &values {
338                let _ = tx.send(Arc::new(*v));
339            }
340            let mut seen1 = Vec::new();
341            while let RecvOutcome::Data(v) = rx1.try_recv() {
342                seen1.push(*v);
343            }
344            let mut seen2 = Vec::new();
345            while let RecvOutcome::Data(v) = rx2.try_recv() {
346                seen2.push(*v);
347            }
348            // Each subscriber keeps only the last 2 items due to capacity=2
349            let expected: Vec<u8> = values.into_iter().rev().take(2).collect::<Vec<_>>().into_iter().rev().collect();
350            prop_assert_eq!(seen1, expected.clone());
351            prop_assert_eq!(seen2, expected);
352        }
353    }
354
355    #[test]
356    fn broadcast_mpmc_smoke() {
357        let (tx, rx1) = broadcast::<u32>(4);
358        let rx2 = tx.subscribe();
359        let tx = Arc::new(tx);
360        let produced = 4u32 * 50u32;
361        let seen1 = Arc::new(AtomicUsize::new(0));
362        let seen2 = Arc::new(AtomicUsize::new(0));
363
364        let mut handles = Vec::new();
365        for offset in 0..4u32 {
366            let txc = tx.clone();
367            handles.push(thread::spawn(move || {
368                for i in 0..50u32 {
369                    let _ = txc.send(Arc::new(i + offset * 1_000));
370                }
371            }));
372        }
373
374        let recv1 = rx1;
375        let recv2 = rx2;
376        let h1 = {
377            let seen1 = seen1.clone();
378            thread::spawn(move || {
379                loop {
380                    match recv1.try_recv() {
381                        RecvOutcome::Data(_) => {
382                            seen1.fetch_add(1, Ordering::Relaxed);
383                        }
384                        RecvOutcome::Empty => {
385                            if seen1.load(Ordering::Relaxed) >= produced as usize {
386                                break;
387                            }
388                            std::thread::yield_now();
389                        }
390                        RecvOutcome::Closed => break,
391                    }
392                }
393            })
394        };
395        let h2 = {
396            let seen2 = seen2.clone();
397            thread::spawn(move || {
398                loop {
399                    match recv2.try_recv() {
400                        RecvOutcome::Data(_) => {
401                            seen2.fetch_add(1, Ordering::Relaxed);
402                        }
403                        RecvOutcome::Empty => {
404                            if seen2.load(Ordering::Relaxed) >= produced as usize {
405                                break;
406                            }
407                            std::thread::yield_now();
408                        }
409                        RecvOutcome::Closed => break,
410                    }
411                }
412            })
413        };
414
415        for h in handles {
416            h.join().unwrap();
417        }
418        drop(tx);
419        h1.join().unwrap();
420        h2.join().unwrap();
421
422        assert!(seen1.load(Ordering::Relaxed) <= produced as usize);
423        assert!(seen2.load(Ordering::Relaxed) <= produced as usize);
424    }
425}
426
427#[cfg(all(test, feature = "metrics"))]
428mod metric_tests {
429    use super::*;
430    use crate::metrics::InMemoryMetrics;
431    use std::sync::Arc;
432
433    #[test]
434    fn metrics_record_drops_and_closed() {
435        let metrics = Arc::new(InMemoryMetrics::default());
436        let collector: Arc<dyn crate::metrics::MetricsSink> = metrics.clone();
437        let (tx, rx) = broadcast_with_metrics::<u32>(1, collector);
438        tx.send(Arc::new(1));
439        tx.send(Arc::new(2));
440        assert_eq!(metrics.counter("channel.broadcast.dropped"), 1);
441        drop(rx);
442        assert_eq!(tx.send(Arc::new(3)), Backpressure::Closed);
443        assert_eq!(metrics.counter("channel.broadcast.closed"), 1);
444    }
445}