use crate::error::CancelPollingTaskTimeout;
use std::{cell::UnsafeCell, future::Future, sync::Arc, time::Duration};
use tokio::{select, sync::Notify};
use tokio_util::sync::CancellationToken;
#[must_use = "Dropping this handle will cancel the background task"]
pub struct PollingTaskHandle {
signal: Arc<Notify>,
cancellation_token: CancellationToken,
timeout: Option<Duration>,
}
impl PollingTaskHandle {
pub async fn cancel(self) -> Result<(), CancelPollingTaskTimeout> {
Self::cancel_impl(
self.cancellation_token.clone(),
self.signal.clone(),
self.timeout,
)
.await
}
async fn cancel_impl(
cancellation_token: CancellationToken,
signal: Arc<Notify>,
timeout: Option<Duration>,
) -> Result<(), CancelPollingTaskTimeout> {
cancellation_token.cancel();
if let Some(timeout) = timeout {
if let Err(_) = tokio::time::timeout(timeout, async {
let _ = signal.notified().await;
})
.await
{
return Err(CancelPollingTaskTimeout);
}
}
Ok(())
}
}
impl Drop for PollingTaskHandle {
fn drop(&mut self) {
match self.timeout {
Some(timeout) => {
let cancellation_token = self.cancellation_token.clone();
let signal = self.signal.clone();
tokio::task::spawn(async move {
Self::cancel_impl(cancellation_token, signal, Some(timeout))
.await
.expect("Polling task didn't signal exit within timeout");
});
}
None => self.cancellation_token.cancel(),
}
}
}
struct IntervalCell(UnsafeCell<Duration>);
unsafe impl Sync for IntervalCell {}
#[derive(Clone)]
pub struct TaskChecker {
cancellation_token: CancellationToken,
}
impl TaskChecker {
fn new(cancellation_token: CancellationToken) -> Self {
Self { cancellation_token }
}
pub fn is_running(&self) -> bool {
self.cancellation_token.is_cancelled()
}
}
pub struct PollingTaskBuilder {
timeout: Option<Duration>,
}
impl PollingTaskBuilder {
pub fn new() -> Self {
Self { timeout: None }
}
pub fn track_for_clean_exit_within(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
fn into_polling_task_handle<D, F, Dfut, Ffut>(
self,
interval_fetcher: D,
task: F,
) -> PollingTaskHandle
where
D: Fn() -> Dfut + Send + 'static,
Dfut: Future<Output = Duration> + Send,
F: Fn(TaskChecker) -> Ffut + Send + 'static,
Ffut: Future<Output = ()> + Send,
{
let signal = Arc::new(Notify::new());
let signal_clone = signal.clone();
let cancellation_token = CancellationToken::new();
let cancellation_token_clone = cancellation_token.clone();
let cancellation_token_clone2 = cancellation_token.clone();
let _thread_handle = tokio::task::spawn(async move {
let checker = TaskChecker::new(cancellation_token_clone2);
loop {
let checker_clone = checker.clone();
task(checker_clone).await;
select! {
_ = cancellation_token_clone.cancelled() => {
break;
}
_ = tokio::time::sleep(interval_fetcher().await) => {
}
}
}
let _ = signal_clone.notify_one();
});
PollingTaskHandle {
signal,
timeout: self.timeout,
cancellation_token,
}
}
pub fn task<F, Fut>(self, interval: Duration, task: F) -> PollingTaskHandle
where
F: Fn() -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send,
{
self.task_with_checker(interval, move |_checker| task())
}
pub fn task_with_checker<F, Fut>(self, interval: Duration, task: F) -> PollingTaskHandle
where
F: Fn(TaskChecker) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send,
{
let interval_fetcher = move || async move { interval };
self.into_polling_task_handle(interval_fetcher, move |checker| task(checker))
}
pub fn self_updating_task<F, Fut>(self, task: F) -> PollingTaskHandle
where
F: Fn() -> Fut + Send + 'static,
Fut: Future<Output = Duration> + Send,
{
self.self_updating_task_with_checker(move |_checker| task())
}
pub fn self_updating_task_with_checker<F, Fut>(self, task: F) -> PollingTaskHandle
where
F: Fn(TaskChecker) -> Fut + Send + 'static,
Fut: Future<Output = Duration> + Send,
{
let interval = Arc::new(IntervalCell(UnsafeCell::new(Default::default())));
let interval_clone = interval.clone();
let interval_fetcher = move || {
let interval = unsafe { (*interval_clone.0.get()).to_owned() };
async move { interval }
};
self.into_polling_task_handle(interval_fetcher, move |checker| {
let interval_clone = interval.clone();
let task_future = task(checker);
async move {
unsafe {
let interval_ref = &mut *interval_clone.0.get();
*interval_ref = task_future.await;
}
}
})
}
pub fn variable_task<D, F, Dfut, Ffut>(self, interval_fetcher: D, task: F) -> PollingTaskHandle
where
D: Fn() -> Dfut + Send + 'static,
Dfut: Future<Output = Duration> + Send,
F: Fn() -> Ffut + Send + 'static,
Ffut: Future<Output = ()> + Send,
{
self.into_polling_task_handle(interval_fetcher, move |_| task())
}
pub fn variable_task_with_checker<D, F, Dfut, Ffut>(
self,
interval_fetcher: D,
task: F,
) -> PollingTaskHandle
where
D: Fn() -> Dfut + Send + 'static,
Dfut: Future<Output = Duration> + Send,
F: Fn(TaskChecker) -> Ffut + Send + 'static,
Ffut: Future<Output = ()> + Send,
{
self.into_polling_task_handle(interval_fetcher, move |checker| task(checker))
}
}
pub fn fire_and_forget_polling_task<F, Fut>(interval: Duration, task: F)
where
F: Fn() -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send,
{
tokio::task::spawn(async move {
loop {
task().await;
tokio::time::sleep(interval).await;
}
});
}
pub fn self_updating_fire_and_forget_polling_task<F, Fut>(task: F)
where
F: Fn() -> Fut + Send + 'static,
Fut: Future<Output = Duration> + Send,
{
tokio::task::spawn(async move {
loop {
tokio::time::sleep(task().await).await;
}
});
}
pub fn variable_fire_and_forget_polling_task<D, F>(interval_fetcher: D, task: F)
where
D: Fn() -> Duration + Send + 'static,
F: Fn() + Send + 'static,
{
tokio::task::spawn(async move {
loop {
task();
tokio::time::sleep(interval_fetcher()).await;
}
});
}
#[cfg(test)]
mod tests {
mod fire_and_forget;
mod self_updating_task;
mod task;
mod variable_task;
}