use std::sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc,
};
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone)]
pub struct FallbackCoordinator {
signal: CancellationToken,
pending_tasks: Arc<AtomicUsize>,
notify: Arc<Notify>,
}
impl Default for FallbackCoordinator {
fn default() -> Self {
Self::new()
}
}
impl FallbackCoordinator {
pub fn new() -> Self {
Self {
signal: CancellationToken::new(),
pending_tasks: Arc::new(AtomicUsize::new(0)),
notify: Arc::new(Notify::new()),
}
}
#[must_use]
pub fn register(&self) -> FallbackHandler {
tracing::debug!("FallbackCoordinator: registering component");
self.pending_tasks.fetch_add(1, Ordering::Relaxed);
FallbackHandler {
coordinator: self.clone(),
done: AtomicBool::new(false),
}
}
pub fn token(&self) -> CancellationToken {
self.signal.clone()
}
pub async fn trigger_fallback_and_wait(&self) {
tracing::debug!("FallbackCoordinator: triggering fallback");
self.signal.cancel();
if self.pending_tasks.load(Ordering::Acquire) == 0 {
return; }
self.notify.notified().await;
tracing::debug!("FallbackCoordinator: finished waiting for components to complete cleanup");
}
}
pub struct FallbackHandler {
coordinator: FallbackCoordinator,
done: AtomicBool,
}
impl FallbackHandler {
pub fn done(self) {
tracing::debug!("FallbackHandler: done called");
self.done.store(true, Ordering::Release);
let prev = self
.coordinator
.pending_tasks
.fetch_sub(1, Ordering::Release);
if self.coordinator.signal.is_cancelled() && prev == 1 {
self.coordinator.notify.notify_one();
}
}
}
impl Drop for FallbackHandler {
fn drop(&mut self) {
if !self.done.load(Ordering::Acquire) {
panic!("FallbackHandler dropped without calling done()");
}
}
}