use futures::task::AtomicWaker;
use futures_core::FusedFuture;
use pin_project::{pin_project, pinned_drop};
use std::{
any::Any,
pin::Pin,
sync::{Arc, atomic::AtomicBool},
task::{Context, Poll},
thread::JoinHandle,
};
pub trait CancellationToken: Clone {
fn cancel(&self);
}
#[derive(Debug, Clone)]
pub struct SimpleCancellationToken {
cancelled: Arc<AtomicBool>,
}
impl SimpleCancellationToken {
pub fn new() -> Self {
Self {
cancelled: Arc::new(false.into()),
}
}
pub fn cancel(&self) {
self.cancelled
.store(true, std::sync::atomic::Ordering::SeqCst);
}
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(std::sync::atomic::Ordering::SeqCst)
}
}
impl CancellationToken for SimpleCancellationToken {
fn cancel(&self) {
Self::cancel(self);
}
}
#[pin_project(project = ThreadFutureStateProj)]
pub enum ThreadFutureState<T, F> {
NotStarted(#[pin] F),
Running(JoinHandle<T>),
Completed,
Polling,
}
#[pin_project(PinnedDrop)]
pub struct ThreadFuture<T, F, C>
where
C: CancellationToken + Send + 'static,
{
state: ThreadFutureState<T, F>,
cancel_on_drop: bool,
cancellation_token: C,
waker: Arc<AtomicWaker>,
}
impl<T, F> ThreadFuture<T, F, SimpleCancellationToken> {
pub fn new(work: F) -> Self
where
F: (FnOnce(SimpleCancellationToken) -> T) + Send + 'static,
T: Send + 'static,
{
Self {
state: ThreadFutureState::NotStarted(work),
cancel_on_drop: true,
cancellation_token: SimpleCancellationToken::new(),
waker: Arc::new(AtomicWaker::new()),
}
}
pub fn new_eager(work: F) -> Self
where
F: (FnOnce(SimpleCancellationToken) -> T) + Send + 'static,
T: Send + 'static,
{
let cancellation_token = SimpleCancellationToken::new();
let waker = Arc::new(AtomicWaker::new());
let join_handle = Self::spawn_thread(work, cancellation_token.clone(), waker.clone());
let state = ThreadFutureState::Running(join_handle);
Self {
state,
cancel_on_drop: true,
cancellation_token,
waker,
}
}
}
impl<T, F, C> ThreadFuture<T, F, C>
where
F: (FnOnce(C) -> T) + Send + 'static,
T: Send + 'static,
C: CancellationToken + Send + 'static,
{
pub fn new_with_cancellation(work: F, cancellation_token: C) -> Self {
let waker = Arc::new(AtomicWaker::new());
Self {
state: ThreadFutureState::NotStarted(work),
cancel_on_drop: true,
cancellation_token,
waker,
}
}
pub fn new_eager_with_cancellation(work: F, cancellation_token: C) -> Self {
let waker = Arc::new(AtomicWaker::new());
let join_handle = Self::spawn_thread(work, cancellation_token.clone(), waker.clone());
let state = ThreadFutureState::Running(join_handle);
Self {
state,
cancel_on_drop: true,
cancellation_token,
waker,
}
}
pub fn detach_on_drop(mut self) -> Self {
self.cancel_on_drop = false;
self
}
pub fn detach_on_drop_ref(&mut self) {
self.cancel_on_drop = false;
}
pub fn is_cancel_on_drop(&self) -> bool {
self.cancel_on_drop
}
pub fn cancellation_token(&self) -> &C {
&self.cancellation_token
}
pub fn cancel(&self) {
self.cancellation_token.cancel();
}
fn spawn_thread(work: F, cancel_token: C, waker: Arc<AtomicWaker>) -> JoinHandle<T>
where
F: (FnOnce(C) -> T) + Send + 'static,
T: Send + 'static,
{
std::thread::spawn(move || {
let result = work(cancel_token);
waker.wake();
result
})
}
}
type JoinError = Box<dyn Any + Send + 'static>;
impl<T, F, C> Future for ThreadFuture<T, F, C>
where
F: (FnOnce(C) -> T) + Send + 'static,
T: Send + 'static,
C: CancellationToken + Send + 'static,
{
type Output = Result<T, JoinError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let current_state = std::mem::replace(this.state, ThreadFutureState::Polling);
match current_state {
ThreadFutureState::NotStarted(work) => {
let waker = this.waker.clone();
waker.register(cx.waker());
let cancellation_token = this.cancellation_token.clone();
let join_handle = Self::spawn_thread(work, cancellation_token, waker);
*this.state = ThreadFutureState::Running(join_handle);
Poll::Pending
}
ThreadFutureState::Running(join_handle) => {
if !join_handle.is_finished() {
this.waker.register(cx.waker());
}
if join_handle.is_finished() {
*this.state = ThreadFutureState::Completed;
return Poll::Ready(join_handle.join());
} else {
*this.state = ThreadFutureState::Running(join_handle);
return Poll::Pending;
}
}
ThreadFutureState::Completed => {
*this.state = ThreadFutureState::Completed;
Poll::Pending
}
ThreadFutureState::Polling => {
unreachable!(
"Intermediate polling state reached, this should not be possible unless the poll function was interrupted during processing!"
)
}
}
}
}
#[pinned_drop]
impl<T, F, C> PinnedDrop for ThreadFuture<T, F, C>
where
C: CancellationToken + Send + 'static,
{
fn drop(self: Pin<&mut Self>) {
let this = self.project();
if *this.cancel_on_drop {
this.cancellation_token.cancel();
}
}
}
impl<T, F, C> FusedFuture for ThreadFuture<T, F, C>
where
F: (FnOnce(C) -> T) + Send + 'static,
T: Send + 'static,
C: CancellationToken + Send + 'static,
{
fn is_terminated(&self) -> bool {
matches!(self.state, ThreadFutureState::Completed)
}
}