use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{mpsc, oneshot};
use tokio::task::AbortHandle;
use tokio_util::sync::CancellationToken;
use tracing::Instrument as _;
use zeph_common::BlockingSpawner;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RestartPolicy {
RunOnce,
Restart { max: u32, base_delay: Duration },
}
pub const MAX_RESTART_DELAY: Duration = Duration::from_secs(60);
pub struct TaskDescriptor<F> {
pub name: &'static str,
pub restart: RestartPolicy,
pub factory: F,
}
#[derive(Debug, Clone)]
pub struct TaskHandle {
name: &'static str,
abort: AbortHandle,
}
impl TaskHandle {
pub fn abort(&self) {
tracing::debug!(task.name = self.name, "task aborted via handle");
self.abort.abort();
}
#[must_use]
pub fn name(&self) -> &'static str {
self.name
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum BlockingError {
Panicked,
SupervisorDropped,
}
impl std::fmt::Display for BlockingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Panicked => write!(f, "supervised blocking task panicked"),
Self::SupervisorDropped => write!(f, "supervisor dropped before task completed"),
}
}
}
impl std::error::Error for BlockingError {}
pub struct BlockingHandle<R> {
rx: oneshot::Receiver<Result<R, BlockingError>>,
abort: AbortHandle,
}
impl<R> BlockingHandle<R> {
pub async fn join(self) -> Result<R, BlockingError> {
self.rx
.await
.unwrap_or(Err(BlockingError::SupervisorDropped))
}
pub fn abort(&self) {
self.abort.abort();
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TaskStatus {
Running,
Restarting { attempt: u32, max: u32 },
Completed,
Aborted,
Failed { reason: String },
}
#[derive(Debug, Clone)]
pub struct TaskSnapshot {
pub name: Arc<str>,
pub status: TaskStatus,
pub started_at: Instant,
pub restart_count: u32,
}
type BoxFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
type BoxFactory = Box<dyn Fn() -> BoxFuture + Send + Sync>;
struct TaskEntry {
name: Arc<str>,
status: TaskStatus,
started_at: Instant,
restart_count: u32,
restart_policy: RestartPolicy,
abort_handle: AbortHandle,
factory: Option<BoxFactory>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CompletionKind {
Normal,
Panicked,
Cancelled,
}
struct Completion {
name: Arc<str>,
kind: CompletionKind,
}
struct SupervisorState {
tasks: HashMap<Arc<str>, TaskEntry>,
}
struct Inner {
state: parking_lot::Mutex<SupervisorState>,
completion_tx: mpsc::UnboundedSender<Completion>,
cancel: CancellationToken,
blocking_semaphore: Arc<tokio::sync::Semaphore>,
}
#[derive(Clone)]
pub struct TaskSupervisor {
inner: Arc<Inner>,
}
impl TaskSupervisor {
#[must_use]
pub fn new(cancel: CancellationToken) -> Self {
let (completion_tx, completion_rx) = mpsc::unbounded_channel();
let inner = Arc::new(Inner {
state: parking_lot::Mutex::new(SupervisorState {
tasks: HashMap::new(),
}),
completion_tx,
cancel: cancel.clone(),
blocking_semaphore: Arc::new(tokio::sync::Semaphore::new(8)),
});
Self::start_reap_driver(Arc::clone(&inner), completion_rx, cancel);
Self { inner }
}
pub fn spawn<F, Fut>(&self, desc: TaskDescriptor<F>) -> TaskHandle
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let factory: BoxFactory = Box::new(move || Box::pin((desc.factory)()));
let cancel = self.inner.cancel.clone();
let completion_tx = self.inner.completion_tx.clone();
let name: Arc<str> = Arc::from(desc.name);
let (abort_handle, jh) = Self::do_spawn(desc.name, &factory, cancel);
Self::wire_completion_reporter(Arc::clone(&name), jh, completion_tx);
let entry = TaskEntry {
name: Arc::clone(&name),
status: TaskStatus::Running,
started_at: Instant::now(),
restart_count: 0,
restart_policy: desc.restart,
abort_handle: abort_handle.clone(),
factory: match desc.restart {
RestartPolicy::RunOnce => None,
RestartPolicy::Restart { .. } => Some(factory),
},
};
{
let mut state = self.inner.state.lock();
if let Some(old) = state.tasks.remove(&name) {
old.abort_handle.abort();
}
state.tasks.insert(Arc::clone(&name), entry);
}
TaskHandle {
name: desc.name,
abort: abort_handle,
}
}
#[allow(clippy::needless_pass_by_value)] pub fn spawn_blocking<F, R>(&self, name: Arc<str>, f: F) -> BlockingHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (tx, rx) = oneshot::channel::<Result<R, BlockingError>>();
#[cfg(feature = "task-metrics")]
let span = tracing::info_span!(
"supervised_blocking_task",
task.name = %name,
task.wall_time_ms = tracing::field::Empty,
task.cpu_time_ms = tracing::field::Empty,
);
#[cfg(not(feature = "task-metrics"))]
let span = tracing::info_span!("supervised_blocking_task", task.name = %name);
let semaphore = Arc::clone(&self.inner.blocking_semaphore);
let inner = Arc::clone(&self.inner);
let name_clone = Arc::clone(&name);
let completion_tx = self.inner.completion_tx.clone();
let outer = tokio::spawn(async move {
let _permit = semaphore
.acquire_owned()
.await
.expect("blocking semaphore closed");
let name_for_measure = Arc::clone(&name_clone);
let join_handle = tokio::task::spawn_blocking(move || {
let _enter = span.enter();
measure_blocking(&name_for_measure, f)
});
let abort = join_handle.abort_handle();
{
let mut state = inner.state.lock();
if let Some(entry) = state.tasks.get_mut(&name_clone) {
entry.abort_handle = abort;
}
}
let kind = match join_handle.await {
Ok(val) => {
let _ = tx.send(Ok(val));
CompletionKind::Normal
}
Err(e) if e.is_panic() => {
let _ = tx.send(Err(BlockingError::Panicked));
CompletionKind::Panicked
}
Err(_) => {
CompletionKind::Cancelled
}
};
let _ = completion_tx.send(Completion {
name: name_clone,
kind,
});
});
let abort = outer.abort_handle();
{
let mut state = self.inner.state.lock();
if let Some(old) = state.tasks.remove(&name) {
old.abort_handle.abort();
}
state.tasks.insert(
Arc::clone(&name),
TaskEntry {
name: Arc::clone(&name),
status: TaskStatus::Running,
started_at: Instant::now(),
restart_count: 0,
restart_policy: RestartPolicy::RunOnce,
abort_handle: abort.clone(),
factory: None,
},
);
}
BlockingHandle { rx, abort }
}
pub fn spawn_oneshot<F, Fut, R>(&self, name: Arc<str>, factory: F) -> BlockingHandle<R>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = R> + Send + 'static,
R: Send + 'static,
{
let (tx, rx) = oneshot::channel::<Result<R, BlockingError>>();
let cancel = self.inner.cancel.clone();
let span = tracing::info_span!("supervised_task", task.name = %name);
let join_handle: tokio::task::JoinHandle<Option<R>> = tokio::spawn(
async move {
let fut = factory();
tokio::select! {
result = fut => Some(result),
() = cancel.cancelled() => None,
}
}
.instrument(span),
);
let abort = join_handle.abort_handle();
{
let mut state = self.inner.state.lock();
if let Some(old) = state.tasks.remove(&name) {
old.abort_handle.abort();
}
state.tasks.insert(
Arc::clone(&name),
TaskEntry {
name: Arc::clone(&name),
status: TaskStatus::Running,
started_at: Instant::now(),
restart_count: 0,
restart_policy: RestartPolicy::RunOnce,
abort_handle: abort.clone(),
factory: None,
},
);
}
let completion_tx = self.inner.completion_tx.clone();
tokio::spawn(async move {
let kind = match join_handle.await {
Ok(Some(val)) => {
let _ = tx.send(Ok(val));
CompletionKind::Normal
}
Err(e) if e.is_panic() => {
let _ = tx.send(Err(BlockingError::Panicked));
CompletionKind::Panicked
}
_ => CompletionKind::Cancelled,
};
let _ = completion_tx.send(Completion { name, kind });
});
BlockingHandle { rx, abort }
}
pub fn abort(&self, name: &'static str) {
let state = self.inner.state.lock();
let key: Arc<str> = Arc::from(name);
if let Some(entry) = state.tasks.get(&key) {
entry.abort_handle.abort();
tracing::debug!(task.name = name, "task aborted via supervisor");
}
}
pub async fn shutdown_all(&self, timeout: Duration) {
self.inner.cancel.cancel();
let deadline = tokio::time::Instant::now() + timeout;
loop {
let active = self.active_count();
if active == 0 {
break;
}
if tokio::time::Instant::now() >= deadline {
tracing::warn!(
remaining = active,
"shutdown timeout — aborting remaining tasks"
);
let mut state = self.inner.state.lock();
for entry in state.tasks.values_mut() {
if matches!(
entry.status,
TaskStatus::Running | TaskStatus::Restarting { .. }
) {
entry.abort_handle.abort();
entry.status = TaskStatus::Aborted;
}
}
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
}
#[must_use]
pub fn snapshot(&self) -> Vec<TaskSnapshot> {
let state = self.inner.state.lock();
let mut snaps: Vec<TaskSnapshot> = state
.tasks
.values()
.map(|e| TaskSnapshot {
name: Arc::clone(&e.name),
status: e.status.clone(),
started_at: e.started_at,
restart_count: e.restart_count,
})
.collect();
snaps.sort_by_key(|s| s.started_at);
snaps
}
#[must_use]
pub fn active_count(&self) -> usize {
let state = self.inner.state.lock();
state
.tasks
.values()
.filter(|e| {
matches!(
e.status,
TaskStatus::Running | TaskStatus::Restarting { .. }
)
})
.count()
}
#[must_use]
pub fn cancellation_token(&self) -> CancellationToken {
self.inner.cancel.clone()
}
fn do_spawn(
name: &'static str,
factory: &BoxFactory,
cancel: CancellationToken,
) -> (AbortHandle, tokio::task::JoinHandle<()>) {
let fut = factory();
let span = tracing::info_span!("supervised_task", task.name = name);
let jh = tokio::spawn(
async move {
tokio::select! {
() = fut => {},
() = cancel.cancelled() => {},
}
}
.instrument(span),
);
let abort = jh.abort_handle();
(abort, jh)
}
fn wire_completion_reporter(
name: Arc<str>,
jh: tokio::task::JoinHandle<()>,
completion_tx: mpsc::UnboundedSender<Completion>,
) {
tokio::spawn(async move {
let kind = match jh.await {
Ok(()) => CompletionKind::Normal,
Err(e) if e.is_panic() => CompletionKind::Panicked,
Err(_) => CompletionKind::Cancelled,
};
let _ = completion_tx.send(Completion { name, kind });
});
}
fn start_reap_driver(
inner: Arc<Inner>,
mut completion_rx: mpsc::UnboundedReceiver<Completion>,
cancel: CancellationToken,
) {
tokio::spawn(async move {
loop {
tokio::select! {
biased;
Some(completion) = completion_rx.recv() => {
Self::handle_completion(&inner, completion).await;
}
() = cancel.cancelled() => {
while let Ok(completion) = completion_rx.try_recv() {
Self::handle_completion(&inner, completion).await;
}
break;
}
}
}
});
}
async fn handle_completion(inner: &Arc<Inner>, completion: Completion) {
let Some((attempt, max, delay)) = Self::classify_completion(inner, &completion) else {
return;
};
tracing::warn!(
task.name = %completion.name,
attempt,
max,
delay_ms = delay.as_millis(),
"restarting supervised task"
);
if !delay.is_zero() {
tokio::time::sleep(delay).await;
}
Self::do_restart(inner, &completion.name, attempt);
}
fn classify_completion(
inner: &Arc<Inner>,
completion: &Completion,
) -> Option<(u32, u32, Duration)> {
let mut state = inner.state.lock();
let entry = state.tasks.get_mut(&completion.name)?;
match completion.kind {
CompletionKind::Panicked => {
tracing::warn!(task.name = %completion.name, "supervised task panicked");
}
CompletionKind::Normal => {
tracing::info!(task.name = %completion.name, "supervised task completed");
}
CompletionKind::Cancelled => {
tracing::debug!(task.name = %completion.name, "supervised task cancelled");
}
}
match entry.restart_policy {
RestartPolicy::RunOnce => {
entry.status = TaskStatus::Completed;
state.tasks.remove(&completion.name);
None
}
RestartPolicy::Restart { max, base_delay } => {
if completion.kind != CompletionKind::Panicked {
entry.status = TaskStatus::Completed;
state.tasks.remove(&completion.name);
return None;
}
if entry.restart_count >= max {
let reason = format!("panicked after {max} restart(s)");
tracing::error!(
task.name = %completion.name,
attempts = max,
"task failed permanently"
);
entry.status = TaskStatus::Failed { reason };
None
} else {
let attempt = entry.restart_count + 1;
entry.status = TaskStatus::Restarting { attempt, max };
let multiplier = 1_u32
.checked_shl(attempt.saturating_sub(1))
.unwrap_or(u32::MAX);
let delay = base_delay.saturating_mul(multiplier).min(MAX_RESTART_DELAY);
Some((attempt, max, delay))
}
}
}
}
fn do_restart(inner: &Arc<Inner>, name: &Arc<str>, attempt: u32) {
let spawn_params = {
let mut state = inner.state.lock();
let Some(entry) = state.tasks.get_mut(name.as_ref()) else {
tracing::debug!(
task.name = %name,
"task removed during restart delay — skipping"
);
return;
};
if !matches!(entry.status, TaskStatus::Restarting { .. }) {
return;
}
let Some(factory) = &entry.factory else {
return;
};
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(factory)) {
Err(_) => {
let reason = format!("factory panicked on restart attempt {attempt}");
tracing::error!(task.name = %name, attempt, "factory panicked during restart");
entry.status = TaskStatus::Failed { reason };
None
}
Ok(fut) => Some((
fut,
inner.cancel.clone(),
inner.completion_tx.clone(),
name.clone(),
)),
}
};
let Some((fut, cancel, completion_tx, name)) = spawn_params else {
return;
};
let span = tracing::info_span!("supervised_task", task.name = %name);
let jh = tokio::spawn(
async move {
tokio::select! {
() = fut => {},
() = cancel.cancelled() => {},
}
}
.instrument(span),
);
let new_abort = jh.abort_handle();
{
let mut state = inner.state.lock();
if let Some(entry) = state.tasks.get_mut(name.as_ref()) {
entry.restart_count = attempt;
entry.status = TaskStatus::Running;
entry.abort_handle = new_abort;
}
}
Self::wire_completion_reporter(name.clone(), jh, completion_tx);
}
}
#[cfg(feature = "task-metrics")]
#[inline]
fn measure_blocking<F, R>(name: &str, f: F) -> R
where
F: FnOnce() -> R,
{
use cpu_time::ThreadTime;
let wall_start = std::time::Instant::now();
let cpu_start = ThreadTime::now();
let result = f();
let wall_ms = wall_start.elapsed().as_secs_f64() * 1000.0;
let cpu_ms = cpu_start.elapsed().as_secs_f64() * 1000.0;
metrics::histogram!("zeph.task.wall_time_ms", "task" => name.to_owned()).record(wall_ms);
metrics::histogram!("zeph.task.cpu_time_ms", "task" => name.to_owned()).record(cpu_ms);
tracing::Span::current().record("task.wall_time_ms", wall_ms);
tracing::Span::current().record("task.cpu_time_ms", cpu_ms);
result
}
#[cfg(not(feature = "task-metrics"))]
#[inline]
fn measure_blocking<F, R>(_name: &str, f: F) -> R
where
F: FnOnce() -> R,
{
f()
}
impl BlockingSpawner for TaskSupervisor {
fn spawn_blocking_named(
&self,
name: Arc<str>,
f: Box<dyn FnOnce() + Send + 'static>,
) -> tokio::task::JoinHandle<()> {
let handle = self.spawn_blocking(Arc::clone(&name), f);
tokio::spawn(async move {
if let Err(e) = handle.join().await {
tracing::error!(task.name = %name, error = %e, "supervised blocking task failed");
}
})
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
use tokio_util::sync::CancellationToken;
use super::*;
fn make_supervisor() -> (TaskSupervisor, CancellationToken) {
let cancel = CancellationToken::new();
let sup = TaskSupervisor::new(cancel.clone());
(sup, cancel)
}
#[tokio::test]
async fn test_spawn_and_complete() {
let (sup, _cancel) = make_supervisor();
let done = Arc::new(tokio::sync::Notify::new());
let done2 = Arc::clone(&done);
sup.spawn(TaskDescriptor {
name: "simple",
restart: RestartPolicy::RunOnce,
factory: move || {
let d = Arc::clone(&done2);
async move {
d.notify_one();
}
},
});
tokio::time::timeout(Duration::from_secs(2), done.notified())
.await
.expect("task should complete");
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(
sup.active_count(),
0,
"RunOnce task should be removed after completion"
);
}
#[tokio::test]
async fn test_panic_capture() {
let (sup, _cancel) = make_supervisor();
sup.spawn(TaskDescriptor {
name: "panicking",
restart: RestartPolicy::RunOnce,
factory: || async { panic!("intentional test panic") },
});
tokio::time::sleep(Duration::from_millis(200)).await;
let snaps = sup.snapshot();
assert!(
snaps.iter().all(|s| s.name.as_ref() != "panicking"),
"entry should be reaped"
);
assert_eq!(
sup.active_count(),
0,
"active count must be 0 after RunOnce panic"
);
}
#[tokio::test]
async fn test_restart_only_on_panic() {
let (sup, _cancel) = make_supervisor();
let normal_counter = Arc::new(AtomicU32::new(0));
let nc = Arc::clone(&normal_counter);
sup.spawn(TaskDescriptor {
name: "normal-exit",
restart: RestartPolicy::Restart {
max: 3,
base_delay: Duration::from_millis(10),
},
factory: move || {
let c = Arc::clone(&nc);
async move {
c.fetch_add(1, Ordering::SeqCst);
}
},
});
tokio::time::sleep(Duration::from_millis(300)).await;
assert_eq!(
normal_counter.load(Ordering::SeqCst),
1,
"normal exit must not restart"
);
assert!(
sup.snapshot()
.iter()
.all(|s| s.name.as_ref() != "normal-exit"),
"entry removed after normal exit"
);
let panic_counter = Arc::new(AtomicU32::new(0));
let pc = Arc::clone(&panic_counter);
sup.spawn(TaskDescriptor {
name: "panic-exit",
restart: RestartPolicy::Restart {
max: 2,
base_delay: Duration::from_millis(10),
},
factory: move || {
let c = Arc::clone(&pc);
async move {
c.fetch_add(1, Ordering::SeqCst);
panic!("test panic");
}
},
});
tokio::time::sleep(Duration::from_millis(500)).await;
assert!(
panic_counter.load(Ordering::SeqCst) >= 3,
"panicking task must restart max times"
);
let snap = sup
.snapshot()
.into_iter()
.find(|s| s.name.as_ref() == "panic-exit");
assert!(
matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
"task must be Failed after exhausting restarts"
);
}
#[tokio::test]
async fn test_restart_policy() {
let (sup, _cancel) = make_supervisor();
let counter = Arc::new(AtomicU32::new(0));
let counter2 = Arc::clone(&counter);
sup.spawn(TaskDescriptor {
name: "restartable",
restart: RestartPolicy::Restart {
max: 2,
base_delay: Duration::from_millis(10),
},
factory: move || {
let c = Arc::clone(&counter2);
async move {
c.fetch_add(1, Ordering::SeqCst);
panic!("always panic");
}
},
});
tokio::time::sleep(Duration::from_millis(500)).await;
let runs = counter.load(Ordering::SeqCst);
assert!(
runs >= 3,
"expected at least 3 invocations (initial + 2 restarts), got {runs}"
);
let snaps = sup.snapshot();
let snap = snaps.iter().find(|s| s.name.as_ref() == "restartable");
assert!(snap.is_some(), "failed task should remain in registry");
assert!(
matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
"task should be Failed after exhausting retries"
);
}
#[tokio::test]
async fn test_exponential_backoff() {
let (sup, _cancel) = make_supervisor();
let timestamps = Arc::new(parking_lot::Mutex::new(Vec::<std::time::Instant>::new()));
let ts = Arc::clone(×tamps);
sup.spawn(TaskDescriptor {
name: "backoff-task",
restart: RestartPolicy::Restart {
max: 3,
base_delay: Duration::from_millis(50),
},
factory: move || {
let t = Arc::clone(&ts);
async move {
t.lock().push(std::time::Instant::now());
panic!("always panic");
}
},
});
tokio::time::sleep(Duration::from_millis(800)).await;
let ts = timestamps.lock();
assert!(
ts.len() >= 3,
"expected at least 3 invocations, got {}",
ts.len()
);
if ts.len() >= 3 {
let d1 = ts[1].duration_since(ts[0]);
let d2 = ts[2].duration_since(ts[1]);
assert!(
d2 >= d1.mul_f64(1.5),
"expected exponential backoff: d1={d1:?} d2={d2:?}"
);
}
}
#[tokio::test]
async fn test_graceful_shutdown() {
let (sup, _cancel) = make_supervisor();
for name in ["svc-a", "svc-b", "svc-c"] {
sup.spawn(TaskDescriptor {
name,
restart: RestartPolicy::RunOnce,
factory: || async {
tokio::time::sleep(Duration::from_secs(60)).await;
},
});
}
assert_eq!(sup.active_count(), 3);
tokio::time::timeout(
Duration::from_secs(2),
sup.shutdown_all(Duration::from_secs(1)),
)
.await
.expect("shutdown should complete within timeout");
}
#[tokio::test]
async fn test_force_abort_marks_aborted() {
let cancel = CancellationToken::new();
let sup = TaskSupervisor::new(cancel.clone());
sup.spawn(TaskDescriptor {
name: "stubborn-for-abort",
restart: RestartPolicy::RunOnce,
factory: || async {
std::future::pending::<()>().await;
},
});
sup.shutdown_all(Duration::from_millis(1)).await;
let snaps = sup.snapshot();
if let Some(snap) = snaps
.iter()
.find(|s| s.name.as_ref() == "stubborn-for-abort")
{
assert_eq!(
snap.status,
TaskStatus::Aborted,
"force-aborted task must have Aborted status"
);
}
}
#[tokio::test]
async fn test_registry_snapshot() {
let (sup, _cancel) = make_supervisor();
for name in ["alpha", "beta"] {
sup.spawn(TaskDescriptor {
name,
restart: RestartPolicy::RunOnce,
factory: || async {
tokio::time::sleep(Duration::from_secs(10)).await;
},
});
}
let snaps = sup.snapshot();
assert_eq!(snaps.len(), 2);
let names: Vec<&str> = snaps.iter().map(|s| s.name.as_ref()).collect();
assert!(names.contains(&"alpha"));
assert!(names.contains(&"beta"));
assert!(snaps.iter().all(|s| s.status == TaskStatus::Running));
}
#[tokio::test]
async fn test_blocking_returns_value() {
let (sup, cancel) = make_supervisor();
let handle: BlockingHandle<u32> = sup.spawn_blocking(Arc::from("compute"), || 42_u32);
let result = handle.join().await.expect("should return value");
assert_eq!(result, 42);
cancel.cancel();
}
#[tokio::test]
async fn test_blocking_panic() {
let (sup, _cancel) = make_supervisor();
let handle: BlockingHandle<u32> =
sup.spawn_blocking(Arc::from("panicking-compute"), || panic!("intentional"));
let err = handle
.join()
.await
.expect_err("should return error on panic");
assert_eq!(err, BlockingError::Panicked);
}
#[tokio::test]
async fn test_blocking_registered_in_registry() {
let (sup, cancel) = make_supervisor();
let (tx, rx) = std::sync::mpsc::channel::<()>();
let _handle: BlockingHandle<()> =
sup.spawn_blocking(Arc::from("blocking-task"), move || {
let _ = rx.recv();
});
tokio::time::sleep(Duration::from_millis(10)).await;
assert_eq!(
sup.active_count(),
1,
"blocking task must appear in active_count"
);
let _ = tx.send(());
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(
sup.active_count(),
0,
"blocking task must be removed after completion"
);
cancel.cancel();
}
#[tokio::test]
async fn test_oneshot_registered_in_registry() {
let (sup, cancel) = make_supervisor();
let (tx, rx) = tokio::sync::oneshot::channel::<()>();
let _handle: BlockingHandle<()> =
sup.spawn_oneshot(Arc::from("oneshot-task"), move || async move {
let _ = rx.await;
});
tokio::time::sleep(Duration::from_millis(10)).await;
assert_eq!(
sup.active_count(),
1,
"oneshot task must appear in active_count"
);
let _ = tx.send(());
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(
sup.active_count(),
0,
"oneshot task must be removed after completion"
);
cancel.cancel();
}
#[tokio::test]
async fn test_restart_max_zero() {
let (sup, _cancel) = make_supervisor();
let counter = Arc::new(AtomicU32::new(0));
let counter2 = Arc::clone(&counter);
sup.spawn(TaskDescriptor {
name: "zero-max",
restart: RestartPolicy::Restart {
max: 0,
base_delay: Duration::from_millis(10),
},
factory: move || {
let c = Arc::clone(&counter2);
async move {
c.fetch_add(1, Ordering::SeqCst);
panic!("always panic");
}
},
});
tokio::time::sleep(Duration::from_millis(200)).await;
assert_eq!(
counter.load(Ordering::SeqCst),
1,
"max=0 should not restart"
);
let snaps = sup.snapshot();
let snap = snaps.iter().find(|s| s.name.as_ref() == "zero-max");
assert!(snap.is_some(), "entry should remain as Failed");
assert!(
matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
"status should be Failed"
);
}
#[tokio::test]
async fn test_concurrent_spawns() {
let (sup, cancel) = make_supervisor();
static NAMES: [&str; 50] = [
"t00", "t01", "t02", "t03", "t04", "t05", "t06", "t07", "t08", "t09", "t10", "t11",
"t12", "t13", "t14", "t15", "t16", "t17", "t18", "t19", "t20", "t21", "t22", "t23",
"t24", "t25", "t26", "t27", "t28", "t29", "t30", "t31", "t32", "t33", "t34", "t35",
"t36", "t37", "t38", "t39", "t40", "t41", "t42", "t43", "t44", "t45", "t46", "t47",
"t48", "t49",
];
let completed = Arc::new(AtomicU32::new(0));
for name in &NAMES {
let c = Arc::clone(&completed);
sup.spawn(TaskDescriptor {
name,
restart: RestartPolicy::RunOnce,
factory: move || {
let c = Arc::clone(&c);
async move {
c.fetch_add(1, Ordering::SeqCst);
}
},
});
}
tokio::time::timeout(Duration::from_secs(5), async {
loop {
if completed.load(Ordering::SeqCst) == 50 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await
.expect("all 50 tasks should complete");
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(sup.active_count(), 0, "all tasks must be reaped");
cancel.cancel();
}
#[tokio::test]
async fn test_shutdown_timeout_expiry() {
let cancel = CancellationToken::new();
let sup = TaskSupervisor::new(cancel.clone());
sup.spawn(TaskDescriptor {
name: "stubborn",
restart: RestartPolicy::RunOnce,
factory: || async {
tokio::time::sleep(Duration::from_secs(60)).await;
},
});
assert_eq!(sup.active_count(), 1);
tokio::time::timeout(
Duration::from_secs(2),
sup.shutdown_all(Duration::from_millis(50)),
)
.await
.expect("shutdown_all should return even on timeout expiry");
assert!(
cancel.is_cancelled(),
"cancel token must be cancelled after shutdown"
);
}
#[tokio::test]
async fn test_cancellation_token() {
let cancel = CancellationToken::new();
let sup = TaskSupervisor::new(cancel.clone());
assert!(!sup.cancellation_token().is_cancelled());
sup.shutdown_all(Duration::from_millis(100)).await;
assert!(
sup.cancellation_token().is_cancelled(),
"token must be cancelled after shutdown"
);
}
#[tokio::test]
async fn test_blocking_spawner_task_appears_in_snapshot() {
use zeph_common::BlockingSpawner;
let cancel = CancellationToken::new();
let sup = TaskSupervisor::new(cancel);
let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<()>();
let (release_tx, release_rx) = tokio::sync::oneshot::channel::<()>();
let handle = sup.spawn_blocking_named(
Arc::from("chunk_file"),
Box::new(move || {
let _ = ready_tx.send(());
let _ = release_rx.blocking_recv();
}),
);
ready_rx.await.expect("task should start");
let snapshot = sup.snapshot();
assert!(
snapshot.iter().any(|t| t.name.as_ref() == "chunk_file"),
"chunk_file task must appear in supervisor snapshot"
);
let _ = release_tx.send(());
handle.await.expect("task should complete");
}
#[cfg(feature = "task-metrics")]
#[test]
fn test_measure_blocking_emits_metrics() {
use metrics_util::debugging::DebuggingRecorder;
let recorder = DebuggingRecorder::new();
let snapshotter = recorder.snapshotter();
metrics::with_local_recorder(&recorder, || {
measure_blocking("test_task", || std::hint::black_box(42_u64));
});
let snapshot = snapshotter.snapshot();
let metric_names: Vec<String> = snapshot
.into_vec()
.into_iter()
.map(|(k, _, _, _)| k.key().name().to_owned())
.collect();
assert!(
metric_names.iter().any(|n| n == "zeph.task.wall_time_ms"),
"expected zeph.task.wall_time_ms histogram; got: {metric_names:?}"
);
assert!(
metric_names.iter().any(|n| n == "zeph.task.cpu_time_ms"),
"expected zeph.task.cpu_time_ms histogram; got: {metric_names:?}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_spawn_blocking_semaphore_cap() {
let (sup, _cancel) = make_supervisor();
let concurrent = Arc::new(AtomicU32::new(0));
let max_concurrent = Arc::new(AtomicU32::new(0));
let barrier = Arc::new(std::sync::Barrier::new(1));
let mut handles = Vec::new();
for i in 0u32..16 {
let c = Arc::clone(&concurrent);
let m = Arc::clone(&max_concurrent);
let name: Arc<str> = Arc::from(format!("blocking-{i}").as_str());
let h = sup.spawn_blocking(name, move || {
let prev = c.fetch_add(1, Ordering::SeqCst);
let mut cur_max = m.load(Ordering::SeqCst);
while prev + 1 > cur_max {
match m.compare_exchange(cur_max, prev + 1, Ordering::SeqCst, Ordering::SeqCst)
{
Ok(_) => break,
Err(x) => cur_max = x,
}
}
std::thread::sleep(std::time::Duration::from_millis(20));
c.fetch_sub(1, Ordering::SeqCst);
});
handles.push(h);
}
for h in handles {
h.join().await.expect("blocking task should succeed");
}
drop(barrier);
let observed = max_concurrent.load(Ordering::SeqCst);
assert!(
observed <= 8,
"observed {observed} concurrent blocking tasks; expected ≤ 8 (semaphore cap)"
);
}
}