#[cfg(not(target_arch = "wasm32"))]
pub use std::time::{Duration, Instant};
use std::{
any::Any,
future::Future,
panic::AssertUnwindSafe,
pin::Pin,
sync::{
Arc, Mutex,
atomic::{AtomicBool, Ordering},
},
task::{Context, Poll, Waker},
};
use pin_project_lite::pin_project;
#[cfg(target_arch = "wasm32")]
pub use web_time::{Duration, Instant};
use crate::subscription::Subscription;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TaskState {
Finished,
Yield,
Sleeping(Duration),
}
pub struct Task<S> {
pub state: S,
pub handler: fn(&mut S) -> TaskState,
}
impl<S> Task<S> {
pub fn new(state: S, handler: fn(&mut S) -> TaskState) -> Self { Self { state, handler } }
pub fn step(&mut self) -> TaskState { (self.handler)(&mut self.state) }
}
pub trait Schedulable<Sch> {
type Future: Future<Output = ()>;
fn into_future(self, scheduler: &Sch) -> Self::Future;
}
impl<F, Sch> Schedulable<Sch> for F
where
F: Future<Output = ()>,
{
type Future = Self;
fn into_future(self, _scheduler: &Sch) -> Self::Future { self }
}
pub trait SleepProvider: Clone {
type SleepFuture: Future<Output = ()> + 'static;
fn sleep(&self, duration: Duration) -> Self::SleepFuture;
}
pub struct TaskFuture<Sch, S>
where
Sch: SleepProvider,
{
scheduler: Sch,
task: Task<S>,
pending_sleep: Option<Sch::SleepFuture>,
}
impl<Sch, S> Future for TaskFuture<Sch, S>
where
Sch: SleepProvider,
{
type Output = ();
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
loop {
if let Some(ref mut sleep_fut) = this.pending_sleep {
let sleep_fut = unsafe { Pin::new_unchecked(sleep_fut) };
match sleep_fut.poll(ctx) {
Poll::Ready(()) => this.pending_sleep = None,
Poll::Pending => return Poll::Pending,
}
} else {
match this.task.step() {
TaskState::Finished => return Poll::Ready(()),
TaskState::Yield => {
ctx.waker().wake_by_ref();
return Poll::Pending;
}
TaskState::Sleeping(dur) => {
this.pending_sleep = Some(this.scheduler.sleep(dur));
}
}
}
}
}
}
impl<Sch, S> Schedulable<Sch> for Task<S>
where
Sch: SleepProvider + Scheduler<<Sch as SleepProvider>::SleepFuture> + Clone,
S: 'static,
{
type Future = TaskFuture<Sch, S>;
fn into_future(self, scheduler: &Sch) -> Self::Future {
TaskFuture { scheduler: scheduler.clone(), task: self, pending_sleep: None }
}
}
struct SharedState {
keep_running: AtomicBool,
finished: AtomicBool,
waker: Mutex<Option<Waker>>,
}
#[derive(Clone)]
pub struct TaskHandle {
inner: Arc<SharedState>,
}
impl TaskHandle {
pub(crate) fn new() -> Self {
Self {
inner: Arc::new(SharedState {
keep_running: AtomicBool::new(true),
finished: AtomicBool::new(false),
waker: Mutex::new(None),
}),
}
}
pub fn finished() -> Self {
Self {
inner: Arc::new(SharedState {
keep_running: AtomicBool::new(false),
finished: AtomicBool::new(true),
waker: Mutex::new(None),
}),
}
}
pub(crate) fn mark_finished(&self) {
self.inner.finished.store(true, Ordering::Relaxed);
if let Ok(mut waker) = self.inner.waker.lock()
&& let Some(w) = waker.take()
{
w.wake();
}
}
}
impl Subscription for TaskHandle {
fn unsubscribe(self) {
self
.inner
.keep_running
.store(false, Ordering::Relaxed);
if let Ok(mut waker) = self.inner.waker.lock()
&& let Some(w) = waker.take()
{
w.wake();
}
}
fn is_closed(&self) -> bool {
!self.inner.keep_running.load(Ordering::Relaxed) || self.inner.finished.load(Ordering::Relaxed)
}
}
impl Future for TaskHandle {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.inner.finished.load(Ordering::Relaxed)
|| !self.inner.keep_running.load(Ordering::Relaxed)
{
return Poll::Ready(());
}
if let Ok(mut waker) = self.inner.waker.lock() {
*waker = Some(cx.waker().clone());
}
Poll::Pending
}
}
pin_project! {
struct CatchUnwind<Fut> {
#[pin]
future: Fut,
}
}
impl<Fut> Future for CatchUnwind<Fut>
where
Fut: Future,
{
type Output = Result<Fut::Output, Box<dyn Any + Send + 'static>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let polled = std::panic::catch_unwind(AssertUnwindSafe(|| this.future.poll(cx)));
match polled {
Ok(Poll::Pending) => Poll::Pending,
Ok(Poll::Ready(v)) => Poll::Ready(Ok(v)),
Err(e) => Poll::Ready(Err(e)),
}
}
}
pin_project! {
struct Remote<Fut: Future> {
handle: TaskHandle,
#[pin]
future: CatchUnwind<AssertUnwindSafe<Fut>>,
}
}
impl<Fut: Future<Output = ()>> Future for Remote<Fut> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let this = self.project();
if !this
.handle
.inner
.keep_running
.load(Ordering::Relaxed)
{
return Poll::Ready(());
}
match this.future.poll(cx) {
Poll::Ready(_result) => {
this.handle.mark_finished();
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
}
fn remote_handle<Fut: Future<Output = ()>>(future: Fut) -> (Remote<Fut>, TaskHandle) {
let handle = TaskHandle::new();
let wrapped =
Remote { future: CatchUnwind { future: AssertUnwindSafe(future) }, handle: handle.clone() };
(wrapped, handle)
}
pub trait Scheduler<S>: Clone {
fn schedule(&self, source: S, delay: Option<Duration>) -> TaskHandle;
}
#[cfg(feature = "scheduler")]
pub mod default_schedulers {
#[cfg(all(target_arch = "wasm32", target_vendor = "unknown", target_os = "unknown"))]
use gloo_timers::future::sleep as platform_sleep;
#[cfg(not(all(target_arch = "wasm32", target_vendor = "unknown", target_os = "unknown")))]
use tokio::time::sleep as platform_sleep;
#[cfg(not(all(target_arch = "wasm32", target_vendor = "unknown", target_os = "unknown")))]
use tokio::{spawn, task::spawn_local};
#[cfg(all(target_arch = "wasm32", target_vendor = "unknown", target_os = "unknown"))]
use wasm_bindgen_futures::spawn_local;
use super::*;
#[derive(Clone, Copy, Default)]
pub struct LocalScheduler;
impl SleepProvider for LocalScheduler {
#[cfg(not(all(target_arch = "wasm32", target_vendor = "unknown", target_os = "unknown")))]
type SleepFuture = tokio::time::Sleep;
#[cfg(all(target_arch = "wasm32", target_vendor = "unknown", target_os = "unknown"))]
type SleepFuture = gloo_timers::future::TimeoutFuture;
fn sleep(&self, duration: Duration) -> Self::SleepFuture { platform_sleep(duration) }
}
macro_rules! impl_scheduler {
($sched:ty, $spawn:expr $(, $bound:tt)*) => {
impl<S> Scheduler<S> for $sched
where
S: Schedulable<Self> $(+ $bound)* + 'static,
S::Future: $($bound +)* 'static,
{
fn schedule(&self, source: S, delay: Option<Duration>) -> TaskHandle {
let scheduler = self.clone();
let future = source.into_future(self);
let wrapped = async move {
if let Some(d) = delay {
scheduler.sleep(d).await;
}
future.await;
};
let (remote, handle) = remote_handle(wrapped);
$spawn(remote);
handle
}
}
};
}
impl_scheduler!(LocalScheduler, spawn_local);
#[derive(Clone, Copy, Default)]
pub struct SharedScheduler;
impl SleepProvider for SharedScheduler {
#[cfg(not(all(target_arch = "wasm32", target_vendor = "unknown", target_os = "unknown")))]
type SleepFuture = tokio::time::Sleep;
#[cfg(all(target_arch = "wasm32", target_vendor = "unknown", target_os = "unknown"))]
type SleepFuture = gloo_timers::future::TimeoutFuture;
fn sleep(&self, duration: Duration) -> Self::SleepFuture { platform_sleep(duration) }
}
#[cfg(all(target_arch = "wasm32", target_vendor = "unknown", target_os = "unknown"))]
impl_scheduler!(SharedScheduler, spawn_local);
#[cfg(not(all(target_arch = "wasm32", target_vendor = "unknown", target_os = "unknown")))]
impl_scheduler!(SharedScheduler, spawn, Send);
}
#[cfg(feature = "scheduler")]
pub use default_schedulers::{LocalScheduler, SharedScheduler};
#[cfg(all(feature = "scheduler", not(target_arch = "wasm32")))]
pub use tokio;
#[cfg(test)]
pub mod test_scheduler;
#[cfg(all(test, feature = "scheduler"))]
mod tests {
use std::sync::{Arc, Mutex};
use super::*;
mod scheduler_tests {
use super::*;
#[rxrust_macro::test]
fn test_task_handle_finished_is_closed() {
let handle = TaskHandle::finished();
assert!(handle.is_closed());
}
#[rxrust_macro::test]
fn test_task_creation_and_execution() {
let executed = Arc::new(Mutex::new(false));
let executed_clone = executed.clone();
let mut task = Task::new(executed_clone, |flag| {
*flag.lock().unwrap() = true;
TaskState::Finished
});
task.step();
assert!(*executed.lock().unwrap());
}
#[rxrust_macro::test(local)]
async fn test_local_scheduler_basic() {
let executed = Arc::new(Mutex::new(false));
let executed_clone = executed.clone();
let scheduler = LocalScheduler;
let task = Task::new(executed_clone, |flag| {
*flag.lock().unwrap() = true;
TaskState::Finished
});
let handle = scheduler.schedule(task, None);
handle.await;
assert!(*executed.lock().unwrap());
}
#[rxrust_macro::test]
fn test_local_scheduler_with_delay() {
use std::{cell::Cell, rc::Rc};
use crate::scheduler::test_scheduler::TestScheduler;
TestScheduler::init();
let executed = Rc::new(Cell::new(false));
let executed_clone = executed.clone();
let scheduler = TestScheduler;
let task = Task::new(executed_clone, |flag| {
flag.set(true);
TaskState::Finished
});
let handle = scheduler.schedule(task, Some(Duration::from_millis(50)));
assert!(!executed.get());
TestScheduler::advance_by(Duration::from_millis(30));
assert!(!executed.get());
TestScheduler::advance_by(Duration::from_millis(20));
assert!(executed.get());
assert!(handle.is_closed());
}
#[rxrust_macro::test(local)]
async fn test_local_scheduler_cancellation() {
let executed = Arc::new(Mutex::new(false));
let executed_clone = executed.clone();
let scheduler = LocalScheduler;
let task = Task::new(executed_clone, |flag| {
*flag.lock().unwrap() = true;
TaskState::Finished
});
let handle = scheduler.schedule(task, Some(Duration::from_millis(100)));
let handle_check = handle.clone();
handle.unsubscribe();
assert!(handle_check.is_closed());
LocalScheduler
.sleep(Duration::from_millis(150))
.await;
assert!(!*executed.lock().unwrap());
}
#[cfg(not(target_arch = "wasm32"))]
#[rxrust_macro::test]
async fn test_shared_scheduler_basic() {
let executed = Arc::new(Mutex::new(false));
let executed_clone = executed.clone();
let scheduler = SharedScheduler;
let task = Task::new(executed_clone, |flag| {
*flag.lock().unwrap() = true;
TaskState::Finished
});
let handle = scheduler.schedule(task, None);
handle.await;
assert!(*executed.lock().unwrap());
}
#[cfg(not(target_arch = "wasm32"))]
#[rxrust_macro::test]
async fn test_shared_scheduler_with_delay() {
let executed = Arc::new(Mutex::new(false));
let executed_clone = executed.clone();
let start = Instant::now();
let scheduler = SharedScheduler;
let task = Task::new(executed_clone, |flag| {
*flag.lock().unwrap() = true;
TaskState::Finished
});
let handle = scheduler.schedule(task, Some(Duration::from_millis(50)));
handle.await;
let elapsed = start.elapsed();
assert!(*executed.lock().unwrap());
assert!(elapsed >= Duration::from_millis(50));
}
#[cfg(not(target_arch = "wasm32"))]
#[rxrust_macro::test]
async fn test_shared_scheduler_cancellation() {
let executed = Arc::new(Mutex::new(false));
let executed_clone = executed.clone();
let scheduler = SharedScheduler;
let task = Task::new(executed_clone, |flag| {
*flag.lock().unwrap() = true;
TaskState::Finished
});
let handle = scheduler.schedule(task, Some(Duration::from_millis(100)));
let handle_check = handle.clone();
handle.unsubscribe();
assert!(handle_check.is_closed());
LocalScheduler
.sleep(Duration::from_millis(150))
.await;
assert!(!*executed.lock().unwrap());
}
#[cfg(not(target_arch = "wasm32"))]
#[rxrust_macro::test]
async fn test_shared_scheduler_concurrent_tasks() {
let counter = Arc::new(Mutex::new(0));
let scheduler = SharedScheduler;
let mut handles = vec![];
for _ in 0..10 {
let counter_clone = counter.clone();
let task = Task::new(counter_clone, |cnt| {
*cnt.lock().unwrap() += 1;
TaskState::Finished
});
let handle = scheduler.schedule(task, None);
handles.push(handle);
}
for handle in handles {
handle.await;
}
assert_eq!(*counter.lock().unwrap(), 10);
}
#[rxrust_macro::test]
fn test_scheduler_zero_size_types() {
use std::mem;
assert_eq!(mem::size_of::<LocalScheduler>(), 0);
assert_eq!(mem::size_of::<SharedScheduler>(), 0);
let scheduler1 = LocalScheduler;
let scheduler2 = scheduler1; let _ = (scheduler1, scheduler2);
let scheduler3 = SharedScheduler;
let scheduler4 = scheduler3;
let _ = (scheduler3, scheduler4);
}
#[rxrust_macro::test(local)]
async fn test_task_handle_as_future() {
let executed = Arc::new(Mutex::new(false));
let executed_clone = executed.clone();
let scheduler = LocalScheduler;
let task = Task::new(executed_clone, |flag| {
*flag.lock().unwrap() = true;
TaskState::Finished
});
let handle = scheduler.schedule(task, None);
let result = handle.await;
assert_eq!(result, ());
assert!(*executed.lock().unwrap());
}
#[cfg(not(target_arch = "wasm32"))]
#[rxrust_macro::test(local)]
async fn test_panic_handling() {
let scheduler = LocalScheduler;
let task = Task::new((), |_| {
panic!("Intentional panic for testing");
});
let handle = scheduler.schedule(task, None);
let _result = handle.await;
}
#[cfg(not(target_arch = "wasm32"))]
#[rxrust_macro::test]
async fn test_shared_scheduler_send_requirement() {
let data = Arc::new(Mutex::new(42));
let data_clone = data.clone();
let scheduler = SharedScheduler;
let task = Task::new(data_clone, |d| {
*d.lock().unwrap() = 100;
TaskState::Finished
});
let handle = scheduler.schedule(task, None);
handle.await;
assert_eq!(*data.lock().unwrap(), 100);
}
#[rxrust_macro::test]
fn test_unsubscribe_consumes_handle() {
let handle = TaskHandle::finished();
assert!(handle.is_closed());
handle.unsubscribe();
}
}
}