#![warn(missing_docs)]
extern crate uuid;
use std::sync::{mpsc, Arc, Mutex};
use std::collections::HashMap;
#[derive(Clone)]
pub struct PubSub<T: Clone> {
senders: Arc<Mutex<HashMap<uuid::Uuid, mpsc::Sender<T>>>>,
}
pub struct Subscription<T: Clone> {
receiver: mpsc::Receiver<T>,
senders: Arc<Mutex<HashMap<uuid::Uuid, mpsc::Sender<T>>>>,
id: uuid::Uuid,
}
impl<T: Clone> PubSub<T> {
pub fn new() -> PubSub<T> {
PubSub { senders: Arc::new(Mutex::new(HashMap::new())) }
}
pub fn send(&self, it: T) -> Result<(), mpsc::SendError<T>> {
let senders = self.senders.lock().unwrap();
for (_, sender) in senders.iter() {
match sender.send(it.clone()) {
Ok(_) => {}
Err(err) => return Err(err),
}
}
Ok(())
}
pub fn subscribe(&self) -> Subscription<T> {
let id = uuid::Uuid::new_v4();
let (send, recv) = mpsc::channel();
{
let mut senders = self.senders.lock().unwrap();
senders.insert(id, send);
}
Subscription {
receiver: recv,
senders: self.senders.clone(),
id: id,
}
}
}
impl<T: Clone> Subscription<T> {
pub fn recv(&self) -> Result<T, mpsc::RecvError> {
self.receiver.recv()
}
pub fn try_recv(&self) -> Result<T, mpsc::TryRecvError> {
self.receiver.try_recv()
}
pub fn iter(&self) -> mpsc::Iter<T> {
self.receiver.iter()
}
}
impl<T: Clone> Drop for Subscription<T> {
fn drop(&mut self) {
let mut senders = self.senders.lock().unwrap();
senders.remove(&self.id);
}
}
impl<T: Clone> Clone for Subscription<T> {
fn clone(&self) -> Self {
PubSub { senders: self.senders.clone() }.subscribe()
}
}
#[cfg(test)]
mod tests {
use std;
use super::*;
#[test]
fn many_senders() {
use std::sync::atomic::{AtomicUsize, Ordering};
let send = PubSub::new();
let recv = send.subscribe();
let threads = 5;
let pulses = 50;
let received = std::sync::Arc::new(AtomicUsize::new(0));
for _ in 0..threads {
let recv = recv.clone();
let received = received.clone();
std::thread::spawn(move || {
while let Ok(_) = recv.recv() {
received.fetch_add(1, Ordering::AcqRel);
}
});
}
let mut accum = 0;
for _ in 0..pulses {
accum += 1;
send.send(accum).unwrap();
}
std::thread::sleep(std::time::Duration::from_millis(75));
assert_eq!(received.load(Ordering::Acquire), threads * pulses);
}
}