use parking_lot::RwLock;
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
#[must_use]
pub fn channel<T: Clone + Send + Sync>(capacity: usize) -> BroadcastSender<T> {
BroadcastSender::new(capacity)
}
struct BroadcastState<T> {
buffer: VecDeque<(u64, T)>,
sequence: u64,
capacity: usize,
closed: bool,
}
pub struct BroadcastSender<T> {
state: Arc<RwLock<BroadcastState<T>>>,
sequence: AtomicU64,
}
impl<T: Clone + Send + Sync> BroadcastSender<T> {
#[must_use]
pub fn new(capacity: usize) -> Self {
Self {
state: Arc::new(RwLock::new(BroadcastState {
buffer: VecDeque::with_capacity(capacity),
sequence: 0,
capacity,
closed: false,
})),
sequence: AtomicU64::new(0),
}
}
pub fn send(&self, item: T) -> u64 {
let mut state = self.state.write();
let seq = state.sequence;
state.sequence += 1;
while state.buffer.len() >= state.capacity {
state.buffer.pop_front();
}
state.buffer.push_back((seq, item));
self.sequence.store(seq + 1, Ordering::Release);
seq
}
#[must_use]
pub fn subscribe(&self) -> BroadcastReceiver<T> {
let current_seq = self.sequence.load(Ordering::Acquire);
BroadcastReceiver {
state: Arc::clone(&self.state),
next_seq: current_seq,
}
}
#[must_use]
pub fn subscribe_from_start(&self) -> BroadcastReceiver<T> {
let state = self.state.read();
let start_seq = state.buffer.front().map(|(s, _)| *s).unwrap_or(0);
drop(state);
BroadcastReceiver {
state: Arc::clone(&self.state),
next_seq: start_seq,
}
}
#[must_use]
pub fn sequence(&self) -> u64 {
self.sequence.load(Ordering::Acquire)
}
#[must_use]
pub fn len(&self) -> usize {
self.state.read().buffer.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<T> Drop for BroadcastSender<T> {
fn drop(&mut self) {
self.state.write().closed = true;
}
}
pub struct BroadcastReceiver<T> {
state: Arc<RwLock<BroadcastState<T>>>,
next_seq: u64,
}
impl<T: Clone> BroadcastReceiver<T> {
pub fn recv(&mut self) -> Option<(u64, T)> {
let state = self.state.read();
for (seq, item) in &state.buffer {
if *seq == self.next_seq {
self.next_seq += 1;
return Some((*seq, item.clone()));
}
}
None
}
pub fn recv_all(&mut self) -> Vec<(u64, T)> {
let state = self.state.read();
let mut result = Vec::new();
for (seq, item) in &state.buffer {
if *seq >= self.next_seq {
result.push((*seq, item.clone()));
self.next_seq = *seq + 1;
}
}
result
}
#[must_use]
pub fn is_connected(&self) -> bool {
!self.state.read().closed
}
#[must_use]
pub fn lag(&self) -> u64 {
let state = self.state.read();
state.sequence.saturating_sub(self.next_seq)
}
#[must_use]
pub fn next_sequence(&self) -> u64 {
self.next_seq
}
}
impl<T: Clone> Clone for BroadcastReceiver<T> {
fn clone(&self) -> Self {
Self {
state: Arc::clone(&self.state),
next_seq: self.next_seq,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_broadcast() {
let tx = channel::<u64>(16);
let mut rx1 = tx.subscribe();
let mut rx2 = tx.subscribe();
tx.send(42);
assert_eq!(rx1.recv(), Some((0, 42)));
assert_eq!(rx2.recv(), Some((0, 42)));
}
#[test]
fn test_multiple_messages() {
let tx = channel::<u64>(16);
let mut rx = tx.subscribe();
tx.send(1);
tx.send(2);
tx.send(3);
assert_eq!(rx.recv(), Some((0, 1)));
assert_eq!(rx.recv(), Some((1, 2)));
assert_eq!(rx.recv(), Some((2, 3)));
assert_eq!(rx.recv(), None);
}
#[test]
fn test_late_subscriber() {
let tx = channel::<u64>(16);
tx.send(1);
tx.send(2);
let mut rx = tx.subscribe();
tx.send(3);
assert_eq!(rx.recv(), Some((2, 3)));
}
#[test]
fn test_subscribe_from_start() {
let tx = channel::<u64>(16);
tx.send(1);
tx.send(2);
let mut rx = tx.subscribe_from_start();
assert_eq!(rx.recv(), Some((0, 1)));
assert_eq!(rx.recv(), Some((1, 2)));
}
#[test]
fn test_recv_all() {
let tx = channel::<u64>(16);
let mut rx = tx.subscribe();
tx.send(1);
tx.send(2);
tx.send(3);
let all = rx.recv_all();
assert_eq!(all, vec![(0, 1), (1, 2), (2, 3)]);
}
#[test]
fn test_capacity_overflow() {
let tx = channel::<u64>(3);
let mut rx = tx.subscribe_from_start();
tx.send(1);
tx.send(2);
tx.send(3);
tx.send(4);
let all = rx.recv_all();
assert_eq!(all, vec![(1, 2), (2, 3), (3, 4)]);
}
#[test]
fn test_lag() {
let tx = channel::<u64>(16);
let mut rx = tx.subscribe();
tx.send(1);
tx.send(2);
tx.send(3);
assert_eq!(rx.lag(), 3);
rx.recv();
assert_eq!(rx.lag(), 2);
}
#[test]
fn test_disconnect() {
let tx = channel::<u64>(16);
let rx = tx.subscribe();
assert!(rx.is_connected());
drop(tx);
assert!(!rx.is_connected());
}
}