use std::future::Future;
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use super::AsyncRuntime;
pub struct BackgroundRuntime<AR> {
pub async_runtime: AR,
cancellation_token: CancellationToken,
watched_tasks: TaskTracker,
}
impl<AR: AsyncRuntime> BackgroundRuntime<AR> {
#[must_use]
pub fn new(runtime: AR) -> BackgroundRuntime<AR> {
let cancellation_token = CancellationToken::new();
let watched_tasks = TaskTracker::new();
BackgroundRuntime {
async_runtime: runtime,
cancellation_token,
watched_tasks,
}
}
pub(crate) fn spawn_tracked<F>(&self, future: F)
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.async_runtime
.spawn(self.watched_tasks.track_future(future));
}
pub(crate) fn spawn_untracked<F>(&self, future: F)
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let cancellation_token = self.cancellation_token().clone();
self.async_runtime.spawn(async move {
cancellation_token.run_until_cancelled(future).await;
});
}
}
impl<AR> BackgroundRuntime<AR> {
#[must_use]
pub(crate) fn cancellation_token(&self) -> CancellationToken {
self.cancellation_token.child_token()
}
pub(crate) fn stop(&self) {
log::debug!(target: "eppo", "stopping background runtime");
self.watched_tasks.close();
self.cancellation_token.cancel();
}
pub(super) fn wait(&self) -> impl Future {
let tracker = self.watched_tasks.clone();
async move { tracker.wait().await }
}
}
impl<AR> Drop for BackgroundRuntime<AR> {
fn drop(&mut self) {
self.stop();
}
}
#[cfg(test)]
mod tests {
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use super::*;
#[test]
fn test_start_stop() {
let tokio_runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let background_runtime = BackgroundRuntime::new(tokio_runtime.handle().clone());
background_runtime.stop();
tokio_runtime.block_on(background_runtime.wait());
}
#[test]
fn test_stops_with_uncooperative_task() {
let tokio_runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let background_runtime = BackgroundRuntime::new(tokio_runtime.handle().clone());
background_runtime.spawn_untracked(async {
loop {
tokio::time::sleep(Duration::from_secs(1)).await;
}
});
background_runtime.stop();
tokio_runtime.block_on(background_runtime.wait());
}
#[test]
fn test_waits_for_tracked_task_to_finish() {
let tokio_runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let background_runtime = BackgroundRuntime::new(tokio_runtime.handle().clone());
let finished_cleanly = Arc::new(AtomicBool::new(false));
let cancellation_token = background_runtime.cancellation_token();
background_runtime.spawn_tracked({
let finished_cleanly = finished_cleanly.clone();
async move {
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
finished_cleanly.store(true, Ordering::Relaxed);
return;
},
_ = tokio::time::sleep(Duration::from_secs(1)) => {
},
}
}
}
});
background_runtime.stop();
tokio_runtime.block_on(background_runtime.wait());
assert!(finished_cleanly.load(Ordering::Relaxed));
}
#[test]
fn test_stops_by_dropping() {
let tokio_runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let background_runtime = BackgroundRuntime::new(tokio_runtime.handle().clone());
let thread = std::thread::spawn({
let wait = background_runtime.wait();
move || {
tokio_runtime.block_on(wait);
}
});
drop(background_runtime);
thread.join().unwrap();
}
}