use crate::runtime::execution::ExecutionState;
use crate::runtime::task::clock::VectorClock;
use crate::runtime::task::{TaskId, DEFAULT_INLINE_TASKS};
use crate::runtime::thread;
use smallvec::SmallVec;
use std::cell::RefCell;
use std::fmt::Debug;
use std::rc::Rc;
use std::result::Result;
pub use std::sync::mpsc::{RecvError, RecvTimeoutError, SendError, TryRecvError, TrySendError};
use std::sync::Arc;
use std::time::Duration;
use tracing::trace;
const MAX_INLINE_MESSAGES: usize = 32;
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let channel = Arc::new(Channel::new(None));
let sender = Sender {
inner: Arc::clone(&channel),
};
let receiver = Receiver {
inner: Arc::clone(&channel),
};
(sender, receiver)
}
pub fn sync_channel<T>(bound: usize) -> (SyncSender<T>, Receiver<T>) {
let channel = Arc::new(Channel::new(Some(bound)));
let sender = SyncSender {
inner: Arc::clone(&channel),
};
let receiver = Receiver {
inner: Arc::clone(&channel),
};
(sender, receiver)
}
#[derive(Debug)]
struct Channel<T> {
bound: Option<usize>, state: Rc<RefCell<ChannelState<T>>>,
}
struct TimestampedValue<T> {
value: T,
clock: VectorClock,
}
impl<T> TimestampedValue<T> {
fn new(value: T, clock: VectorClock) -> Self {
Self { value, clock }
}
}
struct ChannelState<T> {
messages: SmallVec<[TimestampedValue<T>; MAX_INLINE_MESSAGES]>, receiver_clock: Option<SmallVec<[VectorClock; MAX_INLINE_MESSAGES]>>, known_senders: usize, known_receivers: usize, waiting_senders: SmallVec<[TaskId; DEFAULT_INLINE_TASKS]>, waiting_receivers: SmallVec<[TaskId; DEFAULT_INLINE_TASKS]>, }
impl<T> Debug for ChannelState<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Channel {{ ")?;
write!(f, "num_messages: {} ", self.messages.len())?;
write!(
f,
"known_senders {} known_receivers {} ",
self.known_senders, self.known_receivers
)?;
write!(f, "waiting_senders: [{:?}] ", self.waiting_senders)?;
write!(f, "waiting_receivers: [{:?}] ", self.waiting_receivers)?;
write!(f, "}}")
}
}
impl<T> Channel<T> {
fn new(bound: Option<usize>) -> Self {
let receiver_clock = if let Some(bound) = bound {
let mut s = SmallVec::with_capacity(bound);
for _ in 0..bound {
s.push(VectorClock::new());
}
Some(s)
} else {
None
};
Self {
bound,
state: Rc::new(RefCell::new(ChannelState {
messages: SmallVec::new(),
receiver_clock,
known_senders: 1,
known_receivers: 1,
waiting_senders: SmallVec::new(),
waiting_receivers: SmallVec::new(),
})),
}
}
fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
self.send_internal(message, false)
}
fn send(&self, message: T) -> Result<(), SendError<T>> {
self.send_internal(message, true).map_err(|e| match e {
TrySendError::Full(_) => unreachable!(),
TrySendError::Disconnected(m) => SendError(m),
})
}
fn send_internal(&self, message: T, can_block: bool) -> Result<(), TrySendError<T>> {
let me = ExecutionState::me();
let mut state = self.state.borrow_mut();
trace!(
state = ?state,
"sender {:?} starting send on channel {:p}",
me,
self,
);
if state.known_receivers == 0 {
return Err(TrySendError::Disconnected(message));
}
let (is_rendezvous, is_full) = if let Some(bound) = self.bound {
(bound == 0, state.messages.len() >= std::cmp::max(bound, 1))
} else {
(false, false)
};
let sender_should_block =
is_full || !state.waiting_senders.is_empty() || (is_rendezvous && state.waiting_receivers.is_empty());
if sender_should_block {
if !can_block {
return Err(TrySendError::Full(message));
}
state.waiting_senders.push(me);
trace!(
state = ?state,
"blocking sender {:?} on channel {:p}",
me,
self,
);
ExecutionState::with(|s| s.current_mut().block(false));
drop(state);
thread::switch();
state = self.state.borrow_mut();
trace!(
state = ?state,
"unblocked sender {:?} on channel {:p}",
me,
self,
);
if state.known_receivers == 0 {
state.waiting_senders.retain(|t| *t != me);
return Err(TrySendError::Disconnected(message));
}
let head = state.waiting_senders.remove(0);
assert_eq!(head, me);
}
ExecutionState::with(|s| {
let clock = s.increment_clock();
state.messages.push(TimestampedValue::new(message, clock.clone()));
});
if let Some(&tid) = state.waiting_receivers.first() {
ExecutionState::with(|s| {
s.get_mut(tid).unblock();
if is_rendezvous {
let recv_clock = s.get_clock(tid).clone();
s.update_clock(&recv_clock);
}
});
}
if let Some(&tid) = state.waiting_senders.first() {
let bound = self.bound.expect("can't have waiting senders on an unbounded channel");
if state.messages.len() < bound {
ExecutionState::with(|s| s.get_mut(tid).unblock());
}
}
if !is_rendezvous {
if let Some(receiver_clock) = &mut state.receiver_clock {
let recv_clock = receiver_clock.remove(0);
ExecutionState::with(|s| s.update_clock(&recv_clock));
}
}
Ok(())
}
fn recv(&self) -> Result<T, RecvError> {
self.recv_internal(true).map_err(|e| match e {
TryRecvError::Disconnected => RecvError,
TryRecvError::Empty => unreachable!(),
})
}
fn try_recv(&self) -> Result<T, TryRecvError> {
self.recv_internal(false)
}
fn recv_internal(&self, can_block: bool) -> Result<T, TryRecvError> {
let me = ExecutionState::me();
let mut state = self.state.borrow_mut();
trace!(
state = ?state,
"starting recv on channel {:p}",
self,
);
if state.messages.is_empty() && state.known_senders == 0 {
return Err(TryRecvError::Disconnected);
}
let is_rendezvous = self.bound == Some(0);
if is_rendezvous && state.messages.is_empty() {
if let Some(&tid) = state.waiting_senders.first() {
ExecutionState::with(|s| s.get_mut(tid).unblock());
} else if !can_block {
return Err(TryRecvError::Empty);
}
}
if !is_rendezvous && !can_block && state.waiting_receivers.len() >= state.messages.len() {
return Err(TryRecvError::Empty);
}
ExecutionState::with(|s| {
let _ = s.increment_clock();
});
let should_block = state.messages.is_empty() || !state.waiting_receivers.is_empty();
if should_block {
state.waiting_receivers.push(me);
trace!(
state = ?state,
"blocking receiver {:?} on channel {:p}",
me,
self,
);
ExecutionState::with(|s| s.current_mut().block(false));
drop(state);
thread::switch();
state = self.state.borrow_mut();
trace!(
state = ?state,
"unblocked receiver {:?} on channel {:p}",
me,
self,
);
if state.messages.is_empty() && state.known_senders == 0 {
state.waiting_receivers.retain(|t| *t != me);
return Err(TryRecvError::Disconnected);
}
let head = state.waiting_receivers.remove(0);
assert_eq!(head, me);
}
let item = state.messages.remove(0);
if let Some(&tid) = state.waiting_senders.first() {
let bound = self.bound.expect("can't have waiting senders on an unbounded channel");
if bound > 0 || !state.waiting_receivers.is_empty() {
ExecutionState::with(|s| s.get_mut(tid).unblock());
}
}
if let Some(&tid) = state.waiting_receivers.first() {
if !state.messages.is_empty() {
ExecutionState::with(|s| s.get_mut(tid).unblock());
}
}
let TimestampedValue { value, clock } = item;
ExecutionState::with(|s| {
s.get_clock_mut(me).update(&clock);
if let Some(receiver_clock) = &mut state.receiver_clock {
let bound = self.bound.expect("unexpected internal error"); if bound > 0 {
assert!(receiver_clock.len() < bound);
receiver_clock.push(s.get_clock(me).clone());
}
}
});
Ok(value)
}
}
unsafe impl<T: Send> Send for Channel<T> {}
unsafe impl<T: Send> Sync for Channel<T> {}
#[derive(Debug)]
pub struct Receiver<T> {
inner: Arc<Channel<T>>,
}
impl<T> Receiver<T> {
pub fn recv(&self) -> Result<T, RecvError> {
self.inner.recv()
}
pub fn try_recv(&self) -> Result<T, TryRecvError> {
self.inner.try_recv()
}
pub fn recv_timeout(&self, _timeout: Duration) -> Result<T, RecvTimeoutError> {
self.inner.recv().map_err(|_| RecvTimeoutError::Disconnected)
}
pub fn iter(&self) -> Iter<'_, T> {
Iter { rx: self }
}
pub fn try_iter(&self) -> TryIter<'_, T> {
TryIter { rx: self }
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
if ExecutionState::should_stop() {
return;
}
let mut state = self.inner.state.borrow_mut();
assert!(state.known_receivers > 0);
state.known_receivers -= 1;
if state.known_receivers == 0 {
for &tid in state.waiting_senders.iter() {
ExecutionState::with(|s| s.get_mut(tid).unblock());
}
}
}
}
#[derive(Debug)]
pub struct Iter<'a, T: 'a> {
rx: &'a Receiver<T>,
}
#[derive(Debug)]
pub struct TryIter<'a, T: 'a> {
rx: &'a Receiver<T>,
}
#[derive(Debug)]
pub struct IntoIter<T> {
rx: Receiver<T>,
}
impl<T> Iterator for Iter<'_, T> {
type Item = T;
fn next(&mut self) -> Option<T> {
self.rx.recv().ok()
}
}
impl<T> Iterator for TryIter<'_, T> {
type Item = T;
fn next(&mut self) -> Option<T> {
self.rx.try_recv().ok()
}
}
impl<'a, T> IntoIterator for &'a Receiver<T> {
type Item = T;
type IntoIter = Iter<'a, T>;
fn into_iter(self) -> Iter<'a, T> {
self.iter()
}
}
impl<T> Iterator for IntoIter<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
self.rx.recv().ok()
}
}
impl<T> IntoIterator for Receiver<T> {
type Item = T;
type IntoIter = IntoIter<T>;
fn into_iter(self) -> IntoIter<T> {
IntoIter { rx: self }
}
}
#[derive(Debug)]
pub struct Sender<T> {
inner: Arc<Channel<T>>,
}
impl<T> Sender<T> {
pub fn send(&self, t: T) -> Result<(), SendError<T>> {
self.inner.send(t)
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
let mut state = self.inner.state.borrow_mut();
state.known_senders += 1;
drop(state);
Self {
inner: self.inner.clone(),
}
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
if ExecutionState::should_stop() {
return;
}
let mut state = self.inner.state.borrow_mut();
assert!(state.known_senders > 0);
state.known_senders -= 1;
if state.known_senders == 0 {
for &tid in state.waiting_receivers.iter() {
ExecutionState::with(|s| s.get_mut(tid).unblock());
}
}
}
}
#[derive(Debug)]
pub struct SyncSender<T> {
inner: Arc<Channel<T>>,
}
impl<T> SyncSender<T> {
pub fn send(&self, t: T) -> Result<(), SendError<T>> {
self.inner.send(t)
}
pub fn try_send(&self, t: T) -> Result<(), TrySendError<T>> {
self.inner.try_send(t)
}
}
impl<T> Clone for SyncSender<T> {
fn clone(&self) -> Self {
let mut state = self.inner.state.borrow_mut();
state.known_senders += 1;
drop(state);
Self {
inner: self.inner.clone(),
}
}
}
impl<T> Drop for SyncSender<T> {
fn drop(&mut self) {
if ExecutionState::should_stop() {
return;
}
let mut state = self.inner.state.borrow_mut();
assert!(state.known_senders > 0);
state.known_senders -= 1;
if state.known_senders == 0 {
for &tid in state.waiting_receivers.iter() {
ExecutionState::with(|s| s.get_mut(tid).unblock());
}
}
}
}