barrage/
lib.rs

1//! Barrage - an asynchronous broadcast channel. Each message sent will be received by every receiver.
2//! When the channel reaches its cap, send operations will block, wait, or fail (depending on which
3//! type of send was chosen). Cloned receivers will only receive messages sent after they are cloned.
4//!
5//! # Example
6//!
7//! ```rust
8//!
9//! let (tx, rx1) = barrage::unbounded();
10//! let rx2 = rx1.clone();
11//! tx.send("Hello!");
12//! let rx3 = rx1.clone();
13//! assert_eq!(rx1.recv(), Ok("Hello!"));
14//! assert_eq!(rx2.recv(), Ok("Hello!"));
15//! assert_eq!(rx3.try_recv(), Ok(None));
16//! ```
17
18use std::future::Future;
19use std::pin::Pin;
20use std::task::{Context, Poll};
21use concurrent_queue::ConcurrentQueue;
22use event_listener::{Event, EventListener};
23use std::fmt::Debug;
24use std::sync::atomic::{AtomicUsize, Ordering};
25use std::sync::Arc;
26use facade::*;
27
28mod facade;
29
30type ReceiverQueue<T> = ConcurrentQueue<Arc<T>>;
31
32struct Shared<T> {
33    receiver_queues: RwLock<Vec<Arc<ReceiverQueue<T>>>>,
34    on_final_receive: Event,
35    on_send: Event,
36    n_receivers: AtomicUsize,
37    n_senders: AtomicUsize,
38    len: AtomicUsize,
39    capacity: Option<usize>,
40}
41
42/// All senders have disconnected from the channel and there are no more messages waiting.
43#[derive(Copy, Clone, Eq, PartialEq, Debug)]
44pub struct Disconnected;
45
46/// The broadcaster side of the channel.
47pub struct Sender<T: Clone + Unpin>(Arc<Shared<T>>);
48
49impl<T: Clone + Unpin> Clone for Sender<T> {
50    fn clone(&self) -> Self {
51        self.0.n_senders.fetch_add(1, Ordering::Relaxed);
52        Sender(self.0.clone())
53    }
54}
55
56impl<T: Clone + Unpin> Drop for Sender<T> {
57    fn drop(&mut self) {
58        if self.0.n_senders.fetch_sub(1, Ordering::Release) != 1 {
59            return;
60        }
61
62        self.0.on_send.notify(usize::MAX);
63    }
64}
65
66#[derive(Copy, Clone, Debug, Eq, PartialEq)]
67pub enum TrySendError<T> {
68    Disconnected(T),
69    Full(T),
70}
71
72#[derive(Copy, Clone, Debug, Eq, PartialEq)]
73pub struct SendError<T>(pub T);
74
75impl<T: Clone + Unpin> Sender<T> {
76    /// Try to broadcast a message to all receivers. If the message cap is reached, this will fail.
77    pub fn try_send(&self, item: T) -> Result<(), TrySendError<T>> {
78        if self.0.n_receivers.load(Ordering::Acquire) == 0 {
79            return Err(TrySendError::Disconnected(item));
80        }
81
82        let shared = &self.0;
83        if shared.capacity.map(|c| c == shared.len.load(Ordering::Acquire)).unwrap_or(false) {
84            return Err(TrySendError::Full(item));
85        }
86
87        let item = Arc::new(item);
88
89        // This isn't a for loop because idx is only advanced for present queues
90        for q in lock_read(&shared.receiver_queues).iter() {
91            assert!(q.push(item.clone()).is_ok());
92        }
93
94        shared.len.fetch_add(1, Ordering::Release);
95        shared.on_send.notify(usize::MAX);
96
97        Ok(())
98    }
99
100    /// Broadcast a message to all receivers. If the message cap is reached, this will block until
101    /// the queue is no longer full.
102    pub fn send(&self, mut item: T) -> Result<(), SendError<T>> {
103        loop {
104            let event_listener = self.0.on_final_receive.listen();
105            match self.try_send(item) {
106                Ok(()) => break Ok(()),
107                Err(TrySendError::Disconnected(item)) => break Err(SendError(item)),
108                Err(TrySendError::Full(ret)) => {
109                    item = ret;
110                    event_listener.wait();
111                }
112            }
113        }
114    }
115
116    /// Broadcast a message to all receivers. If the message cap is reached, this will
117    /// asynchronously wait until the queue is no longer full.
118    pub fn send_async(&self, item: T) -> SendFut<T> {
119        SendFut {
120            item: Some(item),
121            sender: self,
122            event_listener: None,
123        }
124    }
125}
126
127/// The future representing an asynchronous broadcast operation.
128///
129/// # Panics
130///
131/// This will panic if polled after returning `Poll::Ready`.
132pub struct SendFut<'a, T: Clone + Unpin> {
133    item: Option<T>,
134    sender: &'a Sender<T>,
135    event_listener: Option<EventListener>,
136}
137
138impl<'a, T: Clone + Unpin> Future for SendFut<'a, T> {
139    type Output = Result<(), SendError<T>>;
140
141    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
142        let poll = loop {
143            let item = self.item.take().expect("cannot poll completed send future");
144            let mut listener = self.sender.0.on_final_receive.listen();
145            break match self.sender.try_send(item) {
146                Ok(()) => Poll::Ready(Ok(())),
147                Err(TrySendError::Disconnected(item)) => Poll::Ready(Err(SendError(item))),
148                Err(TrySendError::Full(ret)) => {
149                    self.item.replace(ret);
150
151                    if let Poll::Ready(_) = Pin::new(&mut listener).poll(cx) {
152                        continue;
153                    }
154
155                    self.event_listener = Some(listener);
156                    Poll::Pending
157                }
158            }
159        };
160
161        poll
162    }
163}
164
165/// The receiver side of the channel. This will receive every message broadcast.
166///
167/// If receive is called twice on the same receiver, only one receive will receive the broadcast
168/// message. If both must receive it, clone the receiver.
169pub struct Receiver<T: Clone + Unpin>(ReceiverInner<T>);
170
171struct ReceiverInner<T: Clone + Unpin> {
172    shared: Arc<Shared<T>>,
173    queue: Arc<ConcurrentQueue<Arc<T>>>,
174}
175
176impl<T: Clone + Unpin> Drop for ReceiverInner<T> {
177    fn drop(&mut self) {
178        // 2 because 1 in the receiver inner & 1 in the shared
179        if Arc::strong_count(&self.queue) > 2 {
180            return;
181        }
182
183        let mut receiver_queues = lock_write(&self.shared.receiver_queues);
184        receiver_queues.retain(|other| !Arc::ptr_eq(&self.queue, other));
185
186        if self.shared.n_receivers.fetch_sub(1, Ordering::Release) != 1 {
187            return;
188        }
189
190        self.shared.on_final_receive.notify(self.shared.n_senders.load(Ordering::Acquire));
191    }
192}
193
194
195impl<T: Clone + Unpin> Clone for Receiver<T> {
196    fn clone(&self) -> Self {
197        let queue = Arc::new(ConcurrentQueue::unbounded());
198        let mut receiver_queues = lock_write(&self.0.shared.receiver_queues);
199        receiver_queues.push(queue.clone());
200        self.0.shared.n_receivers.fetch_add(1, Ordering::Release);
201        let inner = ReceiverInner {
202            shared: self.0.shared.clone(),
203            queue,
204        };
205
206        Receiver(inner)
207    }
208}
209
210impl<T: Clone + Unpin> ReceiverInner<T> {
211    fn recv(&self) -> Result<T, Disconnected> {
212        loop {
213            let listener = self.shared.on_send.listen();
214            match self.try_recv() {
215                Ok(Some(item)) => break Ok(item),
216                Ok(None) => listener.wait(),
217                Err(_) => break Err(Disconnected),
218            }
219        }
220    }
221
222    fn try_recv(&self) -> Result<Option<T>, Disconnected> {
223        match self.queue.pop() {
224            Ok(item) => {
225                let old_len = self.shared.len.load(Ordering::SeqCst);
226                let weak = Arc::downgrade(&item);
227                let inner = (&*item).clone();
228                drop(item);
229
230                if weak.strong_count() == 0 &&
231                    self.shared.len.compare_exchange(
232                        old_len,
233                        old_len - 1,
234                        Ordering::Release,
235                        Ordering::Relaxed
236                    ).is_ok()
237                {
238                    self.shared.on_final_receive.notify_additional(1);
239                }
240
241                Ok(Some(inner))
242            },
243            Err(_) if self.shared.n_senders.load(Ordering::Acquire) > 0 => Ok(None),
244            Err(_) => Err(Disconnected),
245        }
246    }
247
248    fn recv_async(&self) -> RecvFut<T> {
249        RecvFut {
250            receiver: &self,
251            event_listener: None,
252        }
253    }
254}
255
256impl<T: Clone + Unpin> Receiver<T> {
257    /// Receive a broadcast message. If there are none in the queue, it will block until another is
258    /// sent or all senders disconnect.
259    pub fn recv(&self) -> Result<T, Disconnected> {
260        self.0.recv()
261    }
262
263    /// Try to receive a broadcast message. If there are none in the queue, it will return `None`, or
264    /// if there are no senders it will return `Disconnected`.
265    pub fn try_recv(&self) -> Result<Option<T>, Disconnected> {
266        self.0.try_recv()
267    }
268
269    /// Receive a broadcast message. If there are none in the queue, it will asynchronously wait
270    /// until another is sent or all senders disconnect.
271    pub fn recv_async(&self) -> RecvFut<T> {
272        self.0.recv_async()
273    }
274
275    /// Converts this receiver into a [shared receiver](struct.SharedReceiver.html).
276    pub fn into_shared(self) -> SharedReceiver<T> {
277        SharedReceiver(self.0)
278    }
279}
280
281/// A shared receiver is similar to a receiver, but it shares a mailbox with the other shared
282/// receivers from which it originates (was cloned from). Thus, only one shared receiver with the
283/// same mailbox will receive a broadcast.
284pub struct SharedReceiver<T: Clone + Unpin>(ReceiverInner<T>);
285
286impl<T: Clone + Unpin> Clone for SharedReceiver<T> {
287    fn clone(&self) -> Self {
288        let inner = ReceiverInner {
289            shared: self.0.shared.clone(),
290            queue: self.0.queue.clone(),
291        };
292
293        SharedReceiver(inner)
294    }
295}
296
297impl<T: Clone + Unpin> SharedReceiver<T> {
298    /// Upgrades this shared receiver into a full receiver with its own mailbox
299    pub fn upgrade(mut self) -> Receiver<T> {
300        // Duplicated clone logic to avoid dropping receiver again
301        let queue = Arc::new(ConcurrentQueue::unbounded());
302
303        {
304            let mut receiver_queues = lock_write(&self.0.shared.receiver_queues);
305            receiver_queues.push(queue.clone());
306        }
307
308        self.0.shared.n_receivers.fetch_add(1, Ordering::Release);
309        self.0.queue = queue;
310
311        Receiver(self.0)
312    }
313
314    /// Checks whether this shared receiver shares a mailbox with another.
315    pub fn same_mailbox(&self, other: &SharedReceiver<T>) -> bool {
316        Arc::ptr_eq(&self.0.queue, &other.0.queue)
317    }
318
319    /// Receive a broadcast message. If there are none in the queue, it will block until another is
320    /// sent or all senders disconnect.
321    pub fn recv(&self) -> Result<T, Disconnected> {
322        self.0.recv()
323    }
324
325    /// Try to receive a broadcast message. If there are none in the queue, it will return `None`, or
326    /// if there are no senders it will return `Disconnected`.
327    pub fn try_recv(&self) -> Result<Option<T>, Disconnected> {
328        self.0.try_recv()
329    }
330
331    /// Receive a broadcast message. If there are none in the queue, it will asynchronously wait
332    /// until another is sent or all senders disconnect.
333    pub fn recv_async(&self) -> RecvFut<T> {
334        self.0.recv_async()
335    }
336}
337
338/// The future representing an asynchronous receive operation.
339pub struct RecvFut<'a, T: Clone + Unpin> {
340    receiver: &'a ReceiverInner<T>,
341    event_listener: Option<EventListener>,
342}
343
344impl<'a, T: Clone + Unpin> Future for RecvFut<'a, T> {
345    type Output = Result<T, Disconnected>;
346
347    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
348        loop {
349            let mut listener = self.receiver.shared.on_send.listen();
350            break match self.receiver.try_recv() {
351                Ok(Some(item)) => Poll::Ready(Ok(item)),
352                Ok(None) => {
353                    if let Poll::Ready(_) = Pin::new(&mut listener).poll(cx) {
354                        continue;
355                    }
356                    self.event_listener = Some(listener);
357                    Poll::Pending
358                },
359                Err(_) => Poll::Ready(Err(Disconnected)),
360            }
361        }
362    }
363}
364
365/// Create a new channel with the given capacity. If `None` is passed, it will be unbounded.
366pub fn new<T: Clone + Unpin>(capacity: Option<usize>) -> (Sender<T>, Receiver<T>) {
367    let receiver_queue = Arc::new(ConcurrentQueue::unbounded());
368
369    let shared = Shared {
370        receiver_queues: RwLock::new(vec![receiver_queue.clone()]),
371        on_final_receive: Event::new(),
372        on_send: Event::new(),
373        n_receivers: AtomicUsize::new(1),
374        n_senders: AtomicUsize::new(1),
375        len: AtomicUsize::new(0),
376        capacity,
377    };
378    let shared = Arc::new(shared);
379    let receiver_inner = ReceiverInner {
380        shared: shared.clone(),
381        queue: receiver_queue
382    };
383
384    (Sender(shared), Receiver(receiver_inner))
385}
386
387/// Create a bounded channel of the given capacity.
388pub fn bounded<T: Clone + Unpin>(capacity: usize) -> (Sender<T>, Receiver<T>) {
389    new(Some(capacity))
390}
391
392/// Create an unbounded channel.
393pub fn unbounded<T: Clone + Unpin>() -> (Sender<T>, Receiver<T>) {
394    new(None)
395}