use crate::error::CancelPollingTaskTimeout;
use std::{
cell::UnsafeCell,
sync::{mpsc::Receiver, Arc, Condvar, Mutex},
thread,
time::Duration,
};
struct SharedState {
active: Mutex<bool>,
signal: Arc<Condvar>,
}
#[must_use = "Dropping this handle will cancel the background task"]
pub struct PollingTaskHandle {
receiver: Receiver<()>,
shared_state: Arc<SharedState>,
timeout: Option<Duration>,
}
impl PollingTaskHandle {
pub fn cancel(mut self) -> Result<(), CancelPollingTaskTimeout> {
self.cancel_impl()
}
fn cancel_impl(&mut self) -> Result<(), CancelPollingTaskTimeout> {
if *self.shared_state.active.lock().unwrap() {
*self.shared_state.active.lock().unwrap() = false;
self.shared_state.signal.notify_one();
match self.timeout {
None => {
let _ = self.receiver.recv();
}
Some(timeout) => {
if self.receiver.recv_timeout(timeout).is_err() {
return Err(CancelPollingTaskTimeout);
}
}
}
}
Ok(())
}
}
impl Drop for PollingTaskHandle {
fn drop(&mut self) {
self.cancel_impl()
.expect("Polling thread didn't signal exit within timeout");
}
}
struct IntervalCell(UnsafeCell<Duration>);
unsafe impl Sync for IntervalCell {}
#[derive(Clone)]
pub struct TaskChecker<'a> {
shared_state: &'a Arc<SharedState>,
}
impl<'a> TaskChecker<'a> {
fn new(shared_state: &'a Arc<SharedState>) -> Self {
Self { shared_state }
}
pub fn is_running(&self) -> bool {
self.shared_state.active.lock().unwrap().to_owned()
}
}
#[derive(Default)]
pub struct PollingTaskBuilder {
wait_for_clean_exit: bool,
timeout: Option<Duration>,
}
impl PollingTaskBuilder {
pub fn new() -> Self {
Self {
wait_for_clean_exit: false,
timeout: None,
}
}
pub fn wait_for_clean_exit(mut self, timeout: Option<Duration>) -> Self {
self.timeout = timeout;
self.wait_for_clean_exit = true;
self
}
fn into_polling_task_handle<D, F>(self, interval_fetcher: D, task: F) -> PollingTaskHandle
where
D: Fn() -> Duration + Send + 'static,
F: Fn(&TaskChecker) + Send + 'static,
{
let (sender, receiver) = std::sync::mpsc::channel();
let signal = Arc::new(Condvar::new());
let shared_state = Arc::new(SharedState {
active: Mutex::new(true),
signal: signal.clone(),
});
let shared_state_clone = shared_state.clone();
let _thread_handle = thread::spawn(move || {
let checker = TaskChecker::new(&shared_state_clone);
loop {
task(&checker);
let interval = interval_fetcher();
let result = shared_state_clone
.signal
.wait_timeout_while(
shared_state_clone.active.lock().unwrap(),
interval,
|&mut active| active,
)
.unwrap();
if !result.1.timed_out() {
break;
}
}
let _ = sender.send(());
});
PollingTaskHandle {
receiver,
shared_state,
timeout: self.timeout,
}
}
pub fn task<F>(self, interval: Duration, task: F) -> PollingTaskHandle
where
F: Fn() + Send + 'static,
{
self.task_with_checker(interval, move |_checker| task())
}
pub fn task_with_checker<F>(self, interval: Duration, task: F) -> PollingTaskHandle
where
F: Fn(&TaskChecker) + Send + 'static,
{
let interval_fetcher = move || interval;
self.into_polling_task_handle(interval_fetcher, move |checker| task(checker))
}
pub fn self_updating_task<F>(self, task: F) -> PollingTaskHandle
where
F: Fn() -> Duration + Send + 'static,
{
self.self_updating_task_with_checker(move |_checker| task())
}
pub fn self_updating_task_with_checker<F>(self, task: F) -> PollingTaskHandle
where
F: Fn(&TaskChecker) -> Duration + Send + 'static,
{
let interval = Arc::new(IntervalCell(UnsafeCell::new(Default::default())));
let interval_clone = interval.clone();
let interval_fetcher = move || unsafe { (*interval_clone.0.get()).to_owned() };
self.into_polling_task_handle(interval_fetcher, move |checker| {
unsafe {
let interval_ref = &mut *interval.0.get();
*interval_ref = task(checker);
}
})
}
pub fn variable_task<D, F>(self, interval_fetcher: D, task: F) -> PollingTaskHandle
where
D: Fn() -> Duration + Send + 'static,
F: Fn() + Send + 'static,
{
self.into_polling_task_handle(interval_fetcher, move |_| task())
}
pub fn variable_task_with_checker<D, F>(self, interval_fetcher: D, task: F) -> PollingTaskHandle
where
D: Fn() -> Duration + Send + 'static,
F: Fn(&TaskChecker) + Send + 'static,
{
self.into_polling_task_handle(interval_fetcher, move |checker| task(checker))
}
}
pub fn fire_and_forget_polling_task<F>(interval: Duration, task: F)
where
F: Fn() + Send + 'static,
{
thread::spawn(move || loop {
task();
thread::sleep(interval);
});
}
pub fn self_updating_fire_and_forget_polling_task<F>(task: F)
where
F: Fn() -> Duration + Send + 'static,
{
thread::spawn(move || loop {
thread::sleep(task());
});
}
pub fn variable_fire_and_forget_polling_task<D, F>(interval_fetcher: D, task: F)
where
D: Fn() -> Duration + Send + 'static,
F: Fn() + Send + 'static,
{
thread::spawn(move || loop {
task();
thread::sleep(interval_fetcher());
});
}
#[cfg(test)]
mod tests {
mod fire_and_forget;
mod self_updating_task;
mod task;
mod variable_task;
}