extern crate rand;
use std::sync::{Arc, Mutex, Condvar};
use std::cmp::Ordering;
use std::collections::HashMap;
use std::collections::HashSet;
use std::collections::VecDeque;
use std::collections::BinaryHeap;
use std::sync::mpsc;
use std::thread;
use std::sync;
macro_rules! debug {
( $fmt:expr ) => {
};
( $fmt:expr, $( $args:expr ),+ ) => {
$(let _ = $args;)*;
};
}
struct TaggedData<T> {
from: String,
to: Option<String>,
ts: Option<usize>,
data: Option<T>,
}
enum Message<T> {
Data(TaggedData<T>),
ReceiverJoin(String, Arc<ReceiverInner<T>>),
ReceiverLeave(String),
SenderJoin(Option<String>, String),
SenderLeave(Option<String>, String),
}
pub struct ClockedSender<T> {
target: String,
source: String,
dispatcher: mpsc::SyncSender<Message<T>>,
}
impl<T> Drop for ClockedSender<T> {
fn drop(&mut self) {
self.dispatcher
.send(Message::SenderLeave(Some(self.target.clone()), self.source.clone()))
.unwrap();
}
}
impl<T> ClockedSender<T> {
pub fn send(&self, data: T) {
self.dispatcher
.send(Message::Data(TaggedData {
from: self.source.clone(),
to: Some(self.target.clone()),
ts: None,
data: Some(data),
}))
.unwrap()
}
pub fn forward(&self, data: Option<T>, ts: usize) {
self.dispatcher
.send(Message::Data(TaggedData {
from: self.source.clone(),
to: Some(self.target.clone()),
ts: Some(ts),
data: data,
}))
.unwrap()
}
pub fn clone<V: Into<String>>(&self, source: V) -> ClockedSender<T> {
let source = source.into();
self.dispatcher
.send(Message::SenderJoin(Some(self.target.clone()), source.clone()))
.unwrap();
ClockedSender {
source: source,
target: self.target.clone(),
dispatcher: self.dispatcher.clone(),
}
}
}
impl<T: Clone> ClockedSender<T> {
pub fn into_broadcaster(self) -> ClockedBroadcaster<T> {
let dispatcher = self.dispatcher.clone();
let source = format!("{}_bcast", self.source);
dispatcher.send(Message::SenderJoin(None, source.clone())).unwrap();
ClockedBroadcaster {
source: source,
dispatcher: dispatcher,
}
}
}
pub struct ClockedBroadcaster<T: Clone> {
source: String,
dispatcher: mpsc::SyncSender<Message<T>>,
}
impl<T: Clone> Drop for ClockedBroadcaster<T> {
fn drop(&mut self) {
self.dispatcher.send(Message::SenderLeave(None, self.source.clone())).unwrap();
}
}
impl<T: Clone> ClockedBroadcaster<T> {
pub fn broadcast(&self, data: T) {
self.dispatcher
.send(Message::Data(TaggedData {
from: self.source.clone(),
to: None,
ts: None,
data: Some(data),
}))
.unwrap()
}
pub fn broadcast_forward(&self, data: Option<T>, ts: usize) {
self.dispatcher
.send(Message::Data(TaggedData {
from: self.source.clone(),
to: None,
ts: Some(ts),
data: data,
}))
.unwrap()
}
pub fn clone<V: Into<String>>(&self, source: V) -> ClockedBroadcaster<T> {
let source = source.into();
self.dispatcher.send(Message::SenderJoin(None, source.clone())).unwrap();
ClockedBroadcaster {
source: source,
dispatcher: self.dispatcher.clone(),
}
}
}
struct QueueState<T> {
queue: VecDeque<(T, usize)>,
ts_head: usize,
ts_tail: usize,
closed: bool,
left: bool,
}
struct ReceiverInner<T> {
mx: Mutex<QueueState<T>>,
cond: Condvar,
}
pub struct ClockedReceiver<T: Send + 'static> {
leave: mpsc::SyncSender<String>,
inner: Arc<ReceiverInner<T>>,
name: String,
}
impl<T: Send + 'static> ClockedReceiver<T> {
fn new<V: Into<String>>(name: V,
leave: mpsc::SyncSender<String>,
bound: usize)
-> ClockedReceiver<T> {
ClockedReceiver {
leave: leave,
inner: Arc::new(ReceiverInner {
mx: Mutex::new(QueueState {
queue: VecDeque::with_capacity(bound),
ts_head: 0,
ts_tail: 0,
closed: false,
left: false,
}),
cond: Condvar::new(),
}),
name: name.into(),
}
}
}
impl<T: Send + 'static> Iterator for ClockedReceiver<T> {
type Item = (Option<T>, usize);
fn next(&mut self) -> Option<Self::Item> {
self.recv().ok()
}
}
impl<T: Send + 'static> Drop for ClockedReceiver<T> {
fn drop(&mut self) {
use std::mem;
let name = mem::replace(&mut self.name, String::new());
self.leave.send(name).unwrap();
self.count();
}
}
impl<T: Send + 'static> ClockedReceiver<T> {
pub fn recv(&self) -> Result<(Option<T>, usize), mpsc::RecvError> {
let mut state = self.inner.mx.lock().unwrap();
while state.ts_head == state.ts_tail && state.queue.is_empty() && !state.closed {
state = self.inner.cond.wait(state).unwrap();
}
if let Some((t, ts)) = state.queue.pop_front() {
state.ts_head = ts;
self.inner.cond.notify_one();
return Ok((Some(t), ts));
}
if state.ts_head == state.ts_tail {
assert_eq!(state.closed, true);
return Err(mpsc::RecvError);
}
state.ts_head = state.ts_tail;
self.inner.cond.notify_one();
Ok((None, state.ts_head))
}
pub fn try_recv(&self) -> Result<(Option<T>, usize), mpsc::TryRecvError> {
let mut state = self.inner.mx.lock().unwrap();
if state.ts_head == state.ts_tail && !state.closed {
return Err(mpsc::TryRecvError::Empty);
}
if state.ts_head == state.ts_tail {
assert_eq!(state.closed, true);
return Err(mpsc::TryRecvError::Disconnected);
}
if let Some((t, ts)) = state.queue.pop_front() {
state.ts_head = ts;
self.inner.cond.notify_one();
return Ok((Some(t), ts));
}
state.ts_head = state.ts_tail;
self.inner.cond.notify_one();
Ok((None, state.ts_head))
}
}
pub struct Dispatcher<T: Send> {
dispatcher: mpsc::SyncSender<Message<T>>,
leave: mpsc::SyncSender<String>,
bound: usize,
}
impl<T: Send> Dispatcher<T> {
pub fn new<S1: Into<String>, S2: Into<String>>(&self,
sender: S1,
receiver: S2)
-> (ClockedSender<T>, ClockedReceiver<T>) {
let source = sender.into();
let target = receiver.into();
let send = ClockedSender {
source: source.clone(),
target: target.clone(),
dispatcher: self.dispatcher.clone(),
};
let recv = ClockedReceiver::new(target.clone(), self.leave.clone(), self.bound);
self.dispatcher.send(Message::ReceiverJoin(target.clone(), recv.inner.clone())).unwrap();
self.dispatcher.send(Message::SenderJoin(Some(target.clone()), source)).unwrap();
(send, recv)
}
}
struct Delayed<T> {
ts: usize,
data: T,
}
impl<T> PartialEq for Delayed<T> {
fn eq(&self, other: &Delayed<T>) -> bool {
other.ts == self.ts
}
}
impl<T> PartialOrd for Delayed<T> {
fn partial_cmp(&self, other: &Delayed<T>) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T> Eq for Delayed<T> {}
impl<T> Ord for Delayed<T> {
fn cmp(&self, other: &Delayed<T>) -> Ordering {
other.ts.cmp(&self.ts)
}
}
struct Target<T> {
channel: Arc<ReceiverInner<T>>,
delayed: sync::Mutex<BinaryHeap<Delayed<T>>>,
senders: HashSet<String>,
}
struct DispatchInner<T> {
targets: sync::Arc<sync::RwLock<HashMap<String, Target<T>>>>,
destinations: HashSet<String>,
broadcasters: HashSet<String>,
bdelay: BinaryHeap<Delayed<T>>,
forwarding: Option<bool>,
bound: usize,
id: String,
counter: usize,
}
impl<T: Clone> DispatchInner<T> {
fn notify(&self, to: Option<&String>, ts: usize, data: Option<T>) {
let tgts = self.targets.read().unwrap();
for (tn, t) in tgts.iter() {
let mut state = t.channel.mx.lock().unwrap();
debug!("{}: notifying {} about {}", self.id, tn, ts);
if data.is_some() && (to.is_none() || to.unwrap() == tn.as_str()) {
debug!("{}: including data", self.id);
while state.queue.len() == self.bound && !state.left {
state = t.channel.cond.wait(state).unwrap();
}
if state.left {
t.channel.cond.notify_one();
continue;
}
state.queue.push_back((data.clone().unwrap(), ts));
}
state.ts_tail = ts;
t.channel.cond.notify_one();
drop(state);
}
}
fn process_delayed(&mut self) {
assert!(self.forwarding.unwrap_or(false));
debug!("{}: processing delayed after {}", self.id, self.counter);
loop {
let next = self.bdelay.peek().map(|d| d.ts);
debug!("{}: next from bcast is {:?}", self.id, next);
let tnext = {
let tgts = self.targets.read().unwrap();
let t = self.destinations
.iter()
.map(|to| {
let t = &tgts[to];
(to,
t.delayed
.lock()
.unwrap()
.peek()
.map(|d| d.ts))
})
.filter_map(|(to, ts)| ts.map(move |ts| (to, ts)))
.min_by_key(|&(_, ts)| ts);
t.map(|(to, ts)| (to.to_owned(), ts))
};
debug!("{}: next from tdelay is {:?}", self.id, tnext);
if let Some((to, tnext)) = tnext {
if tnext == self.counter + 1 {
debug!("{}: forwarding from tdelay", self.id);
let d = {
let tgts = self.targets.read().unwrap();
let mut x = tgts[to.as_str()].delayed.lock().unwrap();
x.pop().unwrap()
};
self.notify(Some(&to), d.ts, Some(d.data));
self.counter += 1;
continue;
}
}
if let Some(ts) = next {
if ts == self.counter + 1 {
debug!("{}: forwarding from bdelay", self.id);
let d = self.bdelay.pop().unwrap();
self.notify(None, d.ts, Some(d.data));
self.counter += 1;
continue;
}
}
break;
}
debug!("{}: done replaying", self.id);
}
fn absorb(&mut self, m: Message<T>) {
match m {
Message::Data(td) => {
debug!("{}: got message with ts {:?} from {} for {:?}",
self.id,
td.ts,
td.from,
td.to);
if self.forwarding.is_some() {
assert!(self.forwarding.unwrap() == td.ts.is_some(),
"one sender sent timestamp, another did not");
} else {
self.forwarding = Some(td.ts.is_some())
}
if let Some(ts) = td.ts {
assert!(ts >= self.counter);
if ts == self.counter + 1 {
self.counter = ts;
}
}
if td.ts.is_none() {
self.counter += 1;
self.notify(td.to.as_ref(), self.counter, td.data);
return;
}
let ts = td.ts.unwrap();
if ts == self.counter {
self.notify(td.to.as_ref(), ts, td.data);
self.process_delayed();
return;
}
if let Some(data) = td.data {
if let Some(ref to) = td.to {
debug!("{}: delayed in {:?}", self.id, to);
let tgts = self.targets.read().unwrap();
tgts[to].delayed.lock().unwrap().push(Delayed {
ts: ts,
data: data,
});
drop(tgts);
} else {
debug!("{}: delayed in bcast", self.id);
self.bdelay.push(Delayed {
ts: ts,
data: data,
});
}
}
}
Message::ReceiverJoin(name, inner) => {
debug!("{}: receiver {} joined", self.id, name);
if !self.destinations.insert(name.clone()) {
panic!("receiver {} already exists!", name);
}
let mut tgts = self.targets.write().unwrap();
tgts.insert(name,
Target {
channel: inner,
senders: HashSet::new(),
delayed: sync::Mutex::new(BinaryHeap::new()),
});
}
Message::ReceiverLeave(name) => {
debug!("{}: receiver {} left", self.id, name);
let mut tgts = self.targets.write().unwrap();
tgts.remove(&*name);
self.destinations.remove(&*name);
}
Message::SenderJoin(target, source) => {
debug!("{}: sender {} for {:?} joined", self.id, source, target);
if let Some(target) = target {
let mut tgts = self.targets.write().unwrap();
tgts.get_mut(&*target).unwrap().senders.insert(source);
} else {
self.broadcasters.insert(source);
}
}
Message::SenderLeave(target, source) => {
debug!("{}: sender {} for {:?} left", self.id, source, target);
if let Some(ref target) = target {
let mut tgts = self.targets.write().unwrap();
if let Some(target) = tgts.get_mut(target.as_str()) {
target.senders.remove(&*source);
}
drop(tgts);
} else {
self.broadcasters.remove(&*source);
}
if self.broadcasters.is_empty() {
let mut tgts = self.targets.write().unwrap();
for (tn, t) in tgts.iter_mut()
.filter(|&(_, ref t)| {
t.senders.is_empty() && t.delayed.lock().unwrap().is_empty()
}) {
debug!("{}: closing now-done channel {}", self.id, tn);
let mut state = t.channel.mx.lock().unwrap();
state.closed = true;
t.channel.cond.notify_one();
drop(state);
}
}
}
}
}
}
pub fn new<T: Clone + Send + 'static>(bound: usize) -> Dispatcher<T> {
new_with_seed(bound, 0)
}
pub fn new_with_seed<T: Clone + Send + 'static>(bound: usize, seed: usize) -> Dispatcher<T> {
use rand::{thread_rng, Rng};
let (stx, srx) = mpsc::sync_channel(bound);
let mut d = DispatchInner {
targets: sync::Arc::new(sync::RwLock::new(HashMap::new())),
destinations: HashSet::new(),
bdelay: BinaryHeap::new(),
broadcasters: HashSet::new(),
forwarding: None,
bound: bound,
id: thread_rng().gen_ascii_chars().take(2).collect(),
counter: seed,
};
let id = d.id.clone();
let c_targets = d.targets.clone();
let c_stx = stx.clone();
let (ctx, crx) = mpsc::sync_channel::<String>(0);
thread::spawn(move || {
let mut leaving = Vec::new();
let mut leaving_ = Vec::new();
'recv: loop {
let left = crx.try_recv();
match left {
Ok(left) => {
leaving.push(left);
}
Err(..) if !leaving.is_empty() => {
}
Err(mpsc::TryRecvError::Disconnected) => {
break 'recv;
}
Err(mpsc::TryRecvError::Empty) => {
let left = crx.recv();
if let Ok(left) = left {
leaving.push(left);
} else {
break 'recv;
}
}
}
for left in leaving.drain(..) {
debug!("{} control: dealing with departure of receiver {}",
id,
left);
let targets = c_targets.read().unwrap();
if let Some(t) = targets.get(&*left) {
let mut state = t.channel.mx.lock().unwrap();
state.left = true;
state.closed = true;
t.channel.cond.notify_one();
drop(state);
let ctx = c_stx.clone();
thread::spawn(move || {
ctx.send(Message::ReceiverLeave(left)).unwrap();
});
} else {
leaving_.push(left);
}
}
leaving.extend(leaving_.drain(..));
}
});
thread::spawn(move || {
for m in srx.iter() {
d.absorb(m);
}
});
Dispatcher {
dispatcher: stx,
leave: ctx,
bound: bound,
}
}
#[cfg(test)]
mod tests {
#[test]
fn can_send_after_recv_drop() {
let d = super::new(1);
let (tx_a, rx_a) = d.new("atx", "arx");
let (tx_b, rx_b) = d.new("btx", "brx");
let _ = tx_a;
drop(rx_a);
tx_b.send(10);
assert_eq!(rx_b.recv().unwrap().0.unwrap(), 10);
}
#[test]
fn recv_drop_unblocks_sender() {
use std::thread;
use std::time::Duration;
let d = super::new(1);
let (tx_a, rx_a) = d.new("atx", "arx");
let (tx_b, rx_b) = d.new("btx", "brx");
let tx_a = tx_a.into_broadcaster();
thread::spawn(move || {
for _ in 0..20 {
tx_b.send("b");
}
});
thread::sleep(Duration::from_millis(200));
drop(rx_b);
tx_a.broadcast("a");
loop {
let rx = rx_a.recv();
assert!(rx.is_ok());
let rx = rx.unwrap();
if rx.0.is_some() {
assert_eq!(rx.0, Some("a"));
break;
}
}
}
#[test]
fn can_forward_after_recv_drop() {
let d = super::new(1);
let (tx_a, rx_a) = d.new("atx", "arx");
let (tx_b, rx_b) = d.new("btx", "brx");
let _ = tx_a;
drop(rx_a);
tx_b.forward(Some(10), 1);
assert_eq!(rx_b.recv(), Ok((Some(10), 1)));
}
#[test]
fn forward_with_no_senders() {
use std::sync::mpsc;
let d = super::new(1);
let (tx_a, rx_a) = d.new("atx", "arx");
let (tx_b, rx_b) = d.new("btx", "brx");
tx_a.forward(Some(1), 1);
drop(tx_a);
drop(tx_b);
assert_eq!(rx_a.recv(), Ok((Some(1), 1)));
assert_eq!(rx_b.recv(), Ok((None, 1)));
assert_eq!(rx_a.recv(), Err(mpsc::RecvError));
assert_eq!(rx_b.recv(), Err(mpsc::RecvError));
}
#[test]
fn broadcast_dupe_termination() {
use std::sync::mpsc;
let d = super::new(1);
let (tx, rx) = d.new("tx", "rx");
let tx = tx.into_broadcaster();
tx.broadcast_forward(Some("a"), 1);
tx.broadcast_forward(Some("b"), 2);
drop(tx);
assert_eq!(rx.recv(), Ok((Some("a"), 1)));
assert_eq!(rx.recv(), Ok((Some("b"), 2)));
assert_eq!(rx.recv(), Err(mpsc::RecvError));
}
#[test]
fn multisend_thread_interleaving() {
use std::thread;
for _ in 0..1000 {
let d = super::new(20);
let (tx_a, rx) = d.new("tx_a", "rx");
let tx_b = tx_a.clone("tx_b");
let t_a = thread::spawn(move || {
tx_a.forward(Some("c_1"), 1);
tx_a.forward(Some("c_3"), 3);
tx_a.forward(Some("a_1"), 5);
});
let t_b = thread::spawn(move || {
tx_b.forward(Some("c_2"), 2);
tx_b.forward(Some("b_1"), 4);
tx_b.forward(Some("a_2"), 6);
});
assert_eq!(rx.recv(), Ok((Some("c_1"), 1)));
assert_eq!(rx.recv(), Ok((Some("c_2"), 2)));
assert_eq!(rx.recv(), Ok((Some("c_3"), 3)));
assert_eq!(rx.recv(), Ok((Some("b_1"), 4)));
assert_eq!(rx.recv(), Ok((Some("a_1"), 5)));
assert_eq!(rx.recv(), Ok((Some("a_2"), 6)));
t_a.join().unwrap();
t_b.join().unwrap();
}
}
#[test]
fn test_new_with_seed() {
let d = super::new_with_seed(1, 69105);
let (tx, rx) = d.new("tx", "rx");
tx.send("a");
assert_eq!(rx.recv(), Ok((Some("a"), 69106)));
}
}