use core::fmt::Display;
use core::mem;
use core::time::Duration;
use crossbeam_channel::{bounded, Sender};
use std::sync::{Arc, RwLock};
use std::thread;
use tracing::{debug, error, warn};
use crate::util::lock::LockExt;
pub struct TaskHandle {
shutdown_sender: Sender<()>,
stopped: Arc<RwLock<bool>>,
join_handle: DropJoinHandle,
}
struct DropJoinHandle(Option<thread::JoinHandle<()>>);
pub enum TaskError<E> {
Ignore(Box<E>),
Fatal(Box<E>),
}
pub enum Next {
Continue,
Abort,
}
pub fn spawn_background_task<E: Display>(
span: tracing::Span,
interval_pause: Option<Duration>,
mut step_runner: impl FnMut() -> Result<Next, TaskError<E>> + Send + Sync + 'static,
) -> TaskHandle {
debug!(parent: &span, "spawning task");
let stopped = Arc::new(RwLock::new(false));
let write_stopped = stopped.clone();
let (shutdown_sender, receiver) = bounded(1);
let join_handle = thread::spawn(move || {
let _entered = span.enter();
loop {
match receiver.try_recv() {
Ok(()) => {
break;
}
_ => match step_runner() {
Ok(Next::Continue) => {}
Ok(Next::Abort) => {
debug!("aborting task");
break;
}
Err(TaskError::Ignore(e)) => {
warn!("task encountered ignorable error: {}", e);
}
Err(TaskError::Fatal(e)) => {
error!("task aborting after encountering fatal error: {}", e);
break;
}
},
}
if let Some(interval) = interval_pause {
thread::sleep(interval);
}
}
*write_stopped.acquire_write() = true;
debug!("task terminated");
});
TaskHandle {
shutdown_sender,
stopped,
join_handle: DropJoinHandle(Some(join_handle)),
}
}
impl TaskHandle {
pub fn join(mut self) {
if let Some(handle) = mem::take(&mut self.join_handle.0) {
let _ = handle.join();
}
}
pub fn shutdown(&self) {
let _ = self.shutdown_sender.send(());
}
pub fn shutdown_and_wait(self) {
let _ = self.shutdown_sender.send(());
}
pub fn is_stopped(&self) -> bool {
*self.stopped.acquire_read()
}
}
impl Drop for DropJoinHandle {
fn drop(&mut self) {
if let Some(handle) = mem::take(&mut self.0) {
let _ = handle.join();
}
}
}
impl Drop for TaskHandle {
fn drop(&mut self) {
let _ = self.shutdown_sender.send(());
}
}