use crate::SleepProvider;
use futures::channel::mpsc;
use futures::channel::mpsc::{UnboundedReceiver, UnboundedSender};
use futures::{Stream, StreamExt};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use web_time_compat::{Duration, Instant, SystemTime};
use pin_project::pin_project;
#[derive(Clone, Debug, thiserror::Error)]
#[non_exhaustive]
pub enum SleepError {
#[error("All task handles dropped: task exiting.")]
ScheduleDropped,
}
#[derive(Copy, Clone)]
enum SchedulerCommand {
Fire,
FireAt(Instant),
Cancel,
Suspend,
Resume,
}
#[pin_project(project = TaskScheduleP)]
pub struct TaskSchedule<R: SleepProvider> {
sleep: Option<Pin<Box<R::SleepFuture>>>,
rx: UnboundedReceiver<SchedulerCommand>,
rt: R,
instant_fire: bool,
suspended: bool,
}
#[derive(Clone)]
pub struct TaskHandle {
tx: UnboundedSender<SchedulerCommand>,
}
impl<R: SleepProvider> TaskSchedule<R> {
pub fn new(rt: R) -> (Self, TaskHandle) {
let (tx, rx) = mpsc::unbounded();
(
Self {
sleep: None,
rx,
rt,
instant_fire: true,
suspended: false,
},
TaskHandle { tx },
)
}
pub fn fire_in(&mut self, dur: Duration) {
self.instant_fire = false;
self.sleep = Some(Box::pin(self.rt.sleep(dur)));
}
pub fn fire(&mut self) {
self.instant_fire = true;
self.sleep = None;
}
pub async fn sleep(&mut self, dur: Duration) -> Result<(), SleepError> {
self.fire_in(dur);
self.next().await.ok_or(SleepError::ScheduleDropped)
}
pub async fn sleep_until_wallclock(&mut self, when: SystemTime) -> Result<(), SleepError> {
loop {
let (finished, delay) = crate::timer::calc_next_delay(self.rt.wallclock(), when);
self.sleep(delay).await?;
if finished {
return Ok(());
}
}
}
}
impl TaskHandle {
pub fn fire(&self) -> bool {
self.tx.unbounded_send(SchedulerCommand::Fire).is_ok()
}
pub fn fire_at(&self, instant: Instant) -> bool {
self.tx
.unbounded_send(SchedulerCommand::FireAt(instant))
.is_ok()
}
pub fn cancel(&self) -> bool {
self.tx.unbounded_send(SchedulerCommand::Cancel).is_ok()
}
pub fn suspend(&self) -> bool {
self.tx.unbounded_send(SchedulerCommand::Suspend).is_ok()
}
pub fn resume(&self) -> bool {
self.tx.unbounded_send(SchedulerCommand::Resume).is_ok()
}
}
impl<R: SleepProvider> TaskScheduleP<'_, R> {
fn handle_command(&mut self, cmd: SchedulerCommand) {
match cmd {
SchedulerCommand::Fire => {
*self.instant_fire = true;
*self.sleep = None;
}
SchedulerCommand::FireAt(instant) => {
let now = self.rt.now();
let dur = instant.saturating_duration_since(now);
*self.instant_fire = false;
*self.sleep = Some(Box::pin(self.rt.sleep(dur)));
}
SchedulerCommand::Cancel => {
*self.instant_fire = false;
*self.sleep = None;
}
SchedulerCommand::Suspend => {
*self.suspended = true;
}
SchedulerCommand::Resume => {
*self.suspended = false;
}
}
}
}
impl<R: SleepProvider> Stream for TaskSchedule<R> {
type Item = ();
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
while let Poll::Ready(maybe_cmd) = this.rx.poll_next_unpin(cx) {
match maybe_cmd {
Some(c) => this.handle_command(c),
None => {
return Poll::Ready(None);
}
}
}
if *this.suspended {
return Poll::Pending;
}
if *this.instant_fire {
*this.instant_fire = false;
return Poll::Ready(Some(()));
}
if this
.sleep
.as_mut()
.map(|x| x.as_mut().poll(cx).is_ready())
.unwrap_or(false)
{
*this.sleep = None;
return Poll::Ready(Some(()));
}
Poll::Pending
}
}
#[cfg(all(
test,
any(feature = "native-tls", feature = "rustls"),
any(feature = "tokio", feature = "async-std", feature = "smol"),
not(miri), // Several of these use real SystemTime
))]
mod test {
use crate::scheduler::TaskSchedule;
use crate::{SleepProvider, test_with_all_runtimes};
use futures::FutureExt;
use futures::StreamExt;
use web_time_compat::{Duration, Instant, InstantExt};
#[test]
fn it_fires_immediately() {
test_with_all_runtimes!(|rt| async move {
let (mut sch, _hdl) = TaskSchedule::new(rt);
assert!(sch.next().now_or_never().is_some());
});
}
#[test]
#[allow(clippy::unwrap_used)]
fn it_dies_if_dropped() {
test_with_all_runtimes!(|rt| async move {
let (mut sch, hdl) = TaskSchedule::new(rt);
drop(hdl);
assert!(sch.next().now_or_never().unwrap().is_none());
});
}
#[test]
fn it_fires_on_demand() {
test_with_all_runtimes!(|rt| async move {
let (mut sch, hdl) = TaskSchedule::new(rt);
assert!(sch.next().now_or_never().is_some());
assert!(sch.next().now_or_never().is_none());
assert!(hdl.fire());
assert!(sch.next().now_or_never().is_some());
assert!(sch.next().now_or_never().is_none());
});
}
#[test]
fn it_cancels_instant_firings() {
test_with_all_runtimes!(|rt| async move {
let (mut sch, hdl) = TaskSchedule::new(rt);
assert!(sch.next().now_or_never().is_some());
assert!(sch.next().now_or_never().is_none());
assert!(hdl.fire());
assert!(hdl.cancel());
assert!(sch.next().now_or_never().is_none());
});
}
#[test]
fn it_fires_after_self_reschedule() {
test_with_all_runtimes!(|rt| async move {
let (mut sch, _hdl) = TaskSchedule::new(rt);
assert!(sch.next().now_or_never().is_some());
sch.fire_in(Duration::from_millis(100));
assert!(sch.next().now_or_never().is_none());
assert!(sch.next().await.is_some());
assert!(sch.next().now_or_never().is_none());
});
}
#[test]
fn it_fires_after_external_reschedule() {
test_with_all_runtimes!(|rt| async move {
let (mut sch, hdl) = TaskSchedule::new(rt);
assert!(sch.next().now_or_never().is_some());
hdl.fire_at(Instant::get() + Duration::from_millis(100));
assert!(sch.next().now_or_never().is_none());
assert!(sch.next().await.is_some());
assert!(sch.next().now_or_never().is_none());
});
}
#[test]
#[ignore]
fn it_cancels_delayed_firings() {
test_with_all_runtimes!(|rt| async move {
let (mut sch, hdl) = TaskSchedule::new(rt.clone());
assert!(sch.next().now_or_never().is_some());
hdl.fire_at(Instant::get() + Duration::from_millis(100));
assert!(sch.next().now_or_never().is_none());
rt.sleep(Duration::from_millis(50)).await;
assert!(sch.next().now_or_never().is_none());
hdl.cancel();
assert!(sch.next().now_or_never().is_none());
rt.sleep(Duration::from_millis(100)).await;
assert!(sch.next().now_or_never().is_none());
});
}
#[test]
fn last_fire_wins() {
test_with_all_runtimes!(|rt| async move {
let (mut sch, hdl) = TaskSchedule::new(rt.clone());
assert!(sch.next().now_or_never().is_some());
hdl.fire_at(Instant::get() + Duration::from_millis(100));
hdl.fire();
assert!(sch.next().now_or_never().is_some());
assert!(sch.next().now_or_never().is_none());
rt.sleep(Duration::from_millis(150)).await;
assert!(sch.next().now_or_never().is_none());
});
}
#[test]
fn suspend_and_resume_with_fire() {
test_with_all_runtimes!(|rt| async move {
let (mut sch, hdl) = TaskSchedule::new(rt.clone());
hdl.fire();
hdl.suspend();
assert!(sch.next().now_or_never().is_none());
hdl.resume();
assert!(sch.next().now_or_never().is_some());
});
}
#[test]
fn suspend_and_resume_with_sleep() {
test_with_all_runtimes!(|rt| async move {
let (mut sch, hdl) = TaskSchedule::new(rt.clone());
sch.fire_in(Duration::from_millis(100));
hdl.suspend();
assert!(sch.next().now_or_never().is_none());
hdl.resume();
assert!(sch.next().now_or_never().is_none());
assert!(sch.next().await.is_some());
});
}
#[test]
fn suspend_and_resume_with_nothing() {
test_with_all_runtimes!(|rt| async move {
let (mut sch, hdl) = TaskSchedule::new(rt.clone());
assert!(sch.next().now_or_never().is_some());
hdl.suspend();
assert!(sch.next().now_or_never().is_none());
hdl.resume();
});
}
}