#![warn(missing_docs)]
#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/", env!("CARGO_PKG_README")))]
use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::{Arc, Mutex, Weak};
use std::task::{Context, Poll, Waker};
pub fn channel<T>() -> (Resolve<T>, Promise<T>) {
let inner = Arc::new(Inner::new());
(
Resolve {
inner: Some(Arc::downgrade(&inner)),
},
Promise { inner },
)
}
pub struct Promise<T> {
inner: Arc<Inner<T>>,
}
impl<T> Promise<T> {
pub async fn wait(&self) -> Option<&T> {
std::future::poll_fn(|cx| self.inner.poll_get(cx)).await
}
pub fn try_get(&self) -> Result<&T, PromiseError> {
self.inner
.get()
.ok_or(PromiseError::Empty)
.and_then(|value_opt| value_opt.ok_or(PromiseError::Dropped))
}
pub fn is_done(&self) -> bool {
self.inner.get().is_some()
}
}
#[derive(Debug, Eq, PartialEq, Clone, thiserror::Error)]
pub enum PromiseError {
#[error("value not yet sent")]
Empty,
#[error("closed before a value was sent")]
Dropped,
}
pub struct Resolve<T> {
inner: Option<Weak<Inner<T>>>,
}
impl<T> Resolve<T> {
pub fn into_resolve(mut self, value: T) {
self.resolve(value).unwrap_or_else(|_| panic!("already resolved"));
}
pub fn resolve(&mut self, value: T) -> Result<(), T> {
let Some(inner) = self.inner.take() else {
return Err(value);
};
if let Some(inner) = inner.upgrade() {
unsafe {
inner.resolve(Some(value));
}
}
Ok(())
}
}
impl<T> Drop for Resolve<T> {
fn drop(&mut self) {
if let Some(inner) = self.inner.take().and_then(|weak| weak.upgrade()) {
unsafe {
inner.resolve(None);
}
}
}
}
const BIT: u8 = 0b1;
const FLAG_COMPLETED: u8 = BIT;
const FLAG_VALUE_SET: u8 = BIT << 1;
struct Inner<T> {
flag: AtomicU8,
value: UnsafeCell<MaybeUninit<T>>,
wakers: Mutex<Vec<Waker>>,
}
impl<T> Inner<T> {
const fn new() -> Self {
Self {
flag: AtomicU8::new(0),
value: UnsafeCell::new(MaybeUninit::uninit()),
wakers: Mutex::new(Vec::new()),
}
}
fn poll_get<'a>(&'a self, cx: &mut Context<'_>) -> Poll<Option<&'a T>> {
if let Some(value_opt) = self.get() {
return Poll::Ready(value_opt);
}
{
let mut wakers = self.wakers.lock().unwrap();
if let Some(value_opt) = self.get() {
return Poll::Ready(value_opt);
}
wakers.push(cx.waker().clone());
}
Poll::Pending
}
fn get(&self) -> Option<Option<&T>> {
let flag = self.flag.load(Ordering::Acquire);
let completed = 0 != (flag & FLAG_COMPLETED);
if completed {
let value_set = 0 != (flag & FLAG_VALUE_SET);
if value_set {
Some(Some(unsafe { &*(*self.value.get()).as_ptr() }))
} else {
Some(None)
}
} else {
None
}
}
unsafe fn resolve(&self, value_or_dropped: Option<T>) {
let flag = if let Some(value) = value_or_dropped {
unsafe {
self.value.get().write(MaybeUninit::new(value));
}
FLAG_COMPLETED | FLAG_VALUE_SET
} else {
FLAG_COMPLETED
};
self.flag.store(flag, Ordering::Release);
let wakers = { std::mem::take(&mut *self.wakers.lock().unwrap()) };
wakers.into_iter().for_each(Waker::wake);
}
}
unsafe impl<T: Sync + Send> Sync for Inner<T> {}
unsafe impl<T: Send> Send for Inner<T> {}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_basic() {
let (resolve, promise) = channel::<i32>();
let handle = tokio::task::spawn(async move {
resolve.into_resolve(42);
});
let value = promise.wait().await;
assert_eq!(Some(&42), value);
handle.await.unwrap();
}
#[tokio::test]
async fn test_multiple() {
let (resolve, promise) = channel::<i32>();
let promise1 = Arc::new(promise);
let promise2 = Arc::clone(&promise1);
let promise3 = Arc::clone(&promise1);
let read1 = tokio::task::spawn(async move {
let value = promise1.wait().await;
assert_eq!(Some(&42), value);
});
let read2 = tokio::task::spawn(async move {
let value = promise2.wait().await;
assert_eq!(Some(&42), value);
});
let read3 = tokio::task::spawn(async move {
let value = promise3.wait().await;
assert_eq!(Some(&42), value);
});
let resolve = tokio::task::spawn(async move {
resolve.into_resolve(42);
});
read1.await.unwrap();
read2.await.unwrap();
read3.await.unwrap();
resolve.await.unwrap();
}
#[tokio::test]
async fn test_try_get() {
let (resolve, promise) = channel::<i32>();
assert_eq!(promise.try_get(), Err(PromiseError::Empty));
resolve.into_resolve(42);
assert_eq!(promise.try_get(), Ok(&42));
}
#[tokio::test]
async fn test_dropped() {
let (resolve, promise) = channel::<i32>();
assert!(!promise.is_done());
drop(resolve);
assert!(promise.is_done());
assert_eq!(promise.try_get(), Err(PromiseError::Dropped));
assert_eq!(promise.wait().await, None);
}
}