1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
//! Barrage - an asynchronous broadcast channel. Each message sent will be received by every receiver.
//! When the channel reaches its cap, send operations will block, wait, or fail (depending on which
//! type of send was chosen). Cloned receivers will only receive messages sent after they are cloned.
//!
//! # Example
//!
//! ```rust
//!
//! let (tx, rx1) = barrage::unbounded();
//! let rx2 = rx1.clone();
//! tx.send("Hello!");
//! let rx3 = rx1.clone();
//! assert_eq!(rx1.recv(), Ok("Hello!"));
//! assert_eq!(rx2.recv(), Ok("Hello!"));
//! assert_eq!(rx3.try_recv(), Ok(None));
//! ```

use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use concurrent_queue::ConcurrentQueue;
use event_listener::{Event, EventListener};
use std::fmt::Debug;
use facade::sync::atomic::{AtomicUsize, Ordering};
use facade::sync::Arc;
use facade::*;

mod facade;

type ReceiverQueue<T> = ConcurrentQueue<Arc<T>>;

struct Shared<T> {
    receiver_queues: RwLock<Vec<Arc<ReceiverQueue<T>>>>,
    on_final_receive: Event,
    on_send: Event,
    n_receivers: AtomicUsize,
    n_senders: AtomicUsize,
    len: AtomicUsize,
    capacity: Option<usize>,
}

/// All senders have disconnected from the channel and there are no more messages waiting.
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub struct Disconnected;

/// The broadcaster side of the channel.
pub struct Sender<T: Clone + Unpin>(Arc<Shared<T>>);

impl<T: Clone + Unpin> Clone for Sender<T> {
    fn clone(&self) -> Self {
        self.0.n_senders.fetch_add(1, Ordering::Relaxed);
        Sender(self.0.clone())
    }
}

impl<T: Clone + Unpin> Drop for Sender<T> {
    fn drop(&mut self) {
        if self.0.n_senders.fetch_sub(1, Ordering::Release) != 1 {
            return;
        }

        self.0.on_send.notify(usize::MAX);
    }
}

#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum TrySendError<T> {
    Disconnected(T),
    Full(T),
}

#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct SendError<T>(pub T);

impl<T: Clone + Unpin> Sender<T> {
    /// Try to broadcast a message to all receivers. If the message cap is reached, this will fail.
    pub fn try_send(&self, item: T) -> Result<(), TrySendError<T>> {
        if self.0.n_receivers.load(Ordering::Acquire) == 0 {
            return Err(TrySendError::Disconnected(item));
        }

        let shared = &self.0;
        if shared.capacity.map(|c| c == shared.len.load(Ordering::Acquire)).unwrap_or(false) {
            return Err(TrySendError::Full(item));
        }

        let item = Arc::new(item);

        // This isn't a for loop because idx is only advanced for present queues
        for q in lock_read(&shared.receiver_queues).iter() {
            assert!(q.push(item.clone()).is_ok());
        }

        shared.len.fetch_add(1, Ordering::Release);
        shared.on_send.notify(usize::MAX);

        Ok(())
    }

    /// Broadcast a message to all receivers. If the message cap is reached, this will block until
    /// the queue is no longer full.
    pub fn send(&self, mut item: T) -> Result<(), SendError<T>> {
        loop {
            let event_listener = self.0.on_final_receive.listen();
            match self.try_send(item) {
                Ok(()) => break Ok(()),
                Err(TrySendError::Disconnected(item)) => break Err(SendError(item)),
                Err(TrySendError::Full(ret)) => {
                    item = ret;
                    event_listener.wait();
                }
            }
        }
    }

    /// Broadcast a message to all receivers. If the message cap is reached, this will
    /// asynchronously wait until the queue is no longer full.
    pub fn send_async(&self, item: T) -> SendFut<T> {
        SendFut {
            item: Some(item),
            sender: self,
            event_listener: None,
        }
    }
}

/// The future representing an asynchronous broadcast operation.
///
/// # Panics
///
/// This will panic if polled after returning `Poll::Ready`.
pub struct SendFut<'a, T: Clone + Unpin> {
    item: Option<T>,
    sender: &'a Sender<T>,
    event_listener: Option<EventListener>,
}

