use crate::util::lock_or_clear;
use futures_channel::oneshot;
use std::{
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};
pub fn leak_context_ptr<T>(arc: &Arc<T>) -> usize {
Arc::into_raw(Arc::clone(arc)) as usize
}
pub unsafe fn borrow_context_ptr<'a, T>(ptr: usize) -> &'a T {
unsafe { &*(ptr as *const T) }
}
pub unsafe fn reclaim_context_ptr<T>(ptr: usize) {
unsafe {
drop(Arc::from_raw(ptr as *const T));
}
}
pub struct CompletionSignal<T: Send> {
sender: Mutex<Option<oneshot::Sender<T>>>,
}
pub struct CompletionFuture<T>(oneshot::Receiver<T>);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SignalCancelled;
impl std::fmt::Display for SignalCancelled {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("completion signal cancelled (sender dropped)")
}
}
impl std::error::Error for SignalCancelled {}
impl<T: Send> CompletionSignal<T> {
pub fn new() -> Self {
Self {
sender: Mutex::new(None),
}
}
pub fn listen(&self) -> CompletionFuture<T> {
let (tx, rx) = oneshot::channel();
let old = lock_or_clear(&self.sender).replace(tx);
if old.is_some() {
debug!("CompletionSignal::listen: replacing unconsumed sender");
}
CompletionFuture(rx)
}
pub fn signal(&self, value: T) {
if let Some(tx) = lock_or_clear(&self.sender).take() {
let _ = tx.send(value); }
}
}
impl<T: Send> Default for CompletionSignal<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Future for CompletionFuture<T> {
type Output = Result<T, SignalCancelled>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let receiver = &mut self.get_mut().0;
Pin::new(receiver)
.poll(cx)
.map(|r| r.map_err(|_| SignalCancelled))
}
}
pub async fn await_win32<T, E, F>(signal: &CompletionSignal<T>, start_op: F) -> Result<T, E>
where
T: Send,
E: From<SignalCancelled>,
F: FnOnce() -> Result<(), E>,
{
let future = signal.listen();
start_op()?; future.await.map_err(E::from)
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn signal_from_another_thread() {
let signal = Arc::new(CompletionSignal::<u32>::new());
let signal2 = Arc::clone(&signal);
let future = signal.listen();
std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(10));
signal2.signal(42);
});
let result = futures_executor::block_on(future);
assert_eq!(result, Ok(42));
}
#[test]
fn signal_before_poll_still_works() {
let signal: CompletionSignal<u32> = Default::default();
let future = signal.listen();
signal.signal(7); assert_eq!(futures_executor::block_on(future), Ok(7));
}
#[test]
fn dropped_sender_returns_cancelled() {
let signal = CompletionSignal::<u32>::new();
let future = signal.listen();
drop(signal); assert!(futures_executor::block_on(future).is_err());
}
#[test]
fn replaced_listener_does_not_panic() {
let signal = CompletionSignal::<u32>::new();
let _fut1 = signal.listen(); let fut2 = signal.listen(); signal.signal(99);
assert_eq!(futures_executor::block_on(fut2), Ok(99));
}
#[test]
fn sequential_reuse() {
let signal = CompletionSignal::<&str>::new();
let f1 = signal.listen();
signal.signal("step1");
assert_eq!(futures_executor::block_on(f1), Ok("step1"));
let f2 = signal.listen();
signal.signal("step2");
assert_eq!(futures_executor::block_on(f2), Ok("step2"));
}
#[test]
fn context_ptr_round_trip() {
let state = Arc::new(String::from("hello"));
let raw = leak_context_ptr(&state);
assert_eq!(Arc::strong_count(&state), 2);
unsafe {
let s: &String = borrow_context_ptr(raw);
assert_eq!(s, "hello");
reclaim_context_ptr::<String>(raw);
}
assert_eq!(Arc::strong_count(&state), 1);
assert_eq!(*state, "hello"); }
#[derive(Debug, PartialEq)]
enum TestError {
Sync(&'static str),
Cancelled,
}
impl From<SignalCancelled> for TestError {
fn from(_: SignalCancelled) -> Self {
TestError::Cancelled
}
}
#[test]
fn await_win32_sync_failure_skips_callback() {
let signal = CompletionSignal::<u32>::new();
let result = futures_executor::block_on(await_win32(&signal, || {
Err::<(), _>(TestError::Sync("sync fail"))
}));
assert_eq!(result, Err(TestError::Sync("sync fail")));
let _fut = signal.listen();
}
#[test]
fn await_win32_success() {
let signal = CompletionSignal::<u32>::new();
let result = futures_executor::block_on(async {
let signal_ref = &signal;
let future = signal_ref.listen();
signal_ref.signal(42);
future.await.map_err(TestError::from)
});
assert_eq!(result, Ok(42));
}
#[test]
fn await_win32_threaded_success() {
let signal = Arc::new(CompletionSignal::<u32>::new());
let signal2 = Arc::clone(&signal);
let result = futures_executor::block_on(await_win32(&signal, || {
std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(10));
signal2.signal(99);
});
Ok::<(), TestError>(())
}));
assert_eq!(result, Ok(99));
}
#[test]
fn await_win32_cancelled_signal() {
let signal = CompletionSignal::<u32>::new();
let future = signal.listen();
drop(signal);
let result: Result<u32, TestError> =
futures_executor::block_on(async { future.await.map_err(TestError::from) });
assert_eq!(result, Err(TestError::Cancelled));
}
}