use std::pin::Pin;
use std::sync::{Arc, Condvar, Mutex};
use std::task::{Context, Poll, Waker};
use crate::error::Error;
use crate::runtime::thread_local::RUNTIME_THREAD_LOCAL;
use crate::runtime::work::Work;
use crate::stream::Stream;
type Result<T> = std::result::Result<T, Error>;
pub type Closure<'closure> = Box<dyn FnOnce() + Send + 'closure>;
pub struct Future<'closure, T> {
shared: Arc<Mutex<Shared<'closure, T>>>,
completed: Arc<Condvar>,
_phantom: std::marker::PhantomData<&'closure ()>,
}
impl<'closure, T> Future<'closure, T> {
#[inline]
pub fn new<F>(call: F) -> Self
where
F: FnOnce() -> T + Send + 'closure,
T: Send + 'closure,
{
let shared = Arc::new(Mutex::new(Shared::new()));
let completed = Arc::new(Condvar::new());
let closure = Box::new({
let shared = shared.clone();
let completed = completed.clone();
move || {
let return_value = call();
let mut shared = shared.lock().unwrap();
match shared.state {
State::Running => {
shared.complete(return_value);
completed.notify_all();
if let Some(waker) = shared.waker.take() {
waker.wake();
}
}
_ => {
panic!("unexpected state");
}
}
}
});
shared.lock().unwrap().initialize(closure);
Self {
shared,
completed,
_phantom: Default::default(),
}
}
}
impl<'closure, T> std::future::Future for Future<'closure, T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut shared = self.shared.lock().unwrap();
match shared.state {
State::New => Poll::Pending,
State::Initialized => {
shared.running(cx.waker().clone());
let closure: Box<dyn FnOnce() + Send + 'closure> =
shared.closure.take().expect("initialized without function");
let closure: Box<dyn FnOnce() + Send + 'static> = unsafe {
std::mem::transmute(closure)
};
RUNTIME_THREAD_LOCAL.with(|runtime| {
runtime.enqueue(Work::new(closure)).expect("runtime broken");
});
Poll::Pending
}
State::Running => Poll::Pending,
State::Completed => {
shared.done();
Poll::Ready(shared.return_value.take().unwrap())
}
State::Done => {
panic!("future polled after completion");
}
}
}
}
impl<'closure, T> Drop for Future<'closure, T> {
fn drop(&mut self) {
let mut shared = self.shared.lock().unwrap();
if let State::Running = shared.state {
while !matches!(shared.state, State::Completed) {
shared = self.completed.wait(shared).unwrap();
}
}
}
}
pub struct SynchronizeFuture<'closure>(Future<'closure, Result<()>>);
impl<'closure> SynchronizeFuture<'closure> {
#[inline]
pub(crate) fn new(stream: &'closure Stream) -> Self {
let shared = Arc::new(Mutex::new(Shared::new()));
let completed = Arc::new(Condvar::new());
let closure = Box::new({
let shared = shared.clone();
let completed = completed.clone();
move || {
let callback = {
let shared = shared.clone();
let completed = completed.clone();
move || Self::complete(shared, completed, Ok(()))
};
if let Err(err) = stream.inner().add_callback(callback) {
Self::complete(shared, completed, Err(err));
}
}
});
shared.lock().unwrap().initialize(closure);
Self(Future {
shared,
completed,
_phantom: Default::default(),
})
}
#[inline]
fn complete(
shared: Arc<Mutex<Shared<Result<()>>>>,
completed: Arc<Condvar>,
return_value: Result<()>,
) {
if let Ok(mut shared) = shared.lock() {
match shared.state {
State::Running => {
shared.complete(return_value);
completed.notify_all();
if let Some(waker) = shared.waker.take() {
waker.wake();
}
}
_ => {
panic!("unexpected state");
}
}
}
}
}
impl<'closure> std::future::Future for SynchronizeFuture<'closure> {
type Output = Result<()>;
#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
Pin::new(&mut self.0).poll(cx)
}
}
struct Shared<'closure, T> {
state: State,
closure: Option<Closure<'closure>>,
waker: Option<Waker>,
return_value: Option<T>,
}
#[derive(Debug, Copy, Clone, PartialEq)]
enum State {
New,
Initialized,
Running,
Completed,
Done,
}
impl<'closure, T> Shared<'closure, T> {
fn new() -> Self {
Shared {
state: State::New,
closure: None,
waker: None,
return_value: None,
}
}
#[inline]
fn initialize(&mut self, closure: Closure<'closure>) {
self.closure = Some(closure);
self.state = State::Initialized;
}
#[inline]
fn running(&mut self, waker: Waker) {
self.waker = Some(waker);
self.state = State::Running;
}
#[inline]
fn complete(&mut self, return_value: T) {
self.return_value = Some(return_value);
self.state = State::Completed;
}
#[inline]
fn done(&mut self) {
self.state = State::Done;
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use super::*;
#[tokio::test]
async fn test_future() {
assert!(Future::new(|| true).await);
}
#[tokio::test]
async fn test_future_order() {
let first_future_completed = Arc::new(AtomicBool::new(false));
Future::new({
let first_future_completed = first_future_completed.clone();
move || {
first_future_completed.store(true, Ordering::SeqCst);
}
})
.await;
assert!(
Future::new({
let first_future_completed = first_future_completed.clone();
move || first_future_completed.load(Ordering::SeqCst)
})
.await
);
}
#[tokio::test]
async fn test_future_order_simple() {
let mut first_future_completed = false;
Future::new(|| first_future_completed = true).await;
assert!(Future::new(|| first_future_completed).await);
}
#[tokio::test]
async fn test_future_outlives_closure() {
let mut count_completed = 0;
let mut count_cancelled = 0;
for _ in 0..1_000 {
let mut start_of_closure = false;
let mut end_of_closure = false;
let future = Future::new(|| {
start_of_closure = true;
std::thread::sleep(std::time::Duration::from_millis(1));
end_of_closure = true;
});
let future_with_small_delay = async {
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
future.await
};
let _ =
tokio::time::timeout(std::time::Duration::from_nanos(0), future_with_small_delay)
.await;
assert!((start_of_closure && end_of_closure) || (!start_of_closure && !end_of_closure));
if end_of_closure {
count_completed += 1;
} else {
count_cancelled += 1;
}
}
println!("num completed: {count_completed}");
println!("num cancelled: {count_cancelled}");
}
#[tokio::test]
async fn test_future_outlives_closure_manual() {
let mut start_of_closure = false;
let mut end_of_closure = false;
let future = Future::new(|| {
start_of_closure = true;
std::thread::sleep(std::time::Duration::from_nanos(1000));
end_of_closure = true;
});
let future_with_small_delay = async {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
future.await
};
let _ = tokio::time::timeout(std::time::Duration::ZERO, future_with_small_delay).await;
assert!((!start_of_closure && !end_of_closure))
}
#[tokio::test]
async fn test_future_does_not_run_if_cancelled_before_polling() {
let mut start_of_closure = false;
let mut end_of_closure = false;
let future = Future::new(|| {
start_of_closure = true;
std::thread::sleep(std::time::Duration::from_nanos(1000));
end_of_closure = true;
});
drop(future);
assert!((!start_of_closure && !end_of_closure))
}
#[tokio::test]
async fn test_synchronization_future() {
let stream = crate::Stream::new().await.unwrap();
assert!(SynchronizeFuture::new(&stream).await.is_ok());
}
}