use super::JobId;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
use tokio::time::timeout;
use tracing::{debug, info, warn};
#[derive(Debug, Clone)]
pub struct CancellationToken {
state: Arc<CancellationState>,
}
#[derive(Debug)]
struct CancellationState {
tx: watch::Sender<bool>,
rx: watch::Receiver<bool>,
}
impl CancellationToken {
#[must_use]
pub fn new() -> Self {
let (tx, rx) = watch::channel(false);
Self {
state: Arc::new(CancellationState { tx, rx }),
}
}
#[must_use]
pub fn is_cancelled(&self) -> bool {
*self.state.rx.borrow()
}
pub fn cancel(&self) {
let _ = self.state.tx.send(true);
debug!("Cancellation requested");
}
pub async fn cancelled(&self) {
let mut rx = self.state.rx.clone();
while !*rx.borrow() {
if rx.changed().await.is_err() {
break;
}
}
}
pub async fn run_until_cancelled<F, T>(&self, future: F) -> Result<T, ()>
where
F: std::future::Future<Output = T>,
{
tokio::select! {
result = future => Ok(result),
() = self.cancelled() => Err(()),
}
}
}
impl Default for CancellationToken {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct JobCancellationManager {
tokens: Arc<RwLock<HashMap<JobId, CancellationToken>>>,
}
impl Default for JobCancellationManager {
fn default() -> Self {
Self::new()
}
}
impl JobCancellationManager {
#[must_use]
pub fn new() -> Self {
Self {
tokens: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn register(&self, job_id: JobId, token: CancellationToken) {
self.tokens.write().insert(job_id, token);
debug!("Registered cancellation token for job {}", job_id);
}
pub fn unregister(&self, job_id: &JobId) {
self.tokens.write().remove(job_id);
debug!("Unregistered cancellation token for job {}", job_id);
}
#[must_use]
pub fn cancel_job(&self, job_id: &JobId) -> bool {
self.tokens.read().get(job_id).map_or_else(
|| {
warn!("Attempted to cancel unknown job {}", job_id);
false
},
|token| {
token.cancel();
info!("Cancelled job {}", job_id);
true
},
)
}
pub fn cancel_all(&self) {
let tokens = self.tokens.read();
for (job_id, token) in tokens.iter() {
token.cancel();
info!("Cancelled job {}", job_id);
}
}
#[must_use]
pub fn active_count(&self) -> usize {
self.tokens.read().len()
}
pub async fn wait_for_completion(&self, max_wait: Duration) -> bool {
let start = std::time::Instant::now();
loop {
if self.tokens.read().is_empty() {
info!("All jobs completed gracefully");
return true;
}
if start.elapsed() >= max_wait {
let remaining = self.tokens.read().len();
warn!(
"Timeout waiting for job completion: {} jobs still running",
remaining
);
return false;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
#[derive(Debug, Clone)]
pub struct JobShutdownCoordinator {
cancellation_manager: JobCancellationManager,
shutdown_token: CancellationToken,
}
impl JobShutdownCoordinator {
#[must_use]
pub fn new() -> Self {
Self {
cancellation_manager: JobCancellationManager::new(),
shutdown_token: CancellationToken::new(),
}
}
#[must_use]
pub const fn cancellation_manager(&self) -> &JobCancellationManager {
&self.cancellation_manager
}
#[must_use]
pub const fn shutdown_token(&self) -> &CancellationToken {
&self.shutdown_token
}
pub async fn shutdown(&self, graceful_timeout: Duration) -> ShutdownResult {
info!("Initiating job system shutdown");
self.shutdown_token.cancel();
self.cancellation_manager.cancel_all();
let graceful = self
.cancellation_manager
.wait_for_completion(graceful_timeout)
.await;
if graceful {
info!("Job system shutdown completed gracefully");
ShutdownResult::Graceful
} else {
warn!("Job system forced shutdown after timeout");
ShutdownResult::Forced {
jobs_remaining: self.cancellation_manager.active_count(),
}
}
}
pub async fn shutdown_with_timeout(
&self,
graceful_timeout: Duration,
total_timeout: Duration,
) -> Result<ShutdownResult, ()> {
timeout(total_timeout, self.shutdown(graceful_timeout))
.await
.map_or_else(
|_| {
warn!("Shutdown timeout exceeded");
Err(())
},
Ok,
)
}
}
impl Default for JobShutdownCoordinator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ShutdownResult {
Graceful,
Forced {
jobs_remaining: usize,
},
}
impl ShutdownResult {
#[must_use]
pub const fn is_graceful(&self) -> bool {
matches!(self, Self::Graceful)
}
#[must_use]
pub const fn jobs_remaining(&self) -> usize {
match self {
Self::Graceful => 0,
Self::Forced { jobs_remaining } => *jobs_remaining,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cancellation_token_new() {
let token = CancellationToken::new();
assert!(!token.is_cancelled());
}
#[test]
fn test_cancellation_token_cancel() {
let token = CancellationToken::new();
assert!(!token.is_cancelled());
token.cancel();
assert!(token.is_cancelled());
}
#[test]
fn test_cancellation_token_clone() {
let token1 = CancellationToken::new();
let token2 = token1.clone();
token1.cancel();
assert!(token2.is_cancelled());
}
#[tokio::test]
async fn test_cancellation_token_cancelled() {
let token = CancellationToken::new();
let token_clone = token.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
token_clone.cancel();
});
token.cancelled().await;
assert!(token.is_cancelled());
}
#[tokio::test]
async fn test_run_until_cancelled() {
let token = CancellationToken::new();
let token_clone = token.clone();
let result = tokio::spawn(async move {
token_clone
.run_until_cancelled(async {
tokio::time::sleep(Duration::from_secs(1000)).await;
42
})
.await
});
token.cancel();
let output = result.await.unwrap();
assert_eq!(output, Err(()));
}
#[test]
fn test_cancellation_manager_register() {
let manager = JobCancellationManager::new();
let job_id = JobId::new();
let token = CancellationToken::new();
assert_eq!(manager.active_count(), 0);
manager.register(job_id, token);
assert_eq!(manager.active_count(), 1);
manager.unregister(&job_id);
assert_eq!(manager.active_count(), 0);
}
#[test]
fn test_cancellation_manager_cancel_job() {
let manager = JobCancellationManager::new();
let job_id = JobId::new();
let token = CancellationToken::new();
manager.register(job_id, token.clone());
assert!(!token.is_cancelled());
assert!(manager.cancel_job(&job_id));
assert!(token.is_cancelled());
assert!(manager.cancel_job(&job_id));
}
#[test]
fn test_cancellation_manager_cancel_all() {
let manager = JobCancellationManager::new();
let token1 = CancellationToken::new();
let token2 = CancellationToken::new();
manager.register(JobId::new(), token1.clone());
manager.register(JobId::new(), token2.clone());
manager.cancel_all();
assert!(token1.is_cancelled());
assert!(token2.is_cancelled());
}
#[test]
fn test_shutdown_result() {
let graceful = ShutdownResult::Graceful;
assert!(graceful.is_graceful());
assert_eq!(graceful.jobs_remaining(), 0);
let forced = ShutdownResult::Forced { jobs_remaining: 5 };
assert!(!forced.is_graceful());
assert_eq!(forced.jobs_remaining(), 5);
}
#[tokio::test]
async fn test_shutdown_coordinator_graceful() {
let coordinator = JobShutdownCoordinator::new();
let job_id = JobId::new();
let token = CancellationToken::new();
coordinator
.cancellation_manager()
.register(job_id, token.clone());
let token_clone = token.clone();
tokio::spawn(async move {
token_clone.cancelled().await;
});
coordinator.cancellation_manager().unregister(&job_id);
let result = coordinator.shutdown(Duration::from_secs(1)).await;
assert!(result.is_graceful());
}
}