use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use crate::error::Result;
struct Shared<T> {
value: Option<Result<T>>,
waker: Option<Waker>,
}
pub struct CompletionFuture<T> {
shared: Arc<Mutex<Shared<T>>>,
}
unsafe impl<T: Send> Send for CompletionFuture<T> {}
impl<T: Send> Future for CompletionFuture<T> {
type Output = Result<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut shared = self.shared.lock().unwrap();
if let Some(value) = shared.value.take() {
Poll::Ready(value)
} else {
shared.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
impl<T: Send> CompletionFuture<T> {
pub fn block_on(self) -> Result<T> {
{
let mut shared = self.shared.lock().unwrap();
if let Some(value) = shared.value.take() {
return value;
}
}
let pair = Arc::new((std::sync::Mutex::new(false), std::sync::Condvar::new()));
let pair_for_waker = pair.clone();
{
let mut shared = self.shared.lock().unwrap();
if let Some(value) = shared.value.take() {
return value;
}
let waker = condvar_waker(pair_for_waker);
shared.waker = Some(waker);
}
let (lock, cvar) = &*pair;
let mut ready = lock.lock().unwrap();
while !*ready {
ready = cvar.wait(ready).unwrap();
}
let mut shared = self.shared.lock().unwrap();
shared.value.take().expect("waker fired but no value was set")
}
}
pub(crate) fn completion_channel<T: Send>() -> (CompletionSender<T>, CompletionFuture<T>) {
let shared = Arc::new(Mutex::new(Shared {
value: None,
waker: None,
}));
let sender = CompletionSender {
shared: shared.clone(),
};
let future = CompletionFuture { shared };
(sender, future)
}
pub(crate) struct CompletionSender<T> {
shared: Arc<Mutex<Shared<T>>>,
}
unsafe impl<T: Send> Send for CompletionSender<T> {}
unsafe impl<T: Send> Sync for CompletionSender<T> {}
impl<T: Send> CompletionSender<T> {
pub fn send(self, value: Result<T>) {
let mut shared = self.shared.lock().unwrap();
shared.value = Some(value);
if let Some(waker) = shared.waker.take() {
waker.wake();
}
}
}
fn condvar_waker(
pair: Arc<(std::sync::Mutex<bool>, std::sync::Condvar)>,
) -> Waker {
use std::task::{RawWaker, RawWakerVTable};
type CondvarPair = (std::sync::Mutex<bool>, std::sync::Condvar);
unsafe fn clone_fn(data: *const ()) -> RawWaker {
let arc = Arc::from_raw(data as *const CondvarPair);
let cloned = arc.clone();
std::mem::forget(arc);
RawWaker::new(Arc::into_raw(cloned) as *const (), &VTABLE)
}
unsafe fn wake_fn(data: *const ()) {
let arc = Arc::from_raw(data as *const CondvarPair);
let (lock, cvar) = &*arc;
let mut ready = lock.lock().unwrap();
*ready = true;
cvar.notify_one();
}
unsafe fn wake_by_ref_fn(data: *const ()) {
let arc = Arc::from_raw(data as *const CondvarPair);
{
let (lock, cvar) = &*arc;
let mut ready = lock.lock().unwrap();
*ready = true;
cvar.notify_one();
drop(ready);
}
std::mem::forget(arc);
}
unsafe fn drop_fn(data: *const ()) {
drop(Arc::from_raw(data as *const CondvarPair));
}
static VTABLE: RawWakerVTable =
RawWakerVTable::new(clone_fn, wake_fn, wake_by_ref_fn, drop_fn);
let data = Arc::into_raw(pair) as *const ();
unsafe { Waker::from_raw(RawWaker::new(data, &VTABLE)) }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::{Error, ErrorKind};
#[test]
fn send_then_block_on() {
let (sender, future) = completion_channel::<String>();
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(10));
sender.send(Ok("hello".to_string()));
});
let result = future.block_on().unwrap();
assert_eq!(result, "hello");
}
#[test]
fn error_propagation() {
let (sender, future) = completion_channel::<String>();
std::thread::spawn(move || {
sender.send(Err(Error::new(ErrorKind::ModelLoad, "test error")));
});
let err = future.block_on().unwrap_err();
assert_eq!(err.kind(), &ErrorKind::ModelLoad);
}
#[test]
fn immediate_value() {
let (sender, future) = completion_channel::<i32>();
sender.send(Ok(42));
assert_eq!(future.block_on().unwrap(), 42);
}
#[test]
fn poll_via_future_trait() {
use std::task::{RawWaker, RawWakerVTable};
fn noop_waker() -> Waker {
unsafe fn clone(_: *const ()) -> RawWaker {
RawWaker::new(std::ptr::null(), &NOOP_VTABLE)
}
unsafe fn noop(_: *const ()) {}
static NOOP_VTABLE: RawWakerVTable =
RawWakerVTable::new(clone, noop, noop, noop);
unsafe {
Waker::from_raw(RawWaker::new(std::ptr::null(), &NOOP_VTABLE))
}
}
let (sender, mut future) = completion_channel::<u64>();
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let pinned = Pin::new(&mut future);
assert!(pinned.poll(&mut cx).is_pending());
sender.send(Ok(99));
let pinned = Pin::new(&mut future);
match pinned.poll(&mut cx) {
Poll::Ready(Ok(v)) => assert_eq!(v, 99),
other => panic!("expected Ready(Ok(99)), got {other:?}"),
}
}
#[test]
fn concurrent_stress() {
let handles: Vec<_> = (0..50)
.map(|i| {
let (sender, future) = completion_channel::<i32>();
let h = std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_micros(i * 10));
sender.send(Ok(i as i32));
});
(h, future)
})
.collect();
for (i, (handle, future)) in handles.into_iter().enumerate() {
let val = future.block_on().unwrap();
assert_eq!(val, i as i32);
handle.join().unwrap();
}
}
}