use std::{
fmt,
hash::{Hash, Hasher},
pin::Pin,
sync::{
atomic::{
AtomicBool, AtomicUsize,
Ordering::{Relaxed, SeqCst},
},
Arc, Weak,
},
task::{self, Poll},
thread,
};
use futures_core::{stream::Stream, task::__internal::AtomicWaker};
use parking_lot::Mutex;
use tokio::sync::oneshot::{channel as oneshot_channel, Receiver as OneshotReceiver};
use crate::actor::Actor;
use crate::handler::{Handler, Message};
use super::envelope::{Envelope, ToEnvelope};
use super::queue::Queue;
use super::SendError;
pub trait Sender<M>: Send
where
M::Result: Send,
M: Message + Send,
{
fn do_send(&self, msg: M) -> Result<(), SendError<M>>;
fn try_send(&self, msg: M) -> Result<(), SendError<M>>;
fn send(&self, msg: M) -> Result<OneshotReceiver<M::Result>, SendError<M>>;
fn boxed(&self) -> Box<dyn Sender<M> + Sync>;
fn hash(&self) -> usize;
fn connected(&self) -> bool;
}
impl<S, M> Sender<M> for Box<S>
where
S: Sender<M> + ?Sized,
M::Result: Send,
M: Message + Send,
{
fn do_send(&self, msg: M) -> Result<(), SendError<M>> {
(**self).do_send(msg)
}
fn try_send(&self, msg: M) -> Result<(), SendError<M>> {
(**self).try_send(msg)
}
fn send(&self, msg: M) -> Result<OneshotReceiver<<M as Message>::Result>, SendError<M>> {
(**self).send(msg)
}
fn boxed(&self) -> Box<dyn Sender<M> + Sync> {
(**self).boxed()
}
fn hash(&self) -> usize {
(**self).hash()
}
fn connected(&self) -> bool {
(**self).connected()
}
}
pub(crate) trait WeakSender<M>: Send
where
M::Result: Send,
M: Message + Send,
{
fn upgrade(&self) -> Option<Box<dyn Sender<M> + Sync>>;
fn boxed(&self) -> Box<dyn WeakSender<M> + Sync>;
}
pub struct AddressSender<A: Actor> {
inner: Arc<Inner<A>>,
sender_task: Arc<Mutex<SenderTask>>,
maybe_parked: Arc<AtomicBool>,
}
impl<A: Actor> fmt::Debug for AddressSender<A> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("AddressSender")
.field("sender_task", &self.sender_task)
.field("maybe_parked", &self.maybe_parked)
.finish()
}
}
pub struct WeakAddressSender<A: Actor> {
inner: Weak<Inner<A>>,
}
impl<A: Actor> Clone for WeakAddressSender<A> {
fn clone(&self) -> WeakAddressSender<A> {
WeakAddressSender {
inner: self.inner.clone(),
}
}
}
impl<A: Actor> fmt::Debug for WeakAddressSender<A> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("WeakAddressSender").finish()
}
}
trait AssertKinds: Send + Sync + Clone {}
pub struct AddressReceiver<A: Actor> {
inner: Arc<Inner<A>>,
}
pub struct AddressSenderProducer<A: Actor> {
inner: Arc<Inner<A>>,
}
struct Inner<A: Actor> {
buffer: AtomicUsize,
state: AtomicUsize,
message_queue: Queue<Envelope<A>>,
parked_queue: Queue<Arc<Mutex<SenderTask>>>,
num_senders: AtomicUsize,
recv_task: AtomicWaker,
}
#[derive(Debug, Clone, Copy)]
struct State {
is_open: bool,
num_messages: usize,
}
impl State {
fn is_closed(&self) -> bool {
!self.is_open && self.num_messages == 0
}
}
const OPEN_MASK: usize = usize::MAX - (usize::MAX >> 1);
const INIT_STATE: usize = OPEN_MASK;
const MAX_CAPACITY: usize = !(OPEN_MASK);
const MAX_BUFFER: usize = MAX_CAPACITY >> 1;
#[derive(Debug)]
struct SenderTask {
task: Option<task::Waker>,
is_parked: bool,
}
impl SenderTask {
fn new() -> Self {
SenderTask {
task: None,
is_parked: false,
}
}
fn notify(&mut self) -> bool {
self.is_parked = false;
if let Some(task) = self.task.take() {
task.wake();
true
} else {
false
}
}
}
pub fn channel<A: Actor>(buffer: usize) -> (AddressSender<A>, AddressReceiver<A>) {
assert!(buffer < MAX_BUFFER, "requested buffer size too large");
let inner = Arc::new(Inner {
buffer: AtomicUsize::new(buffer),
state: AtomicUsize::new(INIT_STATE),
message_queue: Queue::new(),
parked_queue: Queue::new(),
num_senders: AtomicUsize::new(1),
recv_task: AtomicWaker::new(),
});
let tx = AddressSender {
inner: Arc::clone(&inner),
sender_task: Arc::new(Mutex::new(SenderTask::new())),
maybe_parked: Arc::new(AtomicBool::new(false)),
};
let rx = AddressReceiver { inner };
(tx, rx)
}
impl<A: Actor> AddressSender<A> {
pub fn connected(&self) -> bool {
let curr = self.inner.state.load(SeqCst);
let state = decode_state(curr);
state.is_open
}
pub fn send<M>(&self, msg: M) -> Result<OneshotReceiver<M::Result>, SendError<M>>
where
A: Handler<M>,
A::Context: ToEnvelope<A, M>,
M::Result: Send,
M: Message + Send,
{
if !self.poll_unparked(false, None).is_ready() {
return Err(SendError::Full(msg));
}
let park_self = match self.inc_num_messages() {
Some(num_messages) => {
let buffer = self.inner.buffer.load(Relaxed);
buffer != 0 && num_messages >= buffer
}
None => return Err(SendError::Closed(msg)),
};
if park_self {
self.park();
}
let (tx, rx) = oneshot_channel();
let env = <A::Context as ToEnvelope<A, M>>::pack(msg, Some(tx));
self.queue_push_and_signal(env);
Ok(rx)
}
pub fn try_send<M>(&self, msg: M, park: bool) -> Result<(), SendError<M>>
where
A: Handler<M>,
<A as Actor>::Context: ToEnvelope<A, M>,
M::Result: Send,
M: Message + Send + 'static,
{
if !self.poll_unparked(false, None).is_ready() {
return Err(SendError::Full(msg));
}
let park_self = match self.inc_num_messages() {
Some(num_messages) => {
let buffer = self.inner.buffer.load(Relaxed);
buffer != 0 && num_messages >= buffer
}
None => return Err(SendError::Closed(msg)),
};
if park_self && park {
self.park();
}
let env = <A::Context as ToEnvelope<A, M>>::pack(msg, None);
self.queue_push_and_signal(env);
Ok(())
}
pub fn do_send<M>(&self, msg: M) -> Result<(), SendError<M>>
where
A: Handler<M>,
<A as Actor>::Context: ToEnvelope<A, M>,
M::Result: Send,
M: Message + Send,
{
if self.inc_num_messages().is_none() {
Err(SendError::Closed(msg))
} else {
let env = <A::Context as ToEnvelope<A, M>>::pack(msg, None);
self.queue_push_and_signal(env);
Ok(())
}
}
pub fn downgrade(&self) -> WeakAddressSender<A> {
WeakAddressSender {
inner: Arc::downgrade(&self.inner),
}
}
fn queue_push_and_signal(&self, msg: Envelope<A>) {
self.inner.message_queue.push(msg);
self.inner.recv_task.wake();
}
fn inc_num_messages(&self) -> Option<usize> {
let mut curr = self.inner.state.load(SeqCst);
loop {
let mut state = decode_state(curr);
if !state.is_open {
return None;
}
state.num_messages += 1;
let next = encode_state(&state);
match self
.inner
.state
.compare_exchange(curr, next, SeqCst, SeqCst)
{
Ok(_) => {
return Some(state.num_messages);
}
Err(actual) => curr = actual,
}
}
}
fn park(&self) {
{
let mut sender = self.sender_task.lock();
sender.task = None;
sender.is_parked = true;
}
self.inner.parked_queue.push(Arc::clone(&self.sender_task));
let state = decode_state(self.inner.state.load(SeqCst));
self.maybe_parked.store(state.is_open, Relaxed);
}
fn poll_unparked(&self, do_park: bool, cx: Option<&mut task::Context<'_>>) -> Poll<()> {
if self.maybe_parked.load(Relaxed) {
let mut task = self.sender_task.lock();
if !task.is_parked {
self.maybe_parked.store(false, Relaxed);
return Poll::Ready(());
}
task.task = if do_park {
cx.map(|cx| cx.waker().clone())
} else {
None
};
Poll::Pending
} else {
Poll::Ready(())
}
}
}
impl<A, M> Sender<M> for AddressSender<A>
where
A: Handler<M>,
A::Context: ToEnvelope<A, M>,
M::Result: Send,
M: Message + Send + 'static,
{
fn do_send(&self, msg: M) -> Result<(), SendError<M>> {
self.do_send(msg)
}
fn try_send(&self, msg: M) -> Result<(), SendError<M>> {
self.try_send(msg, true)
}
fn send(&self, msg: M) -> Result<OneshotReceiver<M::Result>, SendError<M>> {
self.send(msg)
}
fn boxed(&self) -> Box<dyn Sender<M> + Sync> {
Box::new(self.clone())
}
fn hash(&self) -> usize {
let hash: *const _ = self.inner.as_ref();
hash as usize
}
fn connected(&self) -> bool {
self.connected()
}
}
impl<A: Actor> Clone for AddressSender<A> {
fn clone(&self) -> AddressSender<A> {
let mut curr = self.inner.num_senders.load(SeqCst);
loop {
if curr == self.inner.max_senders() {
panic!("cannot clone `Sender` -- too many outstanding senders");
}
debug_assert!(curr < self.inner.max_senders());
let next = curr + 1;
let actual = self.inner.num_senders.compare_and_swap(curr, next, SeqCst);
if actual == curr {
return AddressSender {
inner: Arc::clone(&self.inner),
sender_task: Arc::new(Mutex::new(SenderTask::new())),
maybe_parked: Arc::new(AtomicBool::new(false)),
};
}
curr = actual;
}
}
}
impl<A: Actor> Drop for AddressSender<A> {
fn drop(&mut self) {
let prev = self.inner.num_senders.fetch_sub(1, SeqCst);
if prev == 1 {
self.inner.recv_task.wake();
}
}
}
impl<A: Actor> PartialEq for AddressSender<A> {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.inner, &other.inner)
}
}
impl<A: Actor> Eq for AddressSender<A> {}
impl<A: Actor> Hash for AddressSender<A> {
fn hash<H: Hasher>(&self, state: &mut H) {
let hash: *const Inner<A> = self.inner.as_ref();
hash.hash(state);
}
}
impl<A: Actor> WeakAddressSender<A> {
pub fn upgrade(&self) -> Option<AddressSender<A>> {
match Weak::upgrade(&self.inner) {
Some(inner) => Some(AddressSenderProducer { inner }.sender()),
None => None,
}
}
}
impl<A, M> WeakSender<M> for WeakAddressSender<A>
where
A: Handler<M>,
A::Context: ToEnvelope<A, M>,
M::Result: Send,
M: Message + Send + 'static,
{
fn upgrade(&self) -> Option<Box<dyn Sender<M> + Sync>> {
if let Some(inner) = WeakAddressSender::upgrade(&self) {
Some(Box::new(inner))
} else {
None
}
}
fn boxed(&self) -> Box<dyn WeakSender<M> + Sync> {
Box::new(self.clone())
}
}
impl<A: Actor> AddressSenderProducer<A> {
pub fn connected(&self) -> bool {
self.inner.num_senders.load(SeqCst) != 0
}
pub fn capacity(&self) -> usize {
self.inner.buffer.load(Relaxed)
}
pub fn set_capacity(&mut self, cap: usize) {
let buffer = self.inner.buffer.load(Relaxed);
self.inner.buffer.store(cap, Relaxed);
if cap > buffer {
while let Some(task) = unsafe { self.inner.parked_queue.pop_spin() } {
task.lock().notify();
}
}
}
pub fn sender(&self) -> AddressSender<A> {
let mut curr = self.inner.num_senders.load(SeqCst);
loop {
if curr == self.inner.max_senders() {
panic!("cannot clone `Sender` -- too many outstanding senders");
}
let next = curr + 1;
let actual = self.inner.num_senders.compare_and_swap(curr, next, SeqCst);
if actual == curr {
return AddressSender {
inner: Arc::clone(&self.inner),
sender_task: Arc::new(Mutex::new(SenderTask::new())),
maybe_parked: Arc::new(AtomicBool::new(false)),
};
}
curr = actual;
}
}
}
impl<A: Actor> AddressReceiver<A> {
pub fn connected(&self) -> bool {
self.inner.num_senders.load(SeqCst) != 0
}
pub fn capacity(&self) -> usize {
self.inner.buffer.load(Relaxed)
}
pub fn set_capacity(&mut self, cap: usize) {
let buffer = self.inner.buffer.load(Relaxed);
self.inner.buffer.store(cap, Relaxed);
if cap > buffer {
while let Some(task) = unsafe { self.inner.parked_queue.pop_spin() } {
task.lock().notify();
}
}
}
pub fn sender(&self) -> AddressSender<A> {
let mut curr = self.inner.num_senders.load(SeqCst);
loop {
if curr == self.inner.max_senders() {
panic!("cannot clone `Sender` -- too many outstanding senders");
}
let next = curr + 1;
let actual = self.inner.num_senders.compare_and_swap(curr, next, SeqCst);
if actual == curr {
return AddressSender {
inner: Arc::clone(&self.inner),
sender_task: Arc::new(Mutex::new(SenderTask::new())),
maybe_parked: Arc::new(AtomicBool::new(false)),
};
}
curr = actual;
}
}
pub fn sender_producer(&self) -> AddressSenderProducer<A> {
AddressSenderProducer {
inner: self.inner.clone(),
}
}
fn next_message(&mut self) -> Poll<Option<Envelope<A>>> {
match unsafe { self.inner.message_queue.pop_spin() } {
Some(msg) => {
self.unpark_one();
self.dec_num_messages();
Poll::Ready(Some(msg))
}
None => {
let state = decode_state(self.inner.state.load(SeqCst));
if state.is_closed() {
Poll::Ready(None)
} else {
Poll::Pending
}
}
}
}
fn unpark_one(&mut self) {
if let Some(task) = unsafe { self.inner.parked_queue.pop_spin() } {
task.lock().notify();
}
}
fn dec_num_messages(&self) {
self.inner.state.fetch_sub(1, SeqCst);
}
}
impl<A: Actor> Stream for AddressReceiver<A> {
type Item = Envelope<A>;
fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match this.next_message() {
Poll::Ready(msg) => Poll::Ready(msg),
Poll::Pending => {
this.inner.recv_task.register(cx.waker());
this.next_message()
}
}
}
}
impl<A: Actor> Drop for AddressReceiver<A> {
fn drop(&mut self) {
self.inner.set_closed();
while let Some(task) = unsafe { self.inner.parked_queue.pop_spin() } {
task.lock().notify();
}
loop {
match self.next_message() {
Poll::Ready(Some(_)) => {}
Poll::Ready(None) => break,
Poll::Pending => {
let state = decode_state(self.inner.state.load(SeqCst));
if state.is_closed() {
break;
}
thread::yield_now();
}
}
}
}
}
impl<A: Actor> Inner<A> {
fn max_senders(&self) -> usize {
MAX_CAPACITY - self.buffer.load(Relaxed)
}
fn set_closed(&self) {
let curr = self.state.load(SeqCst);
if !decode_state(curr).is_open {
return;
}
self.state.fetch_and(!OPEN_MASK, SeqCst);
}
}
unsafe impl<A: Actor> Send for Inner<A> {}
unsafe impl<A: Actor> Sync for Inner<A> {}
fn decode_state(num: usize) -> State {
State {
is_open: num & OPEN_MASK == OPEN_MASK,
num_messages: num & MAX_CAPACITY,
}
}
fn encode_state(state: &State) -> usize {
let mut num = state.num_messages;
if state.is_open {
num |= OPEN_MASK;
}
num
}
#[cfg(test)]
mod tests {
use std::{thread, time};
use super::*;
use crate::address::queue::PopResult;
use crate::prelude::*;
struct Act;
impl Actor for Act {
type Context = Context<Act>;
}
struct Ping;
impl Message for Ping {
type Result = ();
}
impl Handler<Ping> for Act {
type Result = ();
fn handle(&mut self, _: Ping, _: &mut Context<Act>) {}
}
#[test]
fn test_cap() {
System::new().block_on(async {
let (s1, mut recv) = channel::<Act>(1);
let s2 = recv.sender();
let arb = Arbiter::new();
arb.spawn_fn(move || {
let _ = s1.send(Ping);
});
thread::sleep(time::Duration::from_millis(100));
let arb2 = Arbiter::new();
arb2.spawn_fn(move || {
let _ = s2.send(Ping);
let _ = s2.send(Ping);
});
thread::sleep(time::Duration::from_millis(100));
let state = decode_state(recv.inner.state.load(SeqCst));
assert_eq!(state.num_messages, 2);
let p = loop {
match unsafe { recv.inner.parked_queue.pop() } {
PopResult::Data(task) => break Some(task),
PopResult::Empty => break None,
PopResult::Inconsistent => thread::yield_now(),
}
};
assert!(p.is_some());
recv.inner.parked_queue.push(p.unwrap());
recv.set_capacity(10);
thread::sleep(time::Duration::from_millis(100));
let state = decode_state(recv.inner.state.load(SeqCst));
assert_eq!(state.num_messages, 2);
let p = loop {
match unsafe { recv.inner.parked_queue.pop() } {
PopResult::Data(task) => break Some(task),
PopResult::Empty => break None,
PopResult::Inconsistent => thread::yield_now(),
}
};
assert!(p.is_none());
System::current().stop();
});
}
}