#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))]
use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::Arc;
#[cfg(all(tokio_unstable, feature = "tracing"))]
use crate::util::trace;
use std::fmt;
use std::future::Future;
use std::mem::MaybeUninit;
use std::pin::Pin;
use std::sync::atomic::Ordering::{self, AcqRel, Acquire};
use std::task::Poll::{Pending, Ready};
use std::task::{Context, Poll, Waker};
#[derive(Debug)]
pub struct Sender<T> {
inner: Option<Arc<Inner<T>>>,
#[cfg(all(tokio_unstable, feature = "tracing"))]
resource_span: tracing::Span,
}
#[derive(Debug)]
pub struct Receiver<T> {
inner: Option<Arc<Inner<T>>>,
#[cfg(all(tokio_unstable, feature = "tracing"))]
resource_span: tracing::Span,
#[cfg(all(tokio_unstable, feature = "tracing"))]
async_op_span: tracing::Span,
#[cfg(all(tokio_unstable, feature = "tracing"))]
async_op_poll_span: tracing::Span,
}
pub mod error {
use std::fmt;
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct RecvError(pub(super) ());
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum TryRecvError {
Empty,
Closed,
}
impl fmt::Display for RecvError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "channel closed")
}
}
impl std::error::Error for RecvError {}
impl fmt::Display for TryRecvError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TryRecvError::Empty => write!(fmt, "channel empty"),
TryRecvError::Closed => write!(fmt, "channel closed"),
}
}
}
impl std::error::Error for TryRecvError {}
}
use self::error::*;
struct Inner<T> {
state: AtomicUsize,
value: UnsafeCell<Option<T>>,
tx_task: Task,
rx_task: Task,
}
struct Task(UnsafeCell<MaybeUninit<Waker>>);
impl Task {
unsafe fn will_wake(&self, cx: &mut Context<'_>) -> bool {
self.with_task(|w| w.will_wake(cx.waker()))
}
unsafe fn with_task<F, R>(&self, f: F) -> R
where
F: FnOnce(&Waker) -> R,
{
self.0.with(|ptr| {
let waker: *const Waker = (*ptr).as_ptr();
f(&*waker)
})
}
unsafe fn drop_task(&self) {
self.0.with_mut(|ptr| {
let ptr: *mut Waker = (*ptr).as_mut_ptr();
ptr.drop_in_place();
});
}
unsafe fn set_task(&self, cx: &mut Context<'_>) {
self.0.with_mut(|ptr| {
let ptr: *mut Waker = (*ptr).as_mut_ptr();
ptr.write(cx.waker().clone());
});
}
}
#[derive(Clone, Copy)]
struct State(usize);
#[track_caller]
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
#[cfg(all(tokio_unstable, feature = "tracing"))]
let resource_span = {
let location = std::panic::Location::caller();
let resource_span = tracing::trace_span!(
"runtime.resource",
concrete_type = "Sender|Receiver",
kind = "Sync",
loc.file = location.file(),
loc.line = location.line(),
loc.col = location.column(),
);
resource_span.in_scope(|| {
tracing::trace!(
target: "runtime::resource::state_update",
tx_dropped = false,
tx_dropped.op = "override",
)
});
resource_span.in_scope(|| {
tracing::trace!(
target: "runtime::resource::state_update",
rx_dropped = false,
rx_dropped.op = "override",
)
});
resource_span.in_scope(|| {
tracing::trace!(
target: "runtime::resource::state_update",
value_sent = false,
value_sent.op = "override",
)
});
resource_span.in_scope(|| {
tracing::trace!(
target: "runtime::resource::state_update",
value_received = false,
value_received.op = "override",
)
});
resource_span
};
let inner = Arc::new(Inner {
state: AtomicUsize::new(State::new().as_usize()),
value: UnsafeCell::new(None),
tx_task: Task(UnsafeCell::new(MaybeUninit::uninit())),
rx_task: Task(UnsafeCell::new(MaybeUninit::uninit())),
});
let tx = Sender {
inner: Some(inner.clone()),
#[cfg(all(tokio_unstable, feature = "tracing"))]
resource_span: resource_span.clone(),
};
#[cfg(all(tokio_unstable, feature = "tracing"))]
let async_op_span = resource_span
.in_scope(|| tracing::trace_span!("runtime.resource.async_op", source = "Receiver::await"));
#[cfg(all(tokio_unstable, feature = "tracing"))]
let async_op_poll_span =
async_op_span.in_scope(|| tracing::trace_span!("runtime.resource.async_op.poll"));
let rx = Receiver {
inner: Some(inner),
#[cfg(all(tokio_unstable, feature = "tracing"))]
resource_span,
#[cfg(all(tokio_unstable, feature = "tracing"))]
async_op_span,
#[cfg(all(tokio_unstable, feature = "tracing"))]
async_op_poll_span,
};
(tx, rx)
}
impl<T> Sender<T> {
pub fn send(mut self, t: T) -> Result<(), T> {
let inner = self.inner.take().unwrap();
inner.value.with_mut(|ptr| unsafe {
*ptr = Some(t);
});
if !inner.complete() {
unsafe {
return Err(inner.consume_value().unwrap());
}
}
#[cfg(all(tokio_unstable, feature = "tracing"))]
self.resource_span.in_scope(|| {
tracing::trace!(
target: "runtime::resource::state_update",
value_sent = true,
value_sent.op = "override",
)
});
Ok(())
}
pub async fn closed(&mut self) {
use crate::future::poll_fn;
#[cfg(all(tokio_unstable, feature = "tracing"))]
let resource_span = self.resource_span.clone();
#[cfg(all(tokio_unstable, feature = "tracing"))]
let closed = trace::async_op(
|| poll_fn(|cx| self.poll_closed(cx)),
resource_span,
"Sender::closed",
"poll_closed",
false,
);
#[cfg(not(all(tokio_unstable, feature = "tracing")))]
let closed = poll_fn(|cx| self.poll_closed(cx));
closed.await
}
pub fn is_closed(&self) -> bool {
let inner = self.inner.as_ref().unwrap();
let state = State::load(&inner.state, Acquire);
state.is_closed()
}
pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> {
let coop = ready!(crate::runtime::coop::poll_proceed(cx));
let inner = self.inner.as_ref().unwrap();
let mut state = State::load(&inner.state, Acquire);
if state.is_closed() {
coop.made_progress();
return Poll::Ready(());
}
if state.is_tx_task_set() {
let will_notify = unsafe { inner.tx_task.will_wake(cx) };
if !will_notify {
state = State::unset_tx_task(&inner.state);
if state.is_closed() {
State::set_tx_task(&inner.state);
coop.made_progress();
return Ready(());
} else {
unsafe { inner.tx_task.drop_task() };
}
}
}
if !state.is_tx_task_set() {
unsafe {
inner.tx_task.set_task(cx);
}
state = State::set_tx_task(&inner.state);
if state.is_closed() {
coop.made_progress();
return Ready(());
}
}
Pending
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
if let Some(inner) = self.inner.as_ref() {
inner.complete();
#[cfg(all(tokio_unstable, feature = "tracing"))]
self.resource_span.in_scope(|| {
tracing::trace!(
target: "runtime::resource::state_update",
tx_dropped = true,
tx_dropped.op = "override",
)
});
}
}
}
impl<T> Receiver<T> {
pub fn close(&mut self) {
if let Some(inner) = self.inner.as_ref() {
inner.close();
#[cfg(all(tokio_unstable, feature = "tracing"))]
self.resource_span.in_scope(|| {
tracing::trace!(
target: "runtime::resource::state_update",
rx_dropped = true,
rx_dropped.op = "override",
)
});
}
}
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
let result = if let Some(inner) = self.inner.as_ref() {
let state = State::load(&inner.state, Acquire);
if state.is_complete() {
match unsafe { inner.consume_value() } {
Some(value) => {
#[cfg(all(tokio_unstable, feature = "tracing"))]
self.resource_span.in_scope(|| {
tracing::trace!(
target: "runtime::resource::state_update",
value_received = true,
value_received.op = "override",
)
});
Ok(value)
}
None => Err(TryRecvError::Closed),
}
} else if state.is_closed() {
Err(TryRecvError::Closed)
} else {
return Err(TryRecvError::Empty);
}
} else {
Err(TryRecvError::Closed)
};
self.inner = None;
result
}
#[track_caller]
#[cfg(feature = "sync")]
pub fn blocking_recv(self) -> Result<T, RecvError> {
crate::future::block_on(self)
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
if let Some(inner) = self.inner.as_ref() {
inner.close();
#[cfg(all(tokio_unstable, feature = "tracing"))]
self.resource_span.in_scope(|| {
tracing::trace!(
target: "runtime::resource::state_update",
rx_dropped = true,
rx_dropped.op = "override",
)
});
}
}
}
impl<T> Future for Receiver<T> {
type Output = Result<T, RecvError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
#[cfg(all(tokio_unstable, feature = "tracing"))]
let _res_span = self.resource_span.clone().entered();
#[cfg(all(tokio_unstable, feature = "tracing"))]
let _ao_span = self.async_op_span.clone().entered();
#[cfg(all(tokio_unstable, feature = "tracing"))]
let _ao_poll_span = self.async_op_poll_span.clone().entered();
let ret = if let Some(inner) = self.as_ref().get_ref().inner.as_ref() {
#[cfg(all(tokio_unstable, feature = "tracing"))]
let res = ready!(trace_poll_op!("poll_recv", inner.poll_recv(cx)))?;
#[cfg(any(not(tokio_unstable), not(feature = "tracing")))]
let res = ready!(inner.poll_recv(cx))?;
res
} else {
panic!("called after complete");
};
self.inner = None;
Ready(Ok(ret))
}
}
impl<T> Inner<T> {
fn complete(&self) -> bool {
let prev = State::set_complete(&self.state);
if prev.is_closed() {
return false;
}
if prev.is_rx_task_set() {
unsafe {
self.rx_task.with_task(Waker::wake_by_ref);
}
}
true
}
fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
let coop = ready!(crate::runtime::coop::poll_proceed(cx));
let mut state = State::load(&self.state, Acquire);
if state.is_complete() {
coop.made_progress();
match unsafe { self.consume_value() } {
Some(value) => Ready(Ok(value)),
None => Ready(Err(RecvError(()))),
}
} else if state.is_closed() {
coop.made_progress();
Ready(Err(RecvError(())))
} else {
if state.is_rx_task_set() {
let will_notify = unsafe { self.rx_task.will_wake(cx) };
if !will_notify {
state = State::unset_rx_task(&self.state);
if state.is_complete() {
State::set_rx_task(&self.state);
coop.made_progress();
return match unsafe { self.consume_value() } {
Some(value) => Ready(Ok(value)),
None => Ready(Err(RecvError(()))),
};
} else {
unsafe { self.rx_task.drop_task() };
}
}
}
if !state.is_rx_task_set() {
unsafe {
self.rx_task.set_task(cx);
}
state = State::set_rx_task(&self.state);
if state.is_complete() {
coop.made_progress();
match unsafe { self.consume_value() } {
Some(value) => Ready(Ok(value)),
None => Ready(Err(RecvError(()))),
}
} else {
Pending
}
} else {
Pending
}
}
}
fn close(&self) {
let prev = State::set_closed(&self.state);
if prev.is_tx_task_set() && !prev.is_complete() {
unsafe {
self.tx_task.with_task(Waker::wake_by_ref);
}
}
}
unsafe fn consume_value(&self) -> Option<T> {
self.value.with_mut(|ptr| (*ptr).take())
}
}
unsafe impl<T: Send> Send for Inner<T> {}
unsafe impl<T: Send> Sync for Inner<T> {}
fn mut_load(this: &mut AtomicUsize) -> usize {
this.with_mut(|v| *v)
}
impl<T> Drop for Inner<T> {
fn drop(&mut self) {
let state = State(mut_load(&mut self.state));
if state.is_rx_task_set() {
unsafe {
self.rx_task.drop_task();
}
}
if state.is_tx_task_set() {
unsafe {
self.tx_task.drop_task();
}
}
}
}
impl<T: fmt::Debug> fmt::Debug for Inner<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
use std::sync::atomic::Ordering::Relaxed;
fmt.debug_struct("Inner")
.field("state", &State::load(&self.state, Relaxed))
.finish()
}
}
const RX_TASK_SET: usize = 0b00001;
const VALUE_SENT: usize = 0b00010;
const CLOSED: usize = 0b00100;
const TX_TASK_SET: usize = 0b01000;
impl State {
fn new() -> State {
State(0)
}
fn is_complete(self) -> bool {
self.0 & VALUE_SENT == VALUE_SENT
}
fn set_complete(cell: &AtomicUsize) -> State {
let mut state = cell.load(Ordering::Relaxed);
loop {
if State(state).is_closed() {
break;
}
match cell.compare_exchange_weak(
state,
state | VALUE_SENT,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(actual) => state = actual,
}
}
State(state)
}
fn is_rx_task_set(self) -> bool {
self.0 & RX_TASK_SET == RX_TASK_SET
}
fn set_rx_task(cell: &AtomicUsize) -> State {
let val = cell.fetch_or(RX_TASK_SET, AcqRel);
State(val | RX_TASK_SET)
}
fn unset_rx_task(cell: &AtomicUsize) -> State {
let val = cell.fetch_and(!RX_TASK_SET, AcqRel);
State(val & !RX_TASK_SET)
}
fn is_closed(self) -> bool {
self.0 & CLOSED == CLOSED
}
fn set_closed(cell: &AtomicUsize) -> State {
let val = cell.fetch_or(CLOSED, Acquire);
State(val)
}
fn set_tx_task(cell: &AtomicUsize) -> State {
let val = cell.fetch_or(TX_TASK_SET, AcqRel);
State(val | TX_TASK_SET)
}
fn unset_tx_task(cell: &AtomicUsize) -> State {
let val = cell.fetch_and(!TX_TASK_SET, AcqRel);
State(val & !TX_TASK_SET)
}
fn is_tx_task_set(self) -> bool {
self.0 & TX_TASK_SET == TX_TASK_SET
}
fn as_usize(self) -> usize {
self.0
}
fn load(cell: &AtomicUsize, order: Ordering) -> State {
let val = cell.load(order);
State(val)
}
}
impl fmt::Debug for State {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("State")
.field("is_complete", &self.is_complete())
.field("is_closed", &self.is_closed())
.field("is_rx_task_set", &self.is_rx_task_set())
.field("is_tx_task_set", &self.is_tx_task_set())
.finish()
}
}