use std::hash::{Hash, Hasher};
use std::pin::Pin;
use std::sync::atomic::Ordering::{Relaxed, SeqCst};
use std::sync::atomic::{AtomicBool, AtomicUsize};
use std::sync::{Arc, Weak};
use std::task::Poll;
use std::{fmt, task};
use std::{thread, usize};
use futures_core::stream::Stream;
use parking_lot::Mutex;
use tokio::sync::oneshot::{channel as sync_channel, Receiver};
use crate::actor::Actor;
use crate::handler::{Handler, Message};
use super::envelope::{Envelope, ToEnvelope};
use super::queue::{PopResult, Queue};
use super::SendError;
pub trait Sender<M>: Send
where
M::Result: Send,
M: Message + Send,
{
fn do_send(&self, msg: M) -> Result<(), SendError<M>>;
fn try_send(&self, msg: M) -> Result<(), SendError<M>>;
fn send(&self, msg: M) -> Result<Receiver<M::Result>, SendError<M>>;
fn boxed(&self) -> Box<dyn Sender<M> + Sync>;
fn hash(&self) -> usize;
fn connected(&self) -> bool;
}
pub(crate) trait WeakSender<M>: Send
where
M::Result: Send,
M: Message + Send,
{
fn upgrade(&self) -> Option<Box<dyn Sender<M> + Sync>>;
fn boxed(&self) -> Box<dyn WeakSender<M> + Sync>;
}
pub struct AddressSender<A: Actor> {
inner: Arc<Inner<A>>,
sender_task: Arc<Mutex<SenderTask>>,
maybe_parked: Arc<AtomicBool>,
}
impl<A: Actor> fmt::Debug for AddressSender<A> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("AddressSender")
.field("sender_task", &self.sender_task)
.field("maybe_parked", &self.maybe_parked)
.finish()
}
}
pub struct WeakAddressSender<A: Actor> {
inner: Weak<Inner<A>>,
}
impl<A: Actor> Clone for WeakAddressSender<A> {
fn clone(&self) -> WeakAddressSender<A> {
WeakAddressSender {
inner: self.inner.clone(),
}
}
}
impl<A: Actor> fmt::Debug for WeakAddressSender<A> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("WeakAddressSender").finish()
}
}
trait AssertKinds: Send + Sync + Clone {}
pub struct AddressReceiver<A: Actor> {
inner: Arc<Inner<A>>,
}
pub struct AddressSenderProducer<A: Actor> {
inner: Arc<Inner<A>>,
}
struct Inner<A: Actor> {
buffer: AtomicUsize,
state: AtomicUsize,
message_queue: Queue<Envelope<A>>,
parked_queue: Queue<Arc<Mutex<SenderTask>>>,
num_senders: AtomicUsize,
recv_task: Mutex<ReceiverTask>,
}
#[derive(Debug, Clone, Copy)]
struct State {
is_open: bool,
num_messages: usize,
}
#[derive(Debug)]
struct ReceiverTask {
unparked: bool,
task: Option<task::Waker>,
}
enum TryPark {
Parked,
NotEmpty,
}
const OPEN_MASK: usize = usize::MAX - (usize::MAX >> 1);
const INIT_STATE: usize = OPEN_MASK;
const MAX_CAPACITY: usize = !(OPEN_MASK);
const MAX_BUFFER: usize = MAX_CAPACITY >> 1;
#[derive(Debug)]
struct SenderTask {
task: Option<task::Waker>,
is_parked: bool,
}
impl SenderTask {
fn new() -> Self {
SenderTask {
task: None,
is_parked: false,
}
}
fn notify(&mut self) -> bool {
self.is_parked = false;
if let Some(task) = self.task.take() {
task.wake();
true
} else {
false
}
}
}
pub fn channel<A: Actor>(buffer: usize) -> (AddressSender<A>, AddressReceiver<A>) {
assert!(buffer < MAX_BUFFER, "requested buffer size too large");
let inner = Arc::new(Inner {
buffer: AtomicUsize::new(buffer),
state: AtomicUsize::new(INIT_STATE),
message_queue: Queue::new(),
parked_queue: Queue::new(),
num_senders: AtomicUsize::new(1),
recv_task: Mutex::new(ReceiverTask {
unparked: false,
task: None,
}),
});
let tx = AddressSender {
inner: Arc::clone(&inner),
sender_task: Arc::new(Mutex::new(SenderTask::new())),
maybe_parked: Arc::new(AtomicBool::new(false)),
};
let rx = AddressReceiver { inner };
(tx, rx)
}
impl<A: Actor> AddressSender<A> {
pub fn connected(&self) -> bool {
let curr = self.inner.state.load(SeqCst);
let state = decode_state(curr);
state.is_open
}
pub fn send<M>(&self, msg: M) -> Result<Receiver<M::Result>, SendError<M>>
where
A: Handler<M>,
A::Context: ToEnvelope<A, M>,
M::Result: Send,
M: Message + Send,
{
if !self.poll_unparked(false, None).is_ready() {
return Err(SendError::Full(msg));
}
let park_self = match self.inc_num_messages() {
Some(park_self) => park_self,
None => return Err(SendError::Closed(msg)),
};
if park_self {
self.park();
}
let (tx, rx) = sync_channel();
let env = <A::Context as ToEnvelope<A, M>>::pack(msg, Some(tx));
self.queue_push_and_signal(env);
Ok(rx)
}
pub fn try_send<M>(&self, msg: M, park: bool) -> Result<(), SendError<M>>
where
A: Handler<M>,
<A as Actor>::Context: ToEnvelope<A, M>,
M::Result: Send,
M: Message + Send + 'static,
{
if !self.poll_unparked(false, None).is_ready() {
return Err(SendError::Full(msg));
}
let park_self = match self.inc_num_messages() {
Some(park_self) => park_self,
None => return Err(SendError::Closed(msg)),
};
if park_self && park {
self.park();
}
let env = <A::Context as ToEnvelope<A, M>>::pack(msg, None);
self.queue_push_and_signal(env);
Ok(())
}
pub fn do_send<M>(&self, msg: M) -> Result<(), SendError<M>>
where
A: Handler<M>,
<A as Actor>::Context: ToEnvelope<A, M>,
M::Result: Send,
M: Message + Send,
{
if self.inc_num_messages().is_none() {
Err(SendError::Closed(msg))
} else {
let env = <A::Context as ToEnvelope<A, M>>::pack(msg, None);
self.queue_push_and_signal(env);
Ok(())
}
}
pub fn downgrade(&self) -> WeakAddressSender<A> {
WeakAddressSender {
inner: Arc::downgrade(&self.inner),
}
}
fn queue_push_and_signal(&self, msg: Envelope<A>) {
self.inner.message_queue.push(msg);
self.signal();
}
fn inc_num_messages(&self) -> Option<bool> {
let mut curr = self.inner.state.load(SeqCst);
loop {
let mut state = decode_state(curr);
if !state.is_open {
return None;
}
state.num_messages += 1;
let next = encode_state(&state);
match self
.inner
.state
.compare_exchange(curr, next, SeqCst, SeqCst)
{
Ok(_) => {
let buffer = self.inner.buffer.load(Relaxed);
let park_self = buffer != 0 && state.num_messages >= buffer;
return Some(park_self);
}
Err(actual) => curr = actual,
}
}
}
fn signal(&self) {
let task = {
let mut recv_task = self.inner.recv_task.lock();
if recv_task.unparked {
return;
}
recv_task.unparked = true;
recv_task.task.take()
};
if let Some(task) = task {
task.wake();
}
}
fn park(&self) {
{
let mut sender = self.sender_task.lock();
sender.task = None;
sender.is_parked = true;
}
self.inner.parked_queue.push(Arc::clone(&self.sender_task));
let state = decode_state(self.inner.state.load(SeqCst));
self.maybe_parked.store(state.is_open, Relaxed);
}
fn poll_unparked(
&self,
do_park: bool,
cx: Option<&mut task::Context<'_>>,
) -> Poll<()> {
if self.maybe_parked.load(Relaxed) {
let mut task = self.sender_task.lock();
if !task.is_parked {
self.maybe_parked.store(false, Relaxed);
return Poll::Ready(());
}
task.task = if do_park {
cx.map(|cx| cx.waker().clone())
} else {
None
};
Poll::Pending
} else {
Poll::Ready(())
}
}
}
impl<A, M> Sender<M> for AddressSender<A>
where
A: Handler<M>,
A::Context: ToEnvelope<A, M>,
M::Result: Send,
M: Message + Send + 'static,
{
fn do_send(&self, msg: M) -> Result<(), SendError<M>> {
self.do_send(msg)
}
fn try_send(&self, msg: M) -> Result<(), SendError<M>> {
self.try_send(msg, true)
}
fn send(&self, msg: M) -> Result<Receiver<M::Result>, SendError<M>> {
self.send(msg)
}
fn boxed(&self) -> Box<dyn Sender<M> + Sync> {
Box::new(self.clone())
}
fn hash(&self) -> usize {
let hash: *const _ = self.inner.as_ref();
hash as usize
}
fn connected(&self) -> bool {
self.connected()
}
}
impl<A: Actor> Clone for AddressSender<A> {
fn clone(&self) -> AddressSender<A> {
let mut curr = self.inner.num_senders.load(SeqCst);
loop {
if curr == self.inner.max_senders() {
panic!("cannot clone `Sender` -- too many outstanding senders");
}
debug_assert!(curr < self.inner.max_senders());
let next = curr + 1;
let actual = self.inner.num_senders.compare_and_swap(curr, next, SeqCst);
if actual == curr {
return AddressSender {
inner: Arc::clone(&self.inner),
sender_task: Arc::new(Mutex::new(SenderTask::new())),
maybe_parked: Arc::new(AtomicBool::new(false)),
};
}
curr = actual;
}
}
}
impl<A: Actor> Drop for AddressSender<A> {
fn drop(&mut self) {
let prev = self.inner.num_senders.fetch_sub(1, SeqCst);
if prev == 1 {
self.signal();
}
}
}
impl<A: Actor> PartialEq for AddressSender<A> {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.inner, &other.inner)
}
}
impl<A: Actor> Eq for AddressSender<A> {}
impl<A: Actor> Hash for AddressSender<A> {
fn hash<H: Hasher>(&self, state: &mut H) {
let hash: *const Inner<A> = self.inner.as_ref();
hash.hash(state);
}
}
impl<A: Actor> WeakAddressSender<A> {
pub fn upgrade(&self) -> Option<AddressSender<A>> {
match Weak::upgrade(&self.inner) {
Some(inner) => Some(AddressSenderProducer { inner }.sender()),
None => None,
}
}
}
impl<A, M> WeakSender<M> for WeakAddressSender<A>
where
A: Handler<M>,
A::Context: ToEnvelope<A, M>,
M::Result: Send,
M: Message + Send + 'static,
{
fn upgrade(&self) -> Option<Box<dyn Sender<M> + Sync>> {
if let Some(inner) = WeakAddressSender::upgrade(&self) {
Some(Box::new(inner))
} else {
None
}
}
fn boxed(&self) -> Box<dyn WeakSender<M> + Sync> {
Box::new(self.clone())
}
}
impl<A: Actor> AddressSenderProducer<A> {
pub fn connected(&self) -> bool {
self.inner.num_senders.load(SeqCst) != 0
}
pub fn capacity(&self) -> usize {
self.inner.buffer.load(Relaxed)
}
pub fn set_capacity(&mut self, cap: usize) {
let buffer = self.inner.buffer.load(Relaxed);
self.inner.buffer.store(cap, Relaxed);
if cap > buffer {
loop {
match unsafe { self.inner.parked_queue.pop() } {
PopResult::Data(task) => {
task.lock().notify();
}
PopResult::Empty => {
return;
}
PopResult::Inconsistent => {
thread::yield_now();
}
}
}
}
}
pub fn sender(&self) -> AddressSender<A> {
let mut curr = self.inner.num_senders.load(SeqCst);
loop {
if curr == self.inner.max_senders() {
panic!("cannot clone `Sender` -- too many outstanding senders");
}
let next = curr + 1;
let actual = self.inner.num_senders.compare_and_swap(curr, next, SeqCst);
if actual == curr {
return AddressSender {
inner: Arc::clone(&self.inner),
sender_task: Arc::new(Mutex::new(SenderTask::new())),
maybe_parked: Arc::new(AtomicBool::new(false)),
};
}
curr = actual;
}
}
}
impl<A: Actor> AddressReceiver<A> {
pub fn connected(&self) -> bool {
self.inner.num_senders.load(SeqCst) != 0
}
pub fn capacity(&self) -> usize {
self.inner.buffer.load(Relaxed)
}
pub fn set_capacity(&mut self, cap: usize) {
let buffer = self.inner.buffer.load(Relaxed);
self.inner.buffer.store(cap, Relaxed);
if cap > buffer {
loop {
match unsafe { self.inner.parked_queue.pop() } {
PopResult::Data(task) => {
task.lock().notify();
}
PopResult::Empty => {
return;
}
PopResult::Inconsistent => {
thread::yield_now();
}
}
}
}
}
pub fn sender(&self) -> AddressSender<A> {
let mut curr = self.inner.num_senders.load(SeqCst);
loop {
if curr == self.inner.max_senders() {
panic!("cannot clone `Sender` -- too many outstanding senders");
}
let next = curr + 1;
let actual = self.inner.num_senders.compare_and_swap(curr, next, SeqCst);
if actual == curr {
return AddressSender {
inner: Arc::clone(&self.inner),
sender_task: Arc::new(Mutex::new(SenderTask::new())),
maybe_parked: Arc::new(AtomicBool::new(false)),
};
}
curr = actual;
}
}
pub fn sender_producer(&self) -> AddressSenderProducer<A> {
AddressSenderProducer {
inner: self.inner.clone(),
}
}
fn next_message(&mut self) -> Poll<Option<Envelope<A>>> {
loop {
match unsafe { self.inner.message_queue.pop() } {
PopResult::Data(msg) => {
return Poll::Ready(Some(msg));
}
PopResult::Empty => {
return Poll::Pending;
}
PopResult::Inconsistent => {
thread::yield_now();
}
}
}
}
fn unpark_one(&mut self) {
loop {
match unsafe { self.inner.parked_queue.pop() } {
PopResult::Data(task) => {
if task.lock().notify() {
return;
}
}
PopResult::Empty => {
return;
}
PopResult::Inconsistent => {
thread::yield_now();
}
}
}
}
fn try_park(&self, cx: &mut task::Context<'_>) -> TryPark {
let mut recv_task = self.inner.recv_task.lock();
if recv_task.unparked {
recv_task.unparked = false;
return TryPark::NotEmpty;
}
recv_task.task = Some(cx.waker().clone());
TryPark::Parked
}
fn dec_num_messages(&self) {
let mut curr = self.inner.state.load(SeqCst);
loop {
let mut state = decode_state(curr);
state.num_messages -= 1;
let next = encode_state(&state);
match self
.inner
.state
.compare_exchange(curr, next, SeqCst, SeqCst)
{
Ok(_) => break,
Err(actual) => curr = actual,
}
}
}
}
impl<A: Actor> Stream for AddressReceiver<A> {
type Item = Envelope<A>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
let msg = match this.next_message() {
Poll::Ready(msg) => msg,
Poll::Pending => {
match this.try_park(cx) {
TryPark::Parked => {
return Poll::Pending;
}
TryPark::NotEmpty => {
continue;
}
}
}
};
this.unpark_one();
this.dec_num_messages();
return Poll::Ready(msg);
}
}
}
impl<A: Actor> Drop for AddressReceiver<A> {
fn drop(&mut self) {
let mut curr = self.inner.state.load(SeqCst);
loop {
let mut state = decode_state(curr);
if !state.is_open {
break;
}
state.is_open = false;
let next = encode_state(&state);
match self
.inner
.state
.compare_exchange(curr, next, SeqCst, SeqCst)
{
Ok(_) => break,
Err(actual) => curr = actual,
}
}
loop {
match unsafe { self.inner.parked_queue.pop() } {
PopResult::Data(task) => {
task.lock().notify();
}
PopResult::Empty => break,
PopResult::Inconsistent => thread::yield_now(),
}
}
while self.next_message().is_ready() {
}
}
}
impl<A: Actor> Inner<A> {
fn max_senders(&self) -> usize {
MAX_CAPACITY - self.buffer.load(Relaxed)
}
}
unsafe impl<A: Actor> Send for Inner<A> {}
unsafe impl<A: Actor> Sync for Inner<A> {}
fn decode_state(num: usize) -> State {
State {
is_open: num & OPEN_MASK == OPEN_MASK,
num_messages: num & MAX_CAPACITY,
}
}
fn encode_state(state: &State) -> usize {
let mut num = state.num_messages;
if state.is_open {
num |= OPEN_MASK;
}
num
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prelude::*;
use std::{thread, time};
struct Act;
impl Actor for Act {
type Context = Context<Act>;
}
struct Ping;
impl Message for Ping {
type Result = ();
}
impl Handler<Ping> for Act {
type Result = ();
fn handle(&mut self, _: Ping, _: &mut Context<Act>) {}
}
#[test]
fn test_cap() {
System::run(|| {
let (s1, mut recv) = channel::<Act>(1);
let s2 = recv.sender();
let arb = Arbiter::new();
arb.exec_fn(move || {
let _ = s1.send(Ping);
});
thread::sleep(time::Duration::from_millis(100));
let arb2 = Arbiter::new();
arb2.exec_fn(move || {
let _ = s2.send(Ping);
let _ = s2.send(Ping);
});
thread::sleep(time::Duration::from_millis(100));
let state = decode_state(recv.inner.state.load(SeqCst));
assert_eq!(state.num_messages, 2);
let p = loop {
match unsafe { recv.inner.parked_queue.pop() } {
PopResult::Data(task) => break Some(task),
PopResult::Empty => break None,
PopResult::Inconsistent => thread::yield_now(),
}
};
assert!(p.is_some());
recv.inner.parked_queue.push(p.unwrap());
recv.set_capacity(10);
thread::sleep(time::Duration::from_millis(100));
let state = decode_state(recv.inner.state.load(SeqCst));
assert_eq!(state.num_messages, 2);
let p = loop {
match unsafe { recv.inner.parked_queue.pop() } {
PopResult::Data(task) => break Some(task),
PopResult::Empty => break None,
PopResult::Inconsistent => thread::yield_now(),
}
};
assert!(p.is_none());
System::current().stop();
})
.unwrap();
}
}