use futures::future::FusedFuture;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use crate::common::ready_future_state::ReadyFutureState;
use super::ready_future_state::ReadyFutureResult;
pub type ReadyFutureStateSafe<T> = Arc<Mutex<ReadyFutureState<T>>>;
pub struct ReadyFuture<T> {
shared_state: ReadyFutureStateSafe<T>,
}
impl<T> ReadyFuture<T> {
pub fn new() -> Self {
Self::with_shared_state(Arc::new(Mutex::new(ReadyFutureState::new())))
}
pub fn new_completed(value: T) -> Self {
Self::with_shared_state(Arc::new(Mutex::new(ReadyFutureState::new_completed(value))))
}
pub fn with_shared_state(shared_state: ReadyFutureStateSafe<T>) -> Self {
ReadyFuture { shared_state }
}
pub fn clone_state(&self) -> ReadyFutureStateSafe<T> {
self.shared_state.clone()
}
pub(crate) fn get_state(&self) -> std::sync::MutexGuard<'_, ReadyFutureState<T>> {
self.shared_state.lock().unwrap()
}
pub fn new_resolved(val: T) -> Self {
let result = Self::new();
result.get_state().complete(val);
result
}
pub fn complete(&self, val: T) {
self.get_state().complete(val)
}
pub fn terminate(&self) {
self.get_state().terminate()
}
pub fn is_pending(&self) -> bool {
self.get_state().is_pending()
}
pub fn is_fulfilled(&self) -> bool {
self.get_state().is_fulfilled()
}
pub fn is_completed(&self) -> bool {
self.get_state().is_completed()
}
pub fn is_aborted(&self) -> bool {
self.get_state().is_aborted()
}
pub fn is_timeouted(&self) -> bool {
self.get_state().is_timeouted()
}
pub fn is_terminated(&self) -> bool {
self.get_state().is_terminated()
}
}
impl<T> Future for ReadyFuture<T> {
type Output = ReadyFutureResult<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut shared_state = self.shared_state.lock().unwrap();
if shared_state.is_fulfilled() {
let mut result = ReadyFutureResult::Terminated;
std::mem::swap(&mut result, &mut shared_state.result);
Poll::Ready(result)
} else {
if let None = shared_state.waker {
shared_state.waker = Some(cx.waker().clone());
}
Poll::Pending
}
}
}
impl<T> Clone for ReadyFuture<T> {
fn clone(&self) -> Self {
ReadyFuture {
shared_state: self.clone_state(),
}
}
}
impl<T> FusedFuture for ReadyFuture<T> {
fn is_terminated(&self) -> bool {
self.shared_state.lock().unwrap().is_terminated()
}
}
#[cfg(test)]
mod test {
use super::*;
use futures::{executor::block_on, select};
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Context, Poll, Wake};
use super::{ReadyFuture, ReadyFutureResult};
struct TestWaker {
woken: Arc<AtomicBool>,
}
impl Wake for TestWaker {
fn wake(self: Arc<Self>) {
self.woken.store(true, Ordering::SeqCst);
}
}
#[test]
fn test_new_fulfilled() {
let f = ReadyFuture::new_resolved(42_usize);
let result = block_on(f.clone());
assert!(matches!(result, ReadyFutureResult::Completed(42)));
match result {
ReadyFutureResult::Completed(val) => assert_eq!(val, 42),
_ => unreachable!(),
}
let result = block_on(f);
assert!(matches!(result, ReadyFutureResult::Terminated));
}
#[test]
fn test_pending_state() {
let f = ReadyFuture::<usize>::new();
assert!(f.is_pending(), "Future should be pending initially");
let state = f.get_state();
assert!(!state.is_fulfilled(), "State should not be fulfilled");
assert!(!state.is_aborted(), "State should not be aborted");
assert!(!state.is_timeouted(), "State should not be timed out");
assert!(!state.is_terminated(), "State should not be terminated");
}
#[test]
fn test_abort() {
let f = ReadyFuture::<usize>::new();
{
let mut state = f.get_state();
state.abort();
}
let result = block_on(f);
assert!(matches!(result, ReadyFutureResult::Aborted));
}
#[test]
fn test_timeout() {
let f = ReadyFuture::<usize>::new();
{
let mut state = f.get_state();
state.timeout();
}
let result = block_on(f);
assert!(matches!(result, ReadyFutureResult::Timeout));
}
#[test]
fn test_terminated() {
let mut f = ReadyFuture::new_resolved(1_usize);
let mut f_clone = f.clone();
block_on(async {
let result = select! {
_ = f_clone => 0,
complete => 100_usize,
};
assert_eq!(result, 0, "Should resolve immediately");
});
f.terminate();
let result = block_on(async {
select! {
_ = f => { 100 },
complete => 200_usize,
}
});
assert_eq!(result, 200, "Terminated future should not resolve");
assert!(f.is_terminated(), "Future should be terminated");
}
#[test]
fn test_clone_concurrent_access() {
let f = ReadyFuture::new();
let f_clone1 = f.clone();
let f_clone2 = f.clone();
f.get_state().complete(99_usize);
let result1 = block_on(f_clone1);
assert!(matches!(result1, ReadyFutureResult::Completed(99)));
let result2 = block_on(f_clone2);
assert!(matches!(result2, ReadyFutureResult::Terminated));
let result3 = block_on(f);
assert!(matches!(result3, ReadyFutureResult::Terminated));
}
#[test]
fn test_waker_invocation() {
let f = ReadyFuture::<usize>::new();
let woken = Arc::new(AtomicBool::new(false));
let waker = Arc::new(TestWaker {
woken: woken.clone(),
});
let waker = std::task::Waker::from(waker);
let mut cx = Context::from_waker(&waker);
let mut f_clone = f.clone();
let pinned = Pin::new(&mut f_clone);
let result = pinned.poll(&mut cx);
assert!(matches!(result, Poll::Pending));
assert!(
!woken.load(Ordering::SeqCst),
"Waker should not be invoked yet"
);
f.get_state().complete(42);
assert!(
woken.load(Ordering::SeqCst),
"Waker should be invoked after fulfill"
);
let result = block_on(f);
assert!(matches!(result, ReadyFutureResult::Completed(42)));
}
#[test]
fn test_multiple_polls_pending() {
let f = ReadyFuture::<usize>::new();
let woken = Arc::new(AtomicBool::new(false));
let waker = Arc::new(TestWaker {
woken: woken.clone(),
});
let waker = std::task::Waker::from(waker);
let mut cx = Context::from_waker(&waker);
let mut f_clone = f.clone();
let pinned = Pin::new(&mut f_clone);
assert!(matches!(pinned.poll(&mut cx), Poll::Pending));
let mut f_clone = f.clone();
let pinned = Pin::new(&mut f_clone);
assert!(matches!(pinned.poll(&mut cx), Poll::Pending));
assert!(
!woken.load(Ordering::SeqCst),
"Waker should not be invoked during pending polls"
);
f.get_state().complete(42);
assert!(
woken.load(Ordering::SeqCst),
"Waker should be invoked after fulfill"
);
let result = block_on(f);
assert!(matches!(result, ReadyFutureResult::Completed(42)));
}
#[test]
fn test_terminated_no_waker() {
let f = ReadyFuture::<usize>::new();
f.terminate();
assert!(f.is_terminated(), "Future should be terminated");
let woken = Arc::new(AtomicBool::new(false));
let waker = Arc::new(TestWaker {
woken: woken.clone(),
});
let waker = std::task::Waker::from(waker);
let mut cx = Context::from_waker(&waker);
let mut f_clone = f.clone();
let pinned = Pin::new(&mut f_clone);
assert!(matches!(
pinned.poll(&mut cx),
Poll::Ready(ReadyFutureResult::Terminated)
));
assert!(
!woken.load(Ordering::SeqCst),
"Waker should not be invoked after termination"
);
f.get_state().complete(42);
assert!(
!woken.load(Ordering::SeqCst),
"Waker should not be invoked after termination"
);
}
#[test]
fn test_completed_no_waker() {
let f = ReadyFuture::<usize>::new();
f.get_state().complete(1);
assert!(f.is_completed(), "Future should be completed");
block_on(f);
}
}