1use std::collections::VecDeque;
2use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
3use std::sync::{Arc, Mutex, Weak};
4
5use super::{Backpressure, ChannelRecv, ChannelSend, ChannelStats, CloseBehavior, RecvOutcome};
6
7#[cfg(feature = "metrics")]
8use crate::metrics::MetricsSink;
9
10struct Subscriber<T> {
11 buffer: Mutex<VecDeque<Arc<T>>>,
12}
13
14struct BroadcastInner<T> {
15 subscribers: Mutex<Vec<Weak<Subscriber<T>>>>,
16 closed: AtomicBool,
17 senders: AtomicUsize,
18 receivers: AtomicUsize,
19 capacity: usize,
20 enqueued: AtomicU64,
21 dropped: AtomicU64,
22 drained: AtomicU64,
23 close_behavior: CloseBehavior,
24 #[cfg(feature = "metrics")]
25 metrics: Option<Arc<dyn MetricsSink>>,
26}
27
28impl<T> BroadcastInner<T> {
29 fn new(capacity: usize, close_behavior: CloseBehavior) -> Self {
30 Self {
31 subscribers: Mutex::new(Vec::new()),
32 closed: AtomicBool::new(false),
33 senders: AtomicUsize::new(1),
34 receivers: AtomicUsize::new(1),
35 capacity,
36 enqueued: AtomicU64::new(0),
37 dropped: AtomicU64::new(0),
38 drained: AtomicU64::new(0),
39 close_behavior,
40 #[cfg(feature = "metrics")]
41 metrics: None,
42 }
43 }
44
45 #[cfg(feature = "metrics")]
46 fn new_with_metrics(
47 capacity: usize,
48 close_behavior: CloseBehavior,
49 metrics: Arc<dyn MetricsSink>,
50 ) -> Self {
51 Self {
52 subscribers: Mutex::new(Vec::new()),
53 closed: AtomicBool::new(false),
54 senders: AtomicUsize::new(1),
55 receivers: AtomicUsize::new(1),
56 capacity,
57 enqueued: AtomicU64::new(0),
58 dropped: AtomicU64::new(0),
59 drained: AtomicU64::new(0),
60 close_behavior,
61 metrics: Some(metrics),
62 }
63 }
64
65 fn mark_closed(&self) {
66 self.closed.store(true, Ordering::Release);
67 }
68
69 fn try_close(&self) {
70 match self.close_behavior {
71 CloseBehavior::FailFast => {
72 if self.senders.load(Ordering::Acquire) == 0
73 || self.receivers.load(Ordering::Acquire) == 0
74 {
75 self.mark_closed();
76 }
77 }
78 CloseBehavior::DrainUntilSendersDone => {
79 if self.senders.load(Ordering::Acquire) == 0 {
80 self.mark_closed();
81 }
82 }
83 }
84 }
85
86 #[cfg(feature = "metrics")]
87 fn inc(&self, key: &'static str) {
88 if let Some(metrics) = &self.metrics {
89 metrics.increment(key, 1);
90 }
91 }
92}
93
94pub struct BroadcastSender<T> {
95 inner: Arc<BroadcastInner<T>>,
96}
97
98impl<T> Clone for BroadcastSender<T> {
99 fn clone(&self) -> Self {
100 self.inner.senders.fetch_add(1, Ordering::Relaxed);
101 Self {
102 inner: Arc::clone(&self.inner),
103 }
104 }
105}
106
107impl<T> Drop for BroadcastSender<T> {
108 fn drop(&mut self) {
109 self.inner.senders.fetch_sub(1, Ordering::Relaxed);
110 self.inner.try_close();
111 }
112}
113
114impl<T> std::fmt::Debug for BroadcastSender<T> {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 f.debug_struct("BroadcastSender").finish_non_exhaustive()
117 }
118}
119
120pub struct BroadcastReceiver<T> {
121 inner: Arc<BroadcastInner<T>>,
122 subscriber: Arc<Subscriber<T>>,
123}
124
125impl<T> Drop for BroadcastReceiver<T> {
126 fn drop(&mut self) {
127 self.inner.receivers.fetch_sub(1, Ordering::Relaxed);
128 self.inner.try_close();
129 }
130}
131
132pub fn broadcast<T: Send + Sync>(capacity: usize) -> (BroadcastSender<T>, BroadcastReceiver<T>) {
133 assert!(capacity > 0, "capacity must be greater than zero");
134 let inner = Arc::new(BroadcastInner::new(capacity, CloseBehavior::FailFast));
135 let recv = subscribe_inner(&inner);
136 (
137 BroadcastSender {
138 inner: Arc::clone(&inner),
139 },
140 recv,
141 )
142}
143
144pub fn broadcast_with_behavior<T: Send + Sync>(
145 capacity: usize,
146 close_behavior: CloseBehavior,
147) -> (BroadcastSender<T>, BroadcastReceiver<T>) {
148 assert!(capacity > 0, "capacity must be greater than zero");
149 let inner = Arc::new(BroadcastInner::new(capacity, close_behavior));
150 let recv = subscribe_inner(&inner);
151 (
152 BroadcastSender {
153 inner: Arc::clone(&inner),
154 },
155 recv,
156 )
157}
158
159#[cfg(feature = "metrics")]
160pub fn broadcast_with_metrics<T: Send + Sync>(
161 capacity: usize,
162 metrics: Arc<dyn MetricsSink>,
163) -> (BroadcastSender<T>, BroadcastReceiver<T>) {
164 assert!(capacity > 0, "capacity must be greater than zero");
165 let inner = Arc::new(BroadcastInner::new_with_metrics(
166 capacity,
167 CloseBehavior::FailFast,
168 metrics,
169 ));
170 let recv = subscribe_inner(&inner);
171 (
172 BroadcastSender {
173 inner: Arc::clone(&inner),
174 },
175 recv,
176 )
177}
178
179#[cfg(feature = "metrics")]
180pub fn broadcast_with_metrics_and_behavior<T: Send + Sync>(
181 capacity: usize,
182 close_behavior: CloseBehavior,
183 metrics: Arc<dyn MetricsSink>,
184) -> (BroadcastSender<T>, BroadcastReceiver<T>) {
185 assert!(capacity > 0, "capacity must be greater than zero");
186 let inner = Arc::new(BroadcastInner::new_with_metrics(
187 capacity,
188 close_behavior,
189 metrics,
190 ));
191 let recv = subscribe_inner(&inner);
192 (
193 BroadcastSender {
194 inner: Arc::clone(&inner),
195 },
196 recv,
197 )
198}
199
200fn subscribe_inner<T: Send + Sync>(inner: &Arc<BroadcastInner<T>>) -> BroadcastReceiver<T> {
201 let subscriber = Arc::new(Subscriber {
202 buffer: Mutex::new(VecDeque::with_capacity(inner.capacity)),
203 });
204 {
205 let mut subs = inner
206 .subscribers
207 .lock()
208 .expect("broadcast subscriber list poisoned");
209 subs.push(Arc::downgrade(&subscriber));
210 }
211 BroadcastReceiver {
212 inner: Arc::clone(inner),
213 subscriber,
214 }
215}
216
217impl<T: Send + Sync> BroadcastSender<T> {
218 pub fn subscribe(&self) -> BroadcastReceiver<T> {
219 self.inner.receivers.fetch_add(1, Ordering::Relaxed);
220 subscribe_inner(&self.inner)
221 }
222}
223
224impl<T: Send + Sync> ChannelSend<Arc<T>> for BroadcastSender<T> {
225 fn send(&self, value: Arc<T>) -> Backpressure {
226 if self.inner.closed.load(Ordering::Acquire) {
227 #[cfg(feature = "metrics")]
228 self.inner.inc("channel.broadcast.closed");
229 return Backpressure::Closed;
230 }
231
232 let mut live = 0usize;
233 let mut upgraded = Vec::new();
234 {
235 let mut subs = self
236 .inner
237 .subscribers
238 .lock()
239 .expect("broadcast subscriber list poisoned");
240 subs.retain(|weak_sub| {
241 if let Some(sub) = weak_sub.upgrade() {
242 upgraded.push(sub);
243 true
244 } else {
245 false
246 }
247 });
248 }
249
250 for sub in upgraded {
251 live += 1;
252 let mut buf = sub.buffer.lock().expect("broadcast buffer poisoned");
253 if buf.len() >= self.inner.capacity {
254 buf.pop_front();
255 #[cfg(feature = "metrics")]
256 self.inner.inc("channel.broadcast.dropped");
257 self.inner.dropped.fetch_add(1, Ordering::Relaxed);
258 }
259 buf.push_back(Arc::clone(&value));
260 self.inner.enqueued.fetch_add(1, Ordering::Relaxed);
261 }
262
263 if live == 0 {
264 self.inner.mark_closed();
265 #[cfg(feature = "metrics")]
266 self.inner.inc("channel.broadcast.closed");
267 Backpressure::Closed
268 } else {
269 Backpressure::Ok
270 }
271 }
272}
273
274impl<T: Send + Sync> ChannelRecv<Arc<T>> for BroadcastReceiver<T> {
275 fn try_recv(&self) -> RecvOutcome<Arc<T>> {
276 let mut buf = self
277 .subscriber
278 .buffer
279 .lock()
280 .expect("broadcast buffer poisoned");
281 match buf.pop_front() {
282 Some(v) => {
283 self.inner.drained.fetch_add(1, Ordering::Relaxed);
284 RecvOutcome::Data(v)
285 }
286 None if self.inner.closed.load(Ordering::Acquire) => RecvOutcome::Closed,
287 None => RecvOutcome::Empty,
288 }
289 }
290}
291
292impl<T: Send + Sync> BroadcastReceiver<T> {
293 pub fn stats(&self) -> ChannelStats {
294 ChannelStats {
295 enqueued: self.inner.enqueued.load(Ordering::Relaxed),
296 dropped: self.inner.dropped.load(Ordering::Relaxed),
297 drained: self.inner.drained.load(Ordering::Relaxed),
298 depth: self.subscriber.buffer.lock().map(|b| b.len()).unwrap_or(0),
299 closed: self.inner.closed.load(Ordering::Relaxed),
300 }
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use proptest::prelude::*;
308 use std::sync::Arc;
309 use std::sync::atomic::{AtomicUsize, Ordering};
310 use std::thread;
311
312 #[test]
313 fn broadcast_to_multiple_subscribers() {
314 let (tx, rx1) = broadcast::<u32>(2);
315 let rx2 = tx.subscribe();
316
317 let payload = Arc::new(5);
318 tx.send(Arc::clone(&payload));
319
320 assert_eq!(rx1.try_recv(), RecvOutcome::Data(Arc::new(5)));
321 assert_eq!(rx2.try_recv(), RecvOutcome::Data(Arc::new(5)));
322 }
323
324 #[test]
325 fn broadcast_drops_oldest() {
326 let (tx, rx) = broadcast::<u32>(1);
327 tx.send(Arc::new(1));
328 tx.send(Arc::new(2));
329 assert_eq!(rx.try_recv(), RecvOutcome::Data(Arc::new(2)));
330 }
331
332 proptest! {
333 #[test]
334 fn broadcast_respects_per_subscriber_capacity(values in proptest::collection::vec(any::<u8>(), 2..20)) {
335 let (tx, rx1) = broadcast::<u8>(2);
336 let rx2 = tx.subscribe();
337 for v in &values {
338 let _ = tx.send(Arc::new(*v));
339 }
340 let mut seen1 = Vec::new();
341 while let RecvOutcome::Data(v) = rx1.try_recv() {
342 seen1.push(*v);
343 }
344 let mut seen2 = Vec::new();
345 while let RecvOutcome::Data(v) = rx2.try_recv() {
346 seen2.push(*v);
347 }
348 let expected: Vec<u8> = values.into_iter().rev().take(2).collect::<Vec<_>>().into_iter().rev().collect();
350 prop_assert_eq!(seen1, expected.clone());
351 prop_assert_eq!(seen2, expected);
352 }
353 }
354
355 #[test]
356 fn broadcast_mpmc_smoke() {
357 let (tx, rx1) = broadcast::<u32>(4);
358 let rx2 = tx.subscribe();
359 let tx = Arc::new(tx);
360 let produced = 4u32 * 50u32;
361 let seen1 = Arc::new(AtomicUsize::new(0));
362 let seen2 = Arc::new(AtomicUsize::new(0));
363
364 let mut handles = Vec::new();
365 for offset in 0..4u32 {
366 let txc = tx.clone();
367 handles.push(thread::spawn(move || {
368 for i in 0..50u32 {
369 let _ = txc.send(Arc::new(i + offset * 1_000));
370 }
371 }));
372 }
373
374 let recv1 = rx1;
375 let recv2 = rx2;
376 let h1 = {
377 let seen1 = seen1.clone();
378 thread::spawn(move || {
379 loop {
380 match recv1.try_recv() {
381 RecvOutcome::Data(_) => {
382 seen1.fetch_add(1, Ordering::Relaxed);
383 }
384 RecvOutcome::Empty => {
385 if seen1.load(Ordering::Relaxed) >= produced as usize {
386 break;
387 }
388 std::thread::yield_now();
389 }
390 RecvOutcome::Closed => break,
391 }
392 }
393 })
394 };
395 let h2 = {
396 let seen2 = seen2.clone();
397 thread::spawn(move || {
398 loop {
399 match recv2.try_recv() {
400 RecvOutcome::Data(_) => {
401 seen2.fetch_add(1, Ordering::Relaxed);
402 }
403 RecvOutcome::Empty => {
404 if seen2.load(Ordering::Relaxed) >= produced as usize {
405 break;
406 }
407 std::thread::yield_now();
408 }
409 RecvOutcome::Closed => break,
410 }
411 }
412 })
413 };
414
415 for h in handles {
416 h.join().unwrap();
417 }
418 drop(tx);
419 h1.join().unwrap();
420 h2.join().unwrap();
421
422 assert!(seen1.load(Ordering::Relaxed) <= produced as usize);
423 assert!(seen2.load(Ordering::Relaxed) <= produced as usize);
424 }
425}
426
427#[cfg(all(test, feature = "metrics"))]
428mod metric_tests {
429 use super::*;
430 use crate::metrics::InMemoryMetrics;
431 use std::sync::Arc;
432
433 #[test]
434 fn metrics_record_drops_and_closed() {
435 let metrics = Arc::new(InMemoryMetrics::default());
436 let collector: Arc<dyn crate::metrics::MetricsSink> = metrics.clone();
437 let (tx, rx) = broadcast_with_metrics::<u32>(1, collector);
438 tx.send(Arc::new(1));
439 tx.send(Arc::new(2));
440 assert_eq!(metrics.counter("channel.broadcast.dropped"), 1);
441 drop(rx);
442 assert_eq!(tx.send(Arc::new(3)), Backpressure::Closed);
443 assert_eq!(metrics.counter("channel.broadcast.closed"), 1);
444 }
445}