impl<'a, T: Clone + Unpin> Future for SendFut<'a, T> {
    type Output = Result<(), SendError<T>>;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let poll = loop {
            let item = self.item.take().expect("cannot poll completed send future");
            let mut listener = self.sender.0.on_final_receive.listen();
            break match self.sender.try_send(item) {
                Ok(()) => Poll::Ready(Ok(())),
                Err(TrySendError::Disconnected(item)) => Poll::Ready(Err(SendError(item))),
                Err(TrySendError::Full(ret)) => {
                    self.item.replace(ret);

                    if let Poll::Ready(_) = Pin::new(&mut listener).poll(cx) {
                        continue;
                    }

                    self.event_listener = Some(listener);
                    Poll::Pending
                }
            }
        };

        poll
    }
}

/// The receiver side of the channel. This will receive every message broadcast.
///
/// If receive is called twice on the same receiver, only one receive will receive the broadcast
/// message. If both must receive it, clone the receiver.
pub struct Receiver<T: Clone + Unpin>(ReceiverInner<T>);

struct ReceiverInner<T: Clone + Unpin> {
    shared: Arc<Shared<T>>,
    queue: Arc<ConcurrentQueue<Arc<T>>>,
}

impl<T: Clone + Unpin> Drop for ReceiverInner<T> {
    fn drop(&mut self) {
        // 2 because 1 in the receiver inner & 1 in the shared
        if Arc::strong_count(&self.queue) > 2 {
            return;
        }

        let mut receiver_queues = lock_write(&self.shared.receiver_queues);
        receiver_queues.retain(|other| !Arc::ptr_eq(&self.queue, other));

        if self.shared.n_receivers.fetch_sub(1, Ordering::Release) != 1 {
            return;
        }

        self.shared.on_final_receive.notify(self.shared.n_senders.load(Ordering::Acquire));
    }
}


impl<T: Clone + Unpin> Clone for Receiver<T> {
    fn clone(&self) -> Self {
        let queue = Arc::new(ConcurrentQueue::unbounded());
        let mut receiver_queues = lock_write(&self.0.shared.receiver_queues);
        receiver_queues.push(queue.clone());
        self.0.shared.n_receivers.fetch_add(1, Ordering::Release);
        let inner = ReceiverInner {
            shared: self.0.shared.clone(),
            queue,
        };

        Receiver(inner)
    }
}

impl<T: Clone + Unpin> ReceiverInner<T> {
    fn recv(&self) -> Result<T, Disconnected> {
        loop {
            let listener = self.shared.on_send.listen();
            match self.try_recv() {
                Ok(Some(item)) => break Ok(item),
                Ok(None) => listener.wait(),
                Err(_) => break Err(Disconnected),
            }
        }
    }

    fn try_recv(&self) -> Result<Option<T>, Disconnected> {
        match self.queue.pop() {
            Ok(item) => {
                if Arc::strong_count(&item) == 1 {
                    self.shared.len.fetch_sub(1, Ordering::Release);
                    self.shared.on_final_receive.notify_additional(1);
                }

                Ok(Some((&*item).clone()))
            },
            Err(_) if self.shared.n_senders.load(Ordering::Acquire) > 0 => Ok(None),
            Err(_) => Err(Disconnected),
        }
    }

    fn recv_async(&self) -> RecvFut<T> {
        RecvFut {
            receiver: &self,
            event_listener: None,
        }
    }
}

impl<T: Clone + Unpin> Receiver<T> {
    /// Receive a broadcast message. If there are none in the queue, it will block until another is
    /// sent or all senders disconnect.
    pub fn recv(&self) -> Result<T, Disconnected> {
        self.0.recv()
    }

    /// Try to receive a broadcast message. If there are none in the queue, it will return `None`, or
    /// if there are no senders it will return `Disconnected`.
    pub fn try_recv(&self) -> Result<Option<T>, Disconnected> {
        self.0.try_recv()
    }

    /// Receive a broadcast message. If there are none in the queue, it will asynchronously wait
    /// until another is sent or all senders disconnect.
    pub fn recv_async(&self) -> RecvFut<T> {
        self.0.recv_async()
    }

    /// Converts this receiver into a [shared receiver](struct.SharedReceiver.html).
    pub fn into_shared(self) -> SharedReceiver<T> {
        SharedReceiver(self.0)
    }
}

/// A shared receiver is similar to a receiver, but it shares a mailbox with the other shared
/// receivers from which it originates (was cloned from). Thus, only one shared receiver with the
/// same mailbox will receive a broadcast.
pub struct SharedReceiver<T: Clone + Unpin>(ReceiverInner<T>);

impl<T: Clone + Unpin> Clone for SharedReceiver<T> {
    fn clone(&self) -> Self {
        let inner = ReceiverInner {
            shared: self.0.shared.clone(),
            queue: self.0.queue.clone(),
        };

        SharedReceiver(inner)
    }
}

impl<T: Clone + Unpin> SharedReceiver<T> {
    /// Upgrades this shared receiver into a full receiver with its own mailbox
    pub fn upgrade(mut self) -> Receiver<T> {
        // Duplicated clone logic to avoid dropping receiver again
        let queue = Arc::new(ConcurrentQueue::unbounded());

        {
            let mut receiver_queues = lock_write(&self.0.shared.receiver_queues);
            receiver_queues.push(queue.clone());
        }

        self.0.shared.n_receivers.fetch_add(1, Ordering::Release);
        self.0.queue = queue;

        Receiver(self.0)
    }

    /// Checks whether this shared receiver shares a mailbox with another.
    pub fn same_mailbox(&self, other: &SharedReceiver<T>) -> bool {
        Arc::ptr_eq(&self.0.queue, &other.0.queue)
    }

    /// Receive a broadcast message. If there are none in the queue, it will block until another is
    /// sent or all senders disconnect.
    pub fn recv(&self) -> Result<T, Disconnected> {
        self.0.recv()
    }

    /// Try to receive a broadcast message. If there are none in the queue, it will return `None`, or
    /// if there are no senders it will return `Disconnected`.
    pub fn try_recv(&self) -> Result<Option<T>, Disconnected> {
        self.0.try_recv()
    }

    /// Receive a broadcast message. If there are none in the queue, it will asynchronously wait
    /// until another is sent or all senders disconnect.
    pub fn recv_async(&self) -> RecvFut<T> {
        self.0.recv_async()
    }
}

/// The future representing an asynchronous receive operation.
pub struct RecvFut<'a, T: Clone + Unpin> {
    receiver: &'a ReceiverInner<T>,
    event_listener: Option<EventListener>,
}

impl<'a, T: Clone + Unpin> Future for RecvFut<'a, T> {
    type Output = Result<T, Disconnected>;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        loop {
            let mut listener = self.receiver.shared.on_send.listen();
            break match self.receiver.try_recv() {
                Ok(Some(item)) => Poll::Ready(Ok(item)),
                Ok(None) => {
                    if let Poll::Ready(_) = Pin::new(&mut listener).poll(cx) {
                        continue;
                    }
                    self.event_listener = Some(listener);
                    Poll::Pending
                },
                Err(_) => Poll::Ready(Err(Disconnected)),
            }
        }
    }
}

/// Create a new channel with the given capacity. If `None` is passed, it will be unbounded.
pub fn new<T: Clone + Unpin>(capacity: Option<usize>) -> (Sender<T>, Receiver<T>) {
    let receiver_queue = Arc::new(ConcurrentQueue::unbounded());

    let shared = Shared {
        receiver_queues: RwLock::new(vec![receiver_queue.clone()]),
        on_final_receive: Event::new(),
        on_send: Event::new(),
        n_receivers: AtomicUsize::new(1),
        n_senders: AtomicUsize::new(1),
        len: AtomicUsize::new(0),
        capacity,
    };
    let shared = Arc::new(shared);
    let receiver_inner = ReceiverInner {
        shared: shared.clone(),
        queue: receiver_queue
    };

    (Sender(shared), Receiver(receiver_inner))
}

/// Create a bounded channel of the given capacity.
pub fn bounded<T: Clone + Unpin>(capacity: usize) -> (Sender<T>, Receiver<T>) {
    new(Some(capacity))
}

/// Create an unbounded channel.
pub fn unbounded<T: Clone + Unpin>() -> (Sender<T>, Receiver<T>) {
    new(None)
}