Skip to main content

runledger_runtime/
supervisor.rs

1use std::borrow::Borrow;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::task::{Context, Poll};
7use std::time::Duration;
8
9use futures_util::stream::{FuturesUnordered, StreamExt};
10use tokio::runtime::Handle;
11use tokio::sync::watch;
12use tokio::task::{AbortHandle, JoinError, JoinHandle};
13use tokio::time::Instant;
14use tracing::{Instrument, debug, error, info_span, warn};
15
16use crate::catalog::JobCatalog;
17use crate::config::JobsConfig;
18use crate::reaper::run_reaper_loop;
19use crate::registry::JobRegistry;
20use crate::scheduler::run_scheduler_loop;
21use crate::worker::run_worker_loop;
22use crate::{Error, Result, RuntimeError, RuntimeLoopExit};
23
24const WORKER_TASK: &str = "worker";
25const SCHEDULER_TASK: &str = "scheduler";
26const REAPER_TASK: &str = "reaper";
27const MAX_ABORT_DRAIN_TIMEOUT: Duration = Duration::from_secs(1);
28
29/// Supervises the Runledger runtime loops spawned for a worker process.
30///
31/// A supervisor owns the worker, scheduler, and reaper task handles selected by
32/// [`SupervisorBuilder`]. Use [`Self::run_until_shutdown`] for a typical worker
33/// process that should exit on either an external shutdown signal or an internal
34/// runtime task failure.
35///
36/// Dropping a supervisor requests shutdown and detaches the task handles. Call
37/// [`Self::shutdown`] or [`Self::join`] when the owning process needs to observe
38/// panics or unexpected task exits.
39#[must_use]
40pub struct Supervisor {
41    shutdown_tx: watch::Sender<bool>,
42    shutdown_requested: Arc<AtomicBool>,
43    tasks: Vec<RuntimeTask>,
44}
45
46/// Builds a [`Supervisor`] with configurable runtime loops.
47///
48/// Worker, scheduler, and reaper loops are enabled by default. Call
49/// [`Self::with_registry`] or [`Self::with_catalog`] before [`Self::build`] when
50/// worker or reaper loops remain enabled.
51#[must_use]
52pub struct SupervisorBuilder<'a> {
53    pool: &'a runledger_postgres::DbPool,
54    runtime: Handle,
55    registry: Option<JobRegistry>,
56    registry_source: Option<RegistrySource>,
57    mixed_registry_sources: bool,
58    config: JobsConfig,
59    worker_enabled: bool,
60    scheduler_enabled: bool,
61    reaper_enabled: bool,
62}
63
64/// Cloneable handle for requesting supervisor shutdown from another task.
65#[derive(Clone)]
66pub struct SupervisorShutdown {
67    shutdown_tx: watch::Sender<bool>,
68    shutdown_requested: Arc<AtomicBool>,
69}
70
71struct RuntimeTask {
72    name: &'static str,
73    handle: JoinHandle<RuntimeTaskExit>,
74}
75
76#[derive(Clone, Copy, Debug, Eq, PartialEq)]
77enum RegistrySource {
78    Registry,
79    Catalog,
80}
81
82#[derive(Clone, Copy, Debug, Eq, PartialEq)]
83enum RuntimeTaskExit {
84    Completed,
85    Shutdown,
86}
87
88#[derive(Clone, Copy, Debug, Eq, PartialEq)]
89enum DrainResult {
90    Drained,
91    TimedOut,
92}
93
94struct RuntimeTaskFuture {
95    name: &'static str,
96    future: Pin<Box<dyn Future<Output = RuntimeTaskExit> + Send>>,
97    started: bool,
98}
99
100type RuntimeTaskJoinResult = std::result::Result<RuntimeTaskExit, JoinError>;
101type JoinedRuntimeTask = (&'static str, RuntimeTaskJoinResult);
102
103impl Supervisor {
104    /// Returns a builder for a supervisor over a shared pool and runtime
105    /// configuration.
106    ///
107    /// This validates that the caller is inside the Tokio runtime that will own
108    /// spawned supervisor tasks.
109    pub fn builder(
110        pool: &runledger_postgres::DbPool,
111        config: JobsConfig,
112    ) -> std::result::Result<SupervisorBuilder<'_>, RuntimeError> {
113        let runtime =
114            Handle::try_current().map_err(|source| RuntimeError::MissingTokioRuntime { source })?;
115
116        Ok(SupervisorBuilder {
117            pool,
118            runtime,
119            registry: None,
120            registry_source: None,
121            mixed_registry_sources: false,
122            config,
123            worker_enabled: true,
124            scheduler_enabled: true,
125            reaper_enabled: true,
126        })
127    }
128
129    /// Returns a cloneable shutdown handle that can request shutdown without
130    /// owning the supervisor task joins.
131    #[must_use]
132    pub fn shutdown_handle(&self) -> SupervisorShutdown {
133        SupervisorShutdown {
134            shutdown_tx: self.shutdown_tx.clone(),
135            shutdown_requested: Arc::clone(&self.shutdown_requested),
136        }
137    }
138
139    /// Requests graceful shutdown of all supervised loops.
140    pub fn request_shutdown(&self) {
141        request_shutdown_signal(&self.shutdown_tx, self.shutdown_requested.as_ref());
142    }
143
144    /// Returns whether shutdown has been requested through this supervisor or a
145    /// clone of its shutdown handle.
146    #[must_use]
147    pub fn is_shutdown_requested(&self) -> bool {
148        self.shutdown_requested.load(Ordering::SeqCst)
149    }
150
151    /// Waits for all supervised loops to exit.
152    ///
153    /// With the default long-running loops, this method waits until shutdown is
154    /// requested through a [`SupervisorShutdown`] handle or until a task exits.
155    /// If a loop exits before shutdown was requested, the remaining loops are
156    /// asked to shut down and the first observed error is returned. Additional
157    /// task failures observed while draining are logged. This method does not
158    /// impose a deadline; use [`Self::shutdown_with_timeout`] when the caller
159    /// owns shutdown and needs a bounded wait.
160    pub async fn join(mut self) -> Result<()> {
161        let tasks = std::mem::take(&mut self.tasks);
162        self.join_tasks(tasks).await
163    }
164
165    /// Requests graceful shutdown and waits for all supervised loops to exit.
166    ///
167    /// If a loop exits before shutdown was requested, the remaining loops are
168    /// asked to shut down and the pre-existing task exit is reported, even when
169    /// that exit is only observed after shutdown begins. This method does not
170    /// impose a deadline. Use [`Self::shutdown_with_timeout`] when the owning
171    /// process needs a shutdown budget; externally timing out this consuming
172    /// future can detach still-running task handles.
173    pub async fn shutdown(mut self) -> Result<()> {
174        if let Some(error) = join_pre_shutdown_finished_tasks(&mut self.tasks).await {
175            self.request_shutdown();
176            let tasks = std::mem::take(&mut self.tasks);
177            drain_tasks(tasks).await;
178            return Err(Error::Runtime(error));
179        }
180
181        self.request_shutdown();
182        let tasks = std::mem::take(&mut self.tasks);
183        self.join_tasks(tasks).await
184    }
185
186    /// Waits until `shutdown` resolves or a supervised task fails, then exits.
187    ///
188    /// If `shutdown` resolves first, graceful shutdown is requested and the
189    /// supervisor waits up to `timeout` for all loops to exit. If a loop panics
190    /// or exits unexpectedly before `shutdown` resolves, shutdown is requested
191    /// for the remaining loops and the original task error is returned after
192    /// those loops drain or a timeout is reported. If shutdown is requested
193    /// through a [`SupervisorShutdown`] handle and every loop exits cleanly before
194    /// `shutdown` resolves, this returns successfully.
195    ///
196    /// This is the preferred method for worker binaries because it observes
197    /// internal task failures during normal operation while still applying a
198    /// bounded shutdown budget to cooperative process termination.
199    ///
200    /// If `timeout` is too large to represent as a runtime deadline, this returns
201    /// [`RuntimeError::ShutdownTimeoutTooLarge`] immediately. A zero timeout
202    /// requests shutdown, aborts tasks without waiting for cooperative exits, and
203    /// reports [`RuntimeError::ShutdownTimeout`].
204    ///
205    /// If the initial timeout validation fails before `shutdown` resolves, the
206    /// supervisor is still dropped, so shutdown is requested, but task handles
207    /// are not aborted or drained. If a deadline overflow is detected after
208    /// shutdown begins, remaining tasks are aborted and drained before returning.
209    pub async fn run_until_shutdown<F>(mut self, shutdown: F, timeout: Duration) -> Result<()>
210    where
211        F: Future<Output = ()>,
212    {
213        // Validate the shutdown budget before waiting on a signal, then
214        // recompute the deadline when shutdown actually begins.
215        let _ = shutdown_deadline(timeout)?;
216        let tasks = std::mem::take(&mut self.tasks);
217        if tasks.is_empty() {
218            shutdown.await;
219            self.request_shutdown();
220            return Ok(());
221        }
222
223        let mut abort_handles = Some(task_abort_handles(&tasks));
224        let mut joined = join_runtime_tasks(tasks);
225        let mut shutdown = std::pin::pin!(shutdown);
226
227        loop {
228            tokio::select! {
229                _ = shutdown.as_mut() => {
230                    self.request_shutdown();
231                    let abort_handles = abort_handles.take().expect("abort handles are consumed on return");
232                    // The budget was validated before waiting; recompute here so
233                    // the timeout starts when shutdown begins, and abort instead
234                    // of detaching if this late recompute somehow overflows.
235                    let deadline = match shutdown_deadline(timeout) {
236                        Ok(deadline) => deadline,
237                        Err(error) => {
238                            abort_and_drain_joined_tasks_or_log(
239                                &mut joined,
240                                abort_handles,
241                                abort_drain_timeout(timeout),
242                            )
243                            .await;
244                            return Err(error.into());
245                        }
246                    };
247                    return self
248                        .join_joined_tasks_with_timeout(
249                            &mut joined,
250                            abort_handles,
251                            timeout,
252                            deadline,
253                        )
254                        .await;
255                }
256                joined_result = joined.next() => {
257                    let Some((task, result)) = joined_result else {
258                        return Ok(());
259                    };
260                    let Some(error) = classify_task_result(task, result) else {
261                        continue;
262                    };
263
264                    self.request_shutdown();
265                    let abort_handles = abort_handles.take().expect("abort handles are consumed on return");
266                    // Start the drain budget when shutdown begins. If this
267                    // pathological recompute overflows, abort already-running
268                    // tasks rather than dropping their handles detached.
269                    let deadline = match shutdown_deadline(timeout) {
270                        Ok(deadline) => deadline,
271                        Err(error) => {
272                            abort_and_drain_joined_tasks_or_log(
273                                &mut joined,
274                                abort_handles,
275                                abort_drain_timeout(timeout),
276                            )
277                            .await;
278                            return Err(error.into());
279                        }
280                    };
281                    return drain_after_task_error_with_timeout(
282                        &mut joined,
283                        abort_handles,
284                        timeout,
285                        deadline,
286                        error,
287                    )
288                    .await;
289                }
290            }
291        }
292    }
293
294    /// Requests graceful shutdown and waits up to `timeout` for all supervised
295    /// loops to exit.
296    ///
297    /// If a loop had already exited before this method begins shutdown, that
298    /// failure is returned after the remaining loops have had the same shutdown
299    /// budget to exit cooperatively. If the timeout expires, remaining tasks are
300    /// aborted and drained with a bounded cleanup attempt before a timeout error
301    /// is returned. Abort cleanup can make total wall-clock time exceed `timeout`
302    /// by up to one second, or `timeout`, whichever is smaller. A zero timeout
303    /// requests shutdown, immediately aborts tasks that did not already finish,
304    /// and reports [`RuntimeError::ShutdownTimeout`].
305    ///
306    /// If `timeout` is too large to represent as a runtime deadline, this returns
307    /// [`RuntimeError::ShutdownTimeoutTooLarge`] immediately. The supervisor is
308    /// still dropped, so shutdown is requested, but task handles are not aborted
309    /// or drained.
310    pub async fn shutdown_with_timeout(mut self, timeout: Duration) -> Result<()> {
311        let deadline = shutdown_deadline(timeout)?;
312
313        if let Some(error) = join_pre_shutdown_finished_tasks(&mut self.tasks).await {
314            self.request_shutdown();
315            let tasks = std::mem::take(&mut self.tasks);
316            let abort_handles = task_abort_handles(&tasks);
317            let mut joined = join_runtime_tasks(tasks);
318
319            return drain_after_task_error_with_timeout(
320                &mut joined,
321                abort_handles,
322                timeout,
323                deadline,
324                error,
325            )
326            .await;
327        }
328
329        self.request_shutdown();
330        let tasks = std::mem::take(&mut self.tasks);
331        self.join_tasks_with_timeout(tasks, timeout, deadline).await
332    }
333
334    async fn join_tasks(&self, tasks: Vec<RuntimeTask>) -> Result<()> {
335        let mut joined = join_runtime_tasks(tasks);
336
337        while let Some((task, result)) = joined.next().await {
338            if let Some(error) = classify_task_result(task, result) {
339                self.request_shutdown();
340                drain_joined_tasks(&mut joined).await;
341                return Err(Error::Runtime(error));
342            }
343        }
344
345        Ok(())
346    }
347
348    async fn join_tasks_with_timeout(
349        &self,
350        tasks: Vec<RuntimeTask>,
351        timeout: Duration,
352        deadline: Instant,
353    ) -> Result<()> {
354        let abort_handles = task_abort_handles(&tasks);
355        let mut joined = join_runtime_tasks(tasks);
356
357        self.join_joined_tasks_with_timeout(&mut joined, abort_handles, timeout, deadline)
358            .await
359    }
360
361    async fn join_joined_tasks_with_timeout(
362        &self,
363        joined: &mut FuturesUnordered<impl Future<Output = JoinedRuntimeTask>>,
364        abort_handles: Vec<AbortHandle>,
365        timeout: Duration,
366        deadline: Instant,
367    ) -> Result<()> {
368        loop {
369            match tokio::time::timeout_at(deadline, joined.next()).await {
370                Ok(Some((task, result))) => {
371                    if let Some(error) = classify_task_result(task, result) {
372                        self.request_shutdown();
373                        return drain_after_task_error_with_timeout(
374                            joined,
375                            abort_handles,
376                            timeout,
377                            deadline,
378                            error,
379                        )
380                        .await;
381                    }
382                }
383                Ok(None) => return Ok(()),
384                Err(_) => {
385                    abort_and_drain_joined_tasks_or_log(
386                        joined,
387                        abort_handles,
388                        abort_drain_timeout(timeout),
389                    )
390                    .await;
391                    return Err(Error::Runtime(RuntimeError::ShutdownTimeout { timeout }));
392                }
393            }
394        }
395    }
396
397    #[cfg(test)]
398    fn from_tasks_for_tests(tasks: Vec<RuntimeTask>) -> Self {
399        // These synthetic tasks do not receive this channel; tests using this
400        // helper either finish without shutdown or exercise timeout abort paths.
401        let (shutdown_tx, _) = watch::channel(false);
402        Self {
403            shutdown_tx,
404            shutdown_requested: Arc::new(AtomicBool::new(false)),
405            tasks,
406        }
407    }
408}
409
410impl Drop for Supervisor {
411    fn drop(&mut self) {
412        if !self.tasks.is_empty() {
413            warn!(
414                task_count = self.tasks.len(),
415                "dropping jobs runtime supervisor before joining tasks; tasks may continue detached after shutdown is requested and later panics will not be observed"
416            );
417        }
418        // Drop cannot await task handles, so this only nudges loops to exit.
419        self.request_shutdown();
420    }
421}
422
423impl<'a> SupervisorBuilder<'a> {
424    /// Registers the handlers used by worker execution and reaper terminal hooks.
425    ///
426    /// A registry is required when worker or reaper loops are enabled. Scheduler-only
427    /// supervisors can be built without one.
428    #[must_use = "builder methods return an updated builder value"]
429    pub fn with_registry(mut self, registry: JobRegistry) -> Self {
430        self.mixed_registry_sources |= self.registry_source == Some(RegistrySource::Catalog);
431        self.registry_source = Some(RegistrySource::Registry);
432        self.registry = Some(registry);
433        self
434    }
435
436    /// Registers handlers from a [`JobCatalog`].
437    ///
438    /// This does not sync database job definitions. Call
439    /// [`JobCatalog::sync_definitions`] before starting the supervisor or
440    /// creating schedules. Pass `&catalog` when the caller will continue using
441    /// the catalog for schedule, enqueue, or workflow helpers after building the
442    /// supervisor.
443    ///
444    /// # Registry Source
445    ///
446    /// Calling this and [`Self::with_registry`] on the same builder is rejected
447    /// by [`Self::build`]. Choose one registration source per builder.
448    #[must_use = "builder methods return an updated builder value"]
449    pub fn with_catalog(mut self, catalog: impl Borrow<JobCatalog>) -> Self {
450        self.mixed_registry_sources |= self.registry_source == Some(RegistrySource::Registry);
451        self.registry_source = Some(RegistrySource::Catalog);
452        self.registry = Some(catalog.borrow().to_registry());
453        self
454    }
455
456    /// Disables worker job claiming and execution for this supervisor.
457    #[must_use = "builder methods return an updated builder value"]
458    pub fn disable_worker(mut self) -> Self {
459        self.worker_enabled = false;
460        self
461    }
462
463    /// Disables cron schedule materialization for this supervisor.
464    #[must_use = "builder methods return an updated builder value"]
465    pub fn disable_scheduler(mut self) -> Self {
466        self.scheduler_enabled = false;
467        self
468    }
469
470    /// Disables expired-lease reaping for this supervisor.
471    #[must_use = "builder methods return an updated builder value"]
472    pub fn disable_reaper(mut self) -> Self {
473        self.reaper_enabled = false;
474        self
475    }
476
477    /// Starts the enabled runtime loops and returns the owning supervisor.
478    ///
479    /// Returns an error when worker or reaper loops are enabled without a job
480    /// registry.
481    pub fn build(self) -> std::result::Result<Supervisor, RuntimeError> {
482        let Self {
483            pool,
484            runtime,
485            registry,
486            registry_source: _,
487            mixed_registry_sources,
488            config,
489            worker_enabled,
490            scheduler_enabled,
491            reaper_enabled,
492        } = self;
493
494        if mixed_registry_sources {
495            return Err(RuntimeError::MixedRegistrySources);
496        }
497
498        let registry = match registry {
499            Some(registry) => registry,
500            None if worker_enabled || reaper_enabled => {
501                return Err(RuntimeError::MissingRegistry {
502                    worker_enabled,
503                    reaper_enabled,
504                });
505            }
506            None => JobRegistry::new(),
507        };
508
509        let (shutdown_tx, shutdown_rx) = watch::channel(false);
510        let shutdown_requested = Arc::new(AtomicBool::new(false));
511        let mut tasks = Vec::new();
512
513        if worker_enabled {
514            tasks.push(RuntimeTask::spawn_on(&runtime, WORKER_TASK, {
515                let pool = pool.clone();
516                let registry = registry.clone();
517                let config = config.clone();
518                let shutdown_rx = shutdown_rx.clone();
519                async move { run_worker_loop(pool, registry, config, shutdown_rx).await }
520            }));
521        }
522
523        if scheduler_enabled {
524            tasks.push(RuntimeTask::spawn_on(&runtime, SCHEDULER_TASK, {
525                let pool = pool.clone();
526                let config = config.clone();
527                let shutdown_rx = shutdown_rx.clone();
528                async move { run_scheduler_loop(pool, config, shutdown_rx).await }
529            }));
530        }
531
532        if reaper_enabled {
533            let pool = pool.clone();
534            let registry = registry.clone();
535            let config = config.clone();
536            let shutdown_rx = shutdown_rx.clone();
537            tasks.push(RuntimeTask::spawn_on(&runtime, REAPER_TASK, async move {
538                run_reaper_loop(pool, registry, config, shutdown_rx).await
539            }));
540        }
541
542        Ok(Supervisor {
543            shutdown_tx,
544            shutdown_requested,
545            tasks,
546        })
547    }
548}
549
550impl SupervisorShutdown {
551    /// Requests graceful shutdown of all loops watched by the supervisor.
552    pub fn request_shutdown(&self) {
553        request_shutdown_signal(&self.shutdown_tx, self.shutdown_requested.as_ref());
554    }
555
556    /// Returns whether shutdown has been requested.
557    #[must_use]
558    pub fn is_shutdown_requested(&self) -> bool {
559        self.shutdown_requested.load(Ordering::SeqCst)
560    }
561}
562
563impl RuntimeTask {
564    fn spawn_on<F>(runtime: &Handle, name: &'static str, future: F) -> Self
565    where
566        F: Future<Output = RuntimeLoopExit> + Send + 'static,
567    {
568        let span = info_span!("runledger_runtime_supervisor_task", task = name);
569        Self {
570            name,
571            handle: runtime.spawn(
572                RuntimeTaskFuture::new(name, async move { future.await.into() }).instrument(span),
573            ),
574        }
575    }
576
577    #[cfg(test)]
578    fn spawn<F>(name: &'static str, future: F) -> Self
579    where
580        F: Future<Output = RuntimeTaskExit> + Send + 'static,
581    {
582        Self {
583            name,
584            handle: tokio::spawn(RuntimeTaskFuture::new(name, future)),
585        }
586    }
587
588    async fn await_result(self) -> RuntimeTaskJoinResult {
589        self.handle.await
590    }
591}
592
593impl RuntimeTaskFuture {
594    fn new<F>(name: &'static str, future: F) -> Self
595    where
596        F: Future<Output = RuntimeTaskExit> + Send + 'static,
597    {
598        Self {
599            name,
600            future: Box::pin(future),
601            started: false,
602        }
603    }
604}
605
606impl Future for RuntimeTaskFuture {
607    type Output = RuntimeTaskExit;
608
609    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
610        let task = self.as_mut().get_mut();
611        if !task.started {
612            task.started = true;
613            debug!(task = task.name, "supervised runtime task started");
614        }
615
616        match task.future.as_mut().poll(cx) {
617            Poll::Pending => Poll::Pending,
618            Poll::Ready(exit) => {
619                debug!(task = task.name, ?exit, "supervised runtime task completed");
620                Poll::Ready(exit)
621            }
622        }
623    }
624}
625
626impl From<RuntimeLoopExit> for RuntimeTaskExit {
627    fn from(exit: RuntimeLoopExit) -> Self {
628        match exit {
629            RuntimeLoopExit::Shutdown => Self::Shutdown,
630            RuntimeLoopExit::Completed => Self::Completed,
631        }
632    }
633}
634
635fn request_shutdown_signal(shutdown_tx: &watch::Sender<bool>, shutdown_requested: &AtomicBool) {
636    if !shutdown_requested.swap(true, Ordering::SeqCst) {
637        // Mark the observable flag before notifying receivers so status readers
638        // agree that shutdown has begun as soon as request_shutdown returns.
639        // Safe to ignore: send failure means every shutdown receiver was dropped,
640        // so there are no supervised loops left to notify.
641        let _ = shutdown_tx.send(true);
642    }
643}
644
645fn take_finished_tasks(tasks: &mut Vec<RuntimeTask>) -> Vec<RuntimeTask> {
646    let mut finished = Vec::new();
647    let mut index = 0;
648    while index < tasks.len() {
649        if tasks[index].handle.is_finished() {
650            // Preserving order is not required for draining, and swap_remove
651            // avoids shifting still-running task handles.
652            finished.push(tasks.swap_remove(index));
653        } else {
654            index += 1;
655        }
656    }
657    finished
658}
659
660async fn join_pre_shutdown_finished_tasks(tasks: &mut Vec<RuntimeTask>) -> Option<RuntimeError> {
661    let finished = take_finished_tasks(tasks);
662    let mut first_error = None;
663
664    for task in finished {
665        let task_name = task.name;
666        let result = task.await_result().await;
667        let Some(error) = classify_task_result(task_name, result) else {
668            continue;
669        };
670
671        if first_error.is_none() {
672            first_error = Some(error);
673        } else {
674            log_drained_task_error(error);
675        }
676    }
677
678    first_error
679}
680
681fn join_runtime_tasks(
682    tasks: Vec<RuntimeTask>,
683) -> FuturesUnordered<impl Future<Output = JoinedRuntimeTask>> {
684    tasks
685        .into_iter()
686        .map(|task| async move {
687            let name = task.name;
688            (name, task.await_result().await)
689        })
690        .collect()
691}
692
693async fn drain_tasks(tasks: Vec<RuntimeTask>) {
694    let mut joined = join_runtime_tasks(tasks);
695    drain_joined_tasks(&mut joined).await;
696}
697
698fn task_abort_handles(tasks: &[RuntimeTask]) -> Vec<AbortHandle> {
699    tasks
700        .iter()
701        .map(|task| task.handle.abort_handle())
702        .collect()
703}
704
705fn shutdown_deadline(timeout: Duration) -> std::result::Result<Instant, RuntimeError> {
706    Instant::now()
707        .checked_add(timeout)
708        .ok_or(RuntimeError::ShutdownTimeoutTooLarge { timeout })
709}
710
711fn abort_drain_timeout(timeout: Duration) -> Duration {
712    timeout.min(MAX_ABORT_DRAIN_TIMEOUT)
713}
714
715async fn drain_joined_tasks(
716    joined: &mut FuturesUnordered<impl Future<Output = JoinedRuntimeTask>>,
717) {
718    while let Some((task, result)) = joined.next().await {
719        if let Some(error) = classify_task_result(task, result) {
720            log_drained_task_error(error);
721        }
722    }
723}
724
725async fn drain_joined_tasks_until_deadline(
726    joined: &mut FuturesUnordered<impl Future<Output = JoinedRuntimeTask>>,
727    deadline: Instant,
728) -> DrainResult {
729    loop {
730        match tokio::time::timeout_at(deadline, joined.next()).await {
731            Ok(Some((task, result))) => {
732                if let Some(error) = classify_task_result(task, result) {
733                    log_drained_task_error(error);
734                }
735            }
736            Ok(None) => return DrainResult::Drained,
737            Err(_) => return DrainResult::TimedOut,
738        }
739    }
740}
741
742async fn drain_after_task_error_with_timeout(
743    joined: &mut FuturesUnordered<impl Future<Output = JoinedRuntimeTask>>,
744    abort_handles: Vec<AbortHandle>,
745    timeout: Duration,
746    deadline: Instant,
747    error: RuntimeError,
748) -> Result<()> {
749    if matches!(
750        drain_joined_tasks_until_deadline(joined, deadline).await,
751        DrainResult::Drained
752    ) {
753        return Err(error.into());
754    }
755
756    abort_and_drain_joined_tasks_or_log(joined, abort_handles, abort_drain_timeout(timeout)).await;
757    Err(RuntimeError::ShutdownTimeoutAfterTaskError {
758        timeout,
759        source: Box::new(error),
760    }
761    .into())
762}
763
764async fn abort_and_drain_joined_tasks_with_timeout(
765    joined: &mut FuturesUnordered<impl Future<Output = JoinedRuntimeTask>>,
766    abort_handles: Vec<AbortHandle>,
767    timeout: Duration,
768) -> DrainResult {
769    for abort_handle in abort_handles {
770        abort_handle.abort();
771    }
772
773    match tokio::time::timeout(timeout, drain_aborted_joined_tasks(joined)).await {
774        Ok(()) => DrainResult::Drained,
775        Err(_) => DrainResult::TimedOut,
776    }
777}
778
779async fn abort_and_drain_joined_tasks_or_log(
780    joined: &mut FuturesUnordered<impl Future<Output = JoinedRuntimeTask>>,
781    abort_handles: Vec<AbortHandle>,
782    timeout: Duration,
783) {
784    if matches!(
785        abort_and_drain_joined_tasks_with_timeout(joined, abort_handles, timeout).await,
786        DrainResult::TimedOut
787    ) {
788        log_abort_drain_timeout(timeout);
789    }
790}
791
792async fn drain_aborted_joined_tasks(
793    joined: &mut FuturesUnordered<impl Future<Output = JoinedRuntimeTask>>,
794) {
795    while let Some((task, result)) = joined.next().await {
796        match result {
797            Ok(_) => {}
798            Err(source) if source.is_cancelled() => {
799                // Cancellation is expected after this helper aborts the
800                // remaining tasks; the caller returns the timeout or earlier task
801                // failure that triggered the abort.
802            }
803            Err(source) => {
804                log_drained_task_error(RuntimeError::TaskJoin { task, source });
805            }
806        }
807    }
808}
809
810fn log_drained_task_error(error: RuntimeError) {
811    error!(
812        %error,
813        "supervised runtime task failed while draining after an earlier failure"
814    );
815}
816
817fn log_abort_drain_timeout(timeout: Duration) {
818    warn!(
819        ?timeout,
820        "timed out draining aborted supervisor tasks; later task failures may be unobserved"
821    );
822}
823
824fn classify_task_result(task: &'static str, result: RuntimeTaskJoinResult) -> Option<RuntimeError> {
825    match result {
826        Ok(RuntimeTaskExit::Shutdown) => {
827            debug!(task, "supervised runtime task joined after shutdown");
828            None
829        }
830        Ok(RuntimeTaskExit::Completed) => {
831            debug!(task, "supervised runtime task exited before shutdown");
832            Some(RuntimeError::TaskExitedUnexpectedly { task })
833        }
834        Err(source) => {
835            debug!(
836                task,
837                is_cancelled = source.is_cancelled(),
838                is_panic = source.is_panic(),
839                "supervised runtime task join failed"
840            );
841            Some(RuntimeError::TaskJoin { task, source })
842        }
843    }
844}
845
846#[cfg(test)]
847mod tests {
848    use std::sync::Arc;
849    use std::sync::atomic::Ordering;
850    use std::time::Duration;
851
852    use sqlx::postgres::PgPoolOptions;
853    use tokio::time::timeout;
854
855    use super::*;
856    use crate::Error;
857
858    const UNUSED_LAZY_POOL_URL: &str = "postgres://postgres:postgres@127.0.0.1:65535/runledger";
859
860    struct DropFlag(Arc<AtomicBool>);
861
862    impl Drop for DropFlag {
863        fn drop(&mut self) {
864            self.0.store(true, Ordering::SeqCst);
865        }
866    }
867
868    struct CompleteAfterPollSignal {
869        entered_tx: Option<std::sync::mpsc::Sender<()>>,
870        release_rx: std::sync::mpsc::Receiver<()>,
871        exit: RuntimeTaskExit,
872    }
873
874    impl Future for CompleteAfterPollSignal {
875        type Output = RuntimeTaskExit;
876
877        fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
878            let task = self.as_mut().get_mut();
879            if let Some(entered_tx) = task.entered_tx.take() {
880                entered_tx
881                    .send(())
882                    .expect("completion poll entry signal should be received");
883            }
884            task.release_rx
885                .recv()
886                .expect("completion poll should be released");
887            Poll::Ready(task.exit)
888        }
889    }
890
891    fn lazy_pool() -> runledger_postgres::DbPool {
892        PgPoolOptions::new()
893            // The disable-only tests never acquire this pool; this URL is only
894            // a valid PgPool value for supervisor wiring assertions.
895            .connect_lazy(UNUSED_LAZY_POOL_URL)
896            .expect("construct lazy pool")
897    }
898
899    fn test_config() -> JobsConfig {
900        JobsConfig {
901            worker_id: "supervisor-test-worker".to_string(),
902            poll_interval: Duration::from_millis(25),
903            claim_batch_size: 4,
904            lease_ttl_seconds: 10,
905            max_global_concurrency: 4,
906            reaper_interval: Duration::from_millis(50),
907            schedule_poll_interval: Duration::from_millis(50),
908            reaper_retry_delay_ms: 1_000,
909        }
910    }
911
912    fn empty_builder(pool: &runledger_postgres::DbPool) -> SupervisorBuilder<'_> {
913        Supervisor::builder(pool, test_config()).expect("supervisor builder has runtime")
914    }
915
916    fn missing_registry_flags(builder: SupervisorBuilder<'_>) -> (bool, bool) {
917        match builder.build() {
918            Err(RuntimeError::MissingRegistry {
919                worker_enabled,
920                reaper_enabled,
921            }) => (worker_enabled, reaper_enabled),
922            Ok(_) => panic!("missing registry should be a build error"),
923            Err(other) => panic!("expected missing registry error, got {other:?}"),
924        }
925    }
926
927    fn test_task<F>(name: &'static str, future: F) -> RuntimeTask
928    where
929        F: Future<Output = ()> + Send + 'static,
930    {
931        test_task_with_exit(name, RuntimeTaskExit::Completed, future)
932    }
933
934    fn test_shutdown_task<F>(name: &'static str, future: F) -> RuntimeTask
935    where
936        F: Future<Output = ()> + Send + 'static,
937    {
938        test_task_with_exit(name, RuntimeTaskExit::Shutdown, future)
939    }
940
941    fn test_task_with_exit<F>(name: &'static str, exit: RuntimeTaskExit, future: F) -> RuntimeTask
942    where
943        F: Future<Output = ()> + Send + 'static,
944    {
945        RuntimeTask::spawn(name, async move {
946            future.await;
947            exit
948        })
949    }
950
951    fn supervisor_from_shutdown_channel(
952        shutdown_tx: watch::Sender<bool>,
953        shutdown_requested: Arc<AtomicBool>,
954        tasks: Vec<RuntimeTask>,
955    ) -> Supervisor {
956        Supervisor {
957            shutdown_tx,
958            shutdown_requested,
959            tasks,
960        }
961    }
962
963    fn task_names(supervisor: &Supervisor) -> Vec<&'static str> {
964        supervisor.tasks.iter().map(|task| task.name).collect()
965    }
966
967    async fn abort_supervisor_tasks(mut supervisor: Supervisor) {
968        let tasks = std::mem::take(&mut supervisor.tasks);
969        for task in tasks {
970            task.handle.abort();
971            let _ = task.handle.await;
972        }
973    }
974
975    #[tokio::test]
976    async fn builder_defaults_enable_all_loops() {
977        let pool = lazy_pool();
978        let builder = empty_builder(&pool);
979
980        assert!(builder.worker_enabled);
981        assert!(builder.scheduler_enabled);
982        assert!(builder.reaper_enabled);
983        assert!(builder.registry.is_none());
984        assert_eq!(builder.registry_source, None);
985        assert!(!builder.mixed_registry_sources);
986    }
987
988    #[tokio::test]
989    async fn builder_accepts_registry_for_worker_and_reaper_loops() {
990        let pool = lazy_pool();
991        let builder = empty_builder(&pool).with_registry(JobRegistry::new());
992
993        assert!(builder.registry.is_some());
994        assert_eq!(builder.registry_source, Some(RegistrySource::Registry));
995        assert!(!builder.mixed_registry_sources);
996    }
997
998    #[tokio::test]
999    async fn builder_rejects_mixed_registry_sources() {
1000        let pool = lazy_pool();
1001        let registry_then_catalog = empty_builder(&pool)
1002            .with_registry(JobRegistry::new())
1003            .with_catalog(JobCatalog::new())
1004            .disable_worker()
1005            .disable_reaper()
1006            .build();
1007        let Err(registry_then_catalog) = registry_then_catalog else {
1008            panic!("mixed registry sources should be rejected");
1009        };
1010        assert!(matches!(
1011            registry_then_catalog,
1012            RuntimeError::MixedRegistrySources
1013        ));
1014
1015        let catalog_then_registry = empty_builder(&pool)
1016            .with_catalog(JobCatalog::new())
1017            .with_registry(JobRegistry::new())
1018            .disable_worker()
1019            .disable_reaper()
1020            .build();
1021        let Err(catalog_then_registry) = catalog_then_registry else {
1022            panic!("mixed registry sources should be rejected");
1023        };
1024        assert!(matches!(
1025            catalog_then_registry,
1026            RuntimeError::MixedRegistrySources
1027        ));
1028    }
1029
1030    #[tokio::test]
1031    async fn builder_requires_registry_when_worker_or_reaper_is_enabled() {
1032        let pool = lazy_pool();
1033
1034        assert_eq!(missing_registry_flags(empty_builder(&pool)), (true, true));
1035        assert_eq!(
1036            missing_registry_flags(empty_builder(&pool).disable_scheduler().disable_reaper()),
1037            (true, false)
1038        );
1039        assert_eq!(
1040            missing_registry_flags(empty_builder(&pool).disable_worker().disable_scheduler()),
1041            (false, true)
1042        );
1043    }
1044
1045    #[test]
1046    fn builder_requires_tokio_runtime_before_cloning_pool() {
1047        let runtime = tokio::runtime::Runtime::new().expect("construct Tokio runtime");
1048        let pool = runtime.block_on(async { lazy_pool() });
1049        let error = match Supervisor::builder(&pool, test_config()) {
1050            Err(error) => error,
1051            Ok(builder) => {
1052                drop(builder);
1053                runtime.block_on(async {
1054                    pool.close().await;
1055                });
1056                std::mem::forget(pool);
1057                panic!("missing Tokio runtime should be a builder error");
1058            }
1059        };
1060
1061        // The builder was intentionally called outside a runtime to exercise
1062        // the pre-clone runtime check. Close and drop the pool inside the
1063        // temporary runtime so sqlx's own drop precondition does not contaminate
1064        // this assertion.
1065        runtime.block_on(async {
1066            pool.close().await;
1067        });
1068        std::mem::forget(pool);
1069        match error {
1070            RuntimeError::MissingTokioRuntime { .. } => {}
1071            other => panic!("expected missing Tokio runtime error, got {other:?}"),
1072        }
1073    }
1074
1075    #[tokio::test]
1076    async fn builder_can_disable_each_loop() {
1077        let pool = lazy_pool();
1078        let builder = empty_builder(&pool)
1079            .disable_worker()
1080            .disable_scheduler()
1081            .disable_reaper();
1082
1083        assert!(!builder.worker_enabled);
1084        assert!(!builder.scheduler_enabled);
1085        assert!(!builder.reaper_enabled);
1086    }
1087
1088    #[tokio::test]
1089    async fn builder_spawns_only_enabled_tasks() {
1090        let pool = lazy_pool();
1091
1092        let all_disabled = empty_builder(&pool)
1093            .disable_worker()
1094            .disable_scheduler()
1095            .disable_reaper()
1096            .build()
1097            .expect("all-disabled supervisor should build");
1098        assert_eq!(task_names(&all_disabled), Vec::<&'static str>::new());
1099        abort_supervisor_tasks(all_disabled).await;
1100
1101        let scheduler_only = empty_builder(&pool)
1102            .disable_worker()
1103            .disable_reaper()
1104            .build()
1105            .expect("scheduler-only supervisor should not require registry");
1106        assert_eq!(task_names(&scheduler_only), vec![SCHEDULER_TASK]);
1107        abort_supervisor_tasks(scheduler_only).await;
1108
1109        let worker_only = empty_builder(&pool)
1110            .with_registry(JobRegistry::new())
1111            .disable_scheduler()
1112            .disable_reaper()
1113            .build()
1114            .expect("worker-only supervisor should build with registry");
1115        assert_eq!(task_names(&worker_only), vec![WORKER_TASK]);
1116        abort_supervisor_tasks(worker_only).await;
1117
1118        let reaper_only = empty_builder(&pool)
1119            .with_registry(JobRegistry::new())
1120            .disable_worker()
1121            .disable_scheduler()
1122            .build()
1123            .expect("reaper-only supervisor should build with registry");
1124        assert_eq!(task_names(&reaper_only), vec![REAPER_TASK]);
1125        abort_supervisor_tasks(reaper_only).await;
1126
1127        let all_enabled = empty_builder(&pool)
1128            .with_registry(JobRegistry::new())
1129            .build()
1130            .expect("all-enabled supervisor should build with registry");
1131        assert_eq!(
1132            task_names(&all_enabled),
1133            vec![WORKER_TASK, SCHEDULER_TASK, REAPER_TASK]
1134        );
1135        abort_supervisor_tasks(all_enabled).await;
1136    }
1137
1138    #[tokio::test]
1139    async fn all_disabled_supervisor_join_and_shutdown_succeed() {
1140        Supervisor::builder(&lazy_pool(), test_config())
1141            .expect("supervisor builder has runtime")
1142            .disable_worker()
1143            .disable_scheduler()
1144            .disable_reaper()
1145            .build()
1146            .expect("all-disabled supervisor should build")
1147            .join()
1148            .await
1149            .expect("all-disabled supervisor should join");
1150
1151        Supervisor::builder(&lazy_pool(), test_config())
1152            .expect("supervisor builder has runtime")
1153            .disable_worker()
1154            .disable_scheduler()
1155            .disable_reaper()
1156            .build()
1157            .expect("all-disabled supervisor should build")
1158            .shutdown()
1159            .await
1160            .expect("all-disabled supervisor should shut down");
1161    }
1162
1163    #[tokio::test]
1164    async fn shutdown_handle_can_request_shutdown_before_join() {
1165        let supervisor = Supervisor::builder(&lazy_pool(), test_config())
1166            .expect("supervisor builder has runtime")
1167            .disable_worker()
1168            .disable_scheduler()
1169            .disable_reaper()
1170            .build()
1171            .expect("all-disabled supervisor should build");
1172        let shutdown = supervisor.shutdown_handle();
1173        let cloned_shutdown = shutdown.clone();
1174
1175        cloned_shutdown.request_shutdown();
1176
1177        assert!(shutdown.is_shutdown_requested());
1178        assert!(supervisor.is_shutdown_requested());
1179        supervisor
1180            .join()
1181            .await
1182            .expect("supervisor should join after shutdown handle request");
1183    }
1184
1185    #[tokio::test]
1186    async fn shutdown_after_shutdown_handle_request_allows_clean_task_exit() {
1187        let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
1188        let shutdown_requested = Arc::new(AtomicBool::new(false));
1189        let supervisor = supervisor_from_shutdown_channel(
1190            shutdown_tx,
1191            Arc::clone(&shutdown_requested),
1192            vec![test_shutdown_task("cooperative-loop", async move {
1193                while !*shutdown_rx.borrow() {
1194                    if shutdown_rx.changed().await.is_err() {
1195                        break;
1196                    }
1197                }
1198            })],
1199        );
1200        let shutdown = supervisor.shutdown_handle();
1201
1202        shutdown.request_shutdown();
1203
1204        supervisor
1205            .shutdown()
1206            .await
1207            .expect("clean exit after requested shutdown should succeed");
1208    }
1209
1210    #[tokio::test]
1211    async fn run_until_shutdown_requests_shutdown_when_signal_resolves() {
1212        let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
1213        let shutdown_requested = Arc::new(AtomicBool::new(false));
1214        let (signal_tx, signal_rx) = tokio::sync::oneshot::channel();
1215        let supervisor = supervisor_from_shutdown_channel(
1216            shutdown_tx,
1217            Arc::clone(&shutdown_requested),
1218            vec![test_shutdown_task(
1219                "run-until-cooperative-loop",
1220                async move {
1221                    while !*shutdown_rx.borrow() {
1222                        if shutdown_rx.changed().await.is_err() {
1223                            break;
1224                        }
1225                    }
1226                },
1227            )],
1228        );
1229
1230        signal_tx.send(()).expect("signal receiver should be alive");
1231
1232        supervisor
1233            .run_until_shutdown(
1234                async move {
1235                    signal_rx.await.expect("shutdown signal should be sent");
1236                },
1237                Duration::from_secs(1),
1238            )
1239            .await
1240            .expect("resolved shutdown signal should shut down cleanly");
1241        assert!(shutdown_requested.load(Ordering::SeqCst));
1242    }
1243
1244    #[tokio::test]
1245    async fn run_until_shutdown_with_no_tasks_waits_for_signal() {
1246        let supervisor = Supervisor::from_tasks_for_tests(Vec::new());
1247        let (signal_tx, signal_rx) = tokio::sync::oneshot::channel();
1248        let mut run = tokio::spawn(supervisor.run_until_shutdown(
1249            async move {
1250                signal_rx.await.expect("shutdown signal should be sent");
1251            },
1252            Duration::from_secs(1),
1253        ));
1254
1255        assert!(
1256            timeout(Duration::from_millis(50), &mut run).await.is_err(),
1257            "all-disabled supervisor should wait for the shutdown signal"
1258        );
1259
1260        signal_tx.send(()).expect("signal receiver should be alive");
1261        run.await
1262            .expect("run-until-shutdown task should join")
1263            .expect("all-disabled supervisor should complete after signal");
1264    }
1265
1266    #[tokio::test]
1267    async fn run_until_shutdown_reports_task_exit_before_signal() {
1268        let supervisor =
1269            Supervisor::from_tasks_for_tests(vec![test_task("run-until-early-loop", async {})]);
1270
1271        let error = timeout(
1272            Duration::from_secs(1),
1273            supervisor.run_until_shutdown(std::future::pending::<()>(), Duration::from_secs(1)),
1274        )
1275        .await
1276        .expect("task exit should be reported before external signal")
1277        .expect_err("early task exit should fail run-until shutdown");
1278
1279        match error {
1280            Error::Runtime(RuntimeError::TaskExitedUnexpectedly { task }) => {
1281                assert_eq!(task, "run-until-early-loop");
1282            }
1283            other => panic!("expected unexpected task exit, got {other:?}"),
1284        }
1285    }
1286
1287    #[tokio::test]
1288    async fn run_until_shutdown_times_out_and_aborts_after_signal() {
1289        let dropped = Arc::new(AtomicBool::new(false));
1290        let drop_flag = DropFlag(Arc::clone(&dropped));
1291        let supervisor =
1292            Supervisor::from_tasks_for_tests(vec![test_task("run-until-stubborn-loop", async {
1293                let _drop_flag = drop_flag;
1294                std::future::pending::<()>().await;
1295            })]);
1296
1297        let error = supervisor
1298            .run_until_shutdown(async {}, Duration::from_millis(50))
1299            .await
1300            .expect_err("stubborn task should time out after shutdown signal");
1301
1302        match error {
1303            Error::Runtime(RuntimeError::ShutdownTimeout { timeout }) => {
1304                assert_eq!(timeout, Duration::from_millis(50));
1305            }
1306            other => panic!("expected shutdown timeout error, got {other:?}"),
1307        }
1308        assert!(dropped.load(Ordering::SeqCst));
1309    }
1310
1311    #[tokio::test]
1312    async fn run_until_shutdown_reports_task_exit_after_signal_before_deadline() {
1313        let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
1314        let shutdown_requested = Arc::new(AtomicBool::new(false));
1315        let supervisor = supervisor_from_shutdown_channel(
1316            shutdown_tx,
1317            Arc::clone(&shutdown_requested),
1318            vec![test_task("run-until-bad-shutdown-loop", async move {
1319                while !*shutdown_rx.borrow() {
1320                    if shutdown_rx.changed().await.is_err() {
1321                        break;
1322                    }
1323                }
1324            })],
1325        );
1326
1327        let error = supervisor
1328            .run_until_shutdown(async {}, Duration::from_secs(1))
1329            .await
1330            .expect_err("task completion after signal should still be reported");
1331
1332        match error {
1333            Error::Runtime(RuntimeError::TaskExitedUnexpectedly { task }) => {
1334                assert_eq!(task, "run-until-bad-shutdown-loop");
1335            }
1336            other => panic!("expected unexpected task exit, got {other:?}"),
1337        }
1338    }
1339
1340    #[tokio::test]
1341    async fn dropping_supervisor_requests_shutdown_signal() {
1342        let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
1343        let mut observed_shutdown = shutdown_rx.clone();
1344        let shutdown_requested = Arc::new(AtomicBool::new(false));
1345        let supervisor = supervisor_from_shutdown_channel(
1346            shutdown_tx,
1347            Arc::clone(&shutdown_requested),
1348            vec![test_shutdown_task("drop-shutdown-loop", async move {
1349                while !*shutdown_rx.borrow() {
1350                    if shutdown_rx.changed().await.is_err() {
1351                        break;
1352                    }
1353                }
1354            })],
1355        );
1356
1357        drop(supervisor);
1358
1359        timeout(Duration::from_secs(1), observed_shutdown.changed())
1360            .await
1361            .expect("drop should promptly notify shutdown")
1362            .expect("shutdown sender should notify before closing");
1363        assert!(*observed_shutdown.borrow());
1364        assert!(shutdown_requested.load(Ordering::SeqCst));
1365    }
1366
1367    #[tokio::test]
1368    async fn join_reports_task_that_exited_before_late_shutdown_request() {
1369        let (shutdown_tx, _) = watch::channel(false);
1370        let shutdown_requested = Arc::new(AtomicBool::new(false));
1371        let supervisor = supervisor_from_shutdown_channel(
1372            shutdown_tx,
1373            Arc::clone(&shutdown_requested),
1374            vec![test_task("early-before-late-signal", async {})],
1375        );
1376
1377        while !supervisor.tasks[0].handle.is_finished() {
1378            tokio::task::yield_now().await;
1379        }
1380
1381        let shutdown = supervisor.shutdown_handle();
1382        shutdown.request_shutdown();
1383
1384        let error = supervisor
1385            .join()
1386            .await
1387            .expect_err("task exit before shutdown request should still be reported");
1388
1389        match error {
1390            Error::Runtime(RuntimeError::TaskExitedUnexpectedly { task }) => {
1391                assert_eq!(task, "early-before-late-signal");
1392            }
1393            other => panic!("expected unexpected task exit, got {other:?}"),
1394        }
1395    }
1396
1397    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1398    async fn join_reports_task_exit_when_shutdown_races_completion_poll() {
1399        let (entered_tx, entered_rx) = std::sync::mpsc::channel();
1400        let (release_tx, release_rx) = std::sync::mpsc::channel();
1401        let (shutdown_tx, _) = watch::channel(false);
1402        let shutdown_requested = Arc::new(AtomicBool::new(false));
1403        let supervisor = supervisor_from_shutdown_channel(
1404            shutdown_tx,
1405            Arc::clone(&shutdown_requested),
1406            vec![RuntimeTask::spawn(
1407                "race-completion",
1408                CompleteAfterPollSignal {
1409                    entered_tx: Some(entered_tx),
1410                    release_rx,
1411                    exit: RuntimeTaskExit::Completed,
1412                },
1413            )],
1414        );
1415        entered_rx
1416            .recv_timeout(Duration::from_secs(1))
1417            .expect("task should enter its completion poll");
1418        let shutdown = supervisor.shutdown_handle();
1419
1420        shutdown.request_shutdown();
1421        release_tx
1422            .send(())
1423            .expect("completion poll release should be received");
1424
1425        let error = supervisor
1426            .join()
1427            .await
1428            .expect_err("task exit that began before shutdown should be reported");
1429
1430        match error {
1431            Error::Runtime(RuntimeError::TaskExitedUnexpectedly { task }) => {
1432                assert_eq!(task, "race-completion");
1433            }
1434            other => panic!("expected unexpected task exit, got {other:?}"),
1435        }
1436    }
1437
1438    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1439    async fn join_allows_shutdown_exit_when_shutdown_races_completion_poll() {
1440        let (entered_tx, entered_rx) = std::sync::mpsc::channel();
1441        let (release_tx, release_rx) = std::sync::mpsc::channel();
1442        let (shutdown_tx, _) = watch::channel(false);
1443        let shutdown_requested = Arc::new(AtomicBool::new(false));
1444        let supervisor = supervisor_from_shutdown_channel(
1445            shutdown_tx,
1446            Arc::clone(&shutdown_requested),
1447            vec![RuntimeTask::spawn(
1448                "shutdown-race-completion",
1449                CompleteAfterPollSignal {
1450                    entered_tx: Some(entered_tx),
1451                    release_rx,
1452                    exit: RuntimeTaskExit::Shutdown,
1453                },
1454            )],
1455        );
1456        entered_rx
1457            .recv_timeout(Duration::from_secs(1))
1458            .expect("task should enter its completion poll");
1459        let shutdown = supervisor.shutdown_handle();
1460
1461        shutdown.request_shutdown();
1462        release_tx
1463            .send(())
1464            .expect("completion poll release should be received");
1465
1466        supervisor
1467            .join()
1468            .await
1469            .expect("task that reports shutdown should join cleanly");
1470    }
1471
1472    #[tokio::test]
1473    async fn panic_after_shutdown_request_is_reported() {
1474        let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
1475        let shutdown_requested = Arc::new(AtomicBool::new(false));
1476        let supervisor = supervisor_from_shutdown_channel(
1477            shutdown_tx,
1478            Arc::clone(&shutdown_requested),
1479            vec![test_shutdown_task("panic-after-shutdown", async move {
1480                while !*shutdown_rx.borrow() {
1481                    if shutdown_rx.changed().await.is_err() {
1482                        return;
1483                    }
1484                }
1485                panic!("forced post-shutdown panic");
1486            })],
1487        );
1488        let shutdown = supervisor.shutdown_handle();
1489
1490        shutdown.request_shutdown();
1491
1492        let error = supervisor
1493            .shutdown()
1494            .await
1495            .expect_err("panic after requested shutdown should fail");
1496
1497        match error {
1498            Error::Runtime(RuntimeError::TaskJoin { task, source }) => {
1499                assert_eq!(task, "panic-after-shutdown");
1500                assert!(source.is_panic());
1501            }
1502            other => panic!("expected task join error, got {other:?}"),
1503        }
1504    }
1505
1506    #[tokio::test]
1507    async fn early_normal_task_exit_is_unexpected() {
1508        let supervisor = Supervisor::from_tasks_for_tests(vec![test_task("test-loop", async {})]);
1509
1510        let error = supervisor
1511            .join()
1512            .await
1513            .expect_err("early normal exit should fail");
1514
1515        match error {
1516            Error::Runtime(RuntimeError::TaskExitedUnexpectedly { task }) => {
1517                assert_eq!(task, "test-loop");
1518            }
1519            other => panic!("expected unexpected task exit, got {other:?}"),
1520        }
1521    }
1522
1523    #[tokio::test]
1524    async fn shutdown_reports_task_that_exited_before_shutdown_request() {
1525        let supervisor = Supervisor::from_tasks_for_tests(vec![test_task("early-loop", async {})]);
1526
1527        while !supervisor.tasks[0].handle.is_finished() {
1528            tokio::task::yield_now().await;
1529        }
1530
1531        let error = supervisor
1532            .shutdown()
1533            .await
1534            .expect_err("pre-shutdown task exit should fail");
1535
1536        match error {
1537            Error::Runtime(RuntimeError::TaskExitedUnexpectedly { task }) => {
1538                assert_eq!(task, "early-loop");
1539            }
1540            other => panic!("expected unexpected task exit, got {other:?}"),
1541        }
1542    }
1543
1544    #[tokio::test]
1545    async fn pre_shutdown_sweep_consumes_all_already_finished_tasks() {
1546        let mut tasks = vec![
1547            test_task("finished-a", async {}),
1548            test_task("pending", async {
1549                std::future::pending::<()>().await;
1550            }),
1551            test_task("finished-b", async {}),
1552        ];
1553
1554        while tasks
1555            .iter()
1556            .filter(|task| task.name != "pending")
1557            .any(|task| !task.handle.is_finished())
1558        {
1559            tokio::task::yield_now().await;
1560        }
1561
1562        let error = join_pre_shutdown_finished_tasks(&mut tasks)
1563            .await
1564            .expect("finished tasks should produce a pre-shutdown error");
1565
1566        match error {
1567            RuntimeError::TaskExitedUnexpectedly { task } => {
1568                assert!(
1569                    matches!(task, "finished-a" | "finished-b"),
1570                    "unexpected first finished task: {task}"
1571                );
1572            }
1573            other => panic!("expected unexpected task exit, got {other:?}"),
1574        }
1575        assert_eq!(tasks.len(), 1);
1576        assert_eq!(tasks[0].name, "pending");
1577
1578        let pending = tasks.pop().expect("pending task remains");
1579        pending.handle.abort();
1580        // The task was deliberately aborted; this test only needs the join
1581        // handle drained before returning.
1582        let _ = pending.handle.await;
1583    }
1584
1585    #[tokio::test]
1586    async fn pre_shutdown_sweep_allows_explicit_shutdown_exit() {
1587        let mut tasks = vec![test_shutdown_task("finished-after-signal", async {})];
1588        while !tasks[0].handle.is_finished() {
1589            tokio::task::yield_now().await;
1590        }
1591
1592        let error = join_pre_shutdown_finished_tasks(&mut tasks).await;
1593
1594        assert!(error.is_none());
1595        assert!(tasks.is_empty());
1596    }
1597
1598    #[tokio::test]
1599    async fn shutdown_with_timeout_aborts_and_drains_stubborn_task() {
1600        let dropped = Arc::new(AtomicBool::new(false));
1601        let drop_flag = DropFlag(Arc::clone(&dropped));
1602        let supervisor =
1603            Supervisor::from_tasks_for_tests(vec![test_task("stubborn-loop", async move {
1604                let _drop_flag = drop_flag;
1605                std::future::pending::<()>().await;
1606            })]);
1607
1608        let error = supervisor
1609            .shutdown_with_timeout(Duration::from_millis(50))
1610            .await
1611            .expect_err("stubborn task should time out shutdown");
1612
1613        match error {
1614            Error::Runtime(RuntimeError::ShutdownTimeout { timeout }) => {
1615                assert_eq!(timeout, Duration::from_millis(50));
1616            }
1617            other => panic!("expected shutdown timeout error, got {other:?}"),
1618        }
1619        assert!(dropped.load(Ordering::SeqCst));
1620    }
1621
1622    #[tokio::test]
1623    async fn shutdown_with_timeout_rejects_unrepresentable_deadline() {
1624        let error = Supervisor::from_tasks_for_tests(Vec::new())
1625            .shutdown_with_timeout(Duration::MAX)
1626            .await
1627            .expect_err("unrepresentable timeout should fail instead of panicking");
1628
1629        match error {
1630            Error::Runtime(RuntimeError::ShutdownTimeoutTooLarge { timeout }) => {
1631                assert_eq!(timeout, Duration::MAX);
1632            }
1633            other => panic!("expected oversized timeout error, got {other:?}"),
1634        }
1635    }
1636
1637    #[tokio::test]
1638    async fn shutdown_with_zero_timeout_aborts_immediately() {
1639        let supervisor =
1640            Supervisor::from_tasks_for_tests(vec![test_task("zero-timeout-pending-loop", async {
1641                std::future::pending::<()>().await;
1642            })]);
1643
1644        let error = supervisor
1645            .shutdown_with_timeout(Duration::ZERO)
1646            .await
1647            .expect_err("zero timeout should report an immediate shutdown timeout");
1648
1649        match error {
1650            Error::Runtime(RuntimeError::ShutdownTimeout { timeout }) => {
1651                assert_eq!(timeout, Duration::ZERO);
1652            }
1653            other => panic!("expected shutdown timeout error, got {other:?}"),
1654        }
1655    }
1656
1657    #[tokio::test]
1658    async fn shutdown_with_timeout_succeeds_when_task_exits_cooperatively() {
1659        let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
1660        let shutdown_requested = Arc::new(AtomicBool::new(false));
1661        let supervisor = supervisor_from_shutdown_channel(
1662            shutdown_tx,
1663            Arc::clone(&shutdown_requested),
1664            vec![test_shutdown_task("cooperative-timeout-loop", async move {
1665                while !*shutdown_rx.borrow() {
1666                    if shutdown_rx.changed().await.is_err() {
1667                        break;
1668                    }
1669                }
1670            })],
1671        );
1672
1673        supervisor
1674            .shutdown_with_timeout(Duration::from_secs(1))
1675            .await
1676            .expect("cooperative task should shut down before timeout");
1677    }
1678
1679    #[tokio::test]
1680    async fn shutdown_with_timeout_pre_shutdown_error_allows_remaining_task_to_exit() {
1681        let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
1682        let shutdown_requested = Arc::new(AtomicBool::new(false));
1683        let dropped = Arc::new(AtomicBool::new(false));
1684        let drop_flag = DropFlag(Arc::clone(&dropped));
1685        let tasks = vec![
1686            test_task("finished-before-shutdown", async {}),
1687            test_shutdown_task("cooperative-after-error", async move {
1688                let _drop_flag = drop_flag;
1689                while !*shutdown_rx.borrow() {
1690                    if shutdown_rx.changed().await.is_err() {
1691                        break;
1692                    }
1693                }
1694            }),
1695        ];
1696
1697        while !tasks[0].handle.is_finished() {
1698            tokio::task::yield_now().await;
1699        }
1700
1701        let supervisor =
1702            supervisor_from_shutdown_channel(shutdown_tx, Arc::clone(&shutdown_requested), tasks);
1703        let error = supervisor
1704            .shutdown_with_timeout(Duration::from_secs(1))
1705            .await
1706            .expect_err("pre-shutdown task exit should fail");
1707
1708        match error {
1709            Error::Runtime(RuntimeError::TaskExitedUnexpectedly { task }) => {
1710                assert_eq!(task, "finished-before-shutdown");
1711            }
1712            other => panic!("expected pre-shutdown task exit, got {other:?}"),
1713        }
1714        assert!(dropped.load(Ordering::SeqCst));
1715    }
1716
1717    #[tokio::test]
1718    async fn shutdown_with_timeout_reports_timeout_after_pre_shutdown_error() {
1719        let tasks = vec![
1720            test_task("finished-before-shutdown", async {}),
1721            test_task("pending-after-error", async {
1722                std::future::pending::<()>().await;
1723            }),
1724        ];
1725
1726        while !tasks[0].handle.is_finished() {
1727            tokio::task::yield_now().await;
1728        }
1729
1730        let error = Supervisor::from_tasks_for_tests(tasks)
1731            .shutdown_with_timeout(Duration::from_millis(1))
1732            .await
1733            .expect_err("pre-shutdown task exit with stuck drain should time out");
1734
1735        match error {
1736            Error::Runtime(RuntimeError::ShutdownTimeoutAfterTaskError { timeout, source }) => {
1737                assert_eq!(timeout, Duration::from_millis(1));
1738                match *source {
1739                    RuntimeError::TaskExitedUnexpectedly { task } => {
1740                        assert_eq!(task, "finished-before-shutdown");
1741                    }
1742                    other => panic!("expected pre-shutdown task exit source, got {other:?}"),
1743                }
1744            }
1745            other => panic!("expected shutdown timeout after task error, got {other:?}"),
1746        }
1747    }
1748
1749    #[tokio::test]
1750    async fn shutdown_with_timeout_reports_task_error_when_remaining_task_misses_deadline() {
1751        let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
1752        let shutdown_requested = Arc::new(AtomicBool::new(false));
1753        let supervisor = supervisor_from_shutdown_channel(
1754            shutdown_tx,
1755            Arc::clone(&shutdown_requested),
1756            vec![
1757                test_shutdown_task("panic-after-timeout-shutdown", async move {
1758                    while !*shutdown_rx.borrow() {
1759                        if shutdown_rx.changed().await.is_err() {
1760                            return;
1761                        }
1762                    }
1763                    panic!("forced live shutdown panic");
1764                }),
1765                test_shutdown_task("pending-after-timeout-panic", async {
1766                    std::future::pending::<()>().await;
1767                }),
1768            ],
1769        );
1770
1771        let error = supervisor
1772            .shutdown_with_timeout(Duration::from_millis(50))
1773            .await
1774            .expect_err("task failure with stuck drain should preserve task error source");
1775
1776        match error {
1777            Error::Runtime(RuntimeError::ShutdownTimeoutAfterTaskError { timeout, source }) => {
1778                assert_eq!(timeout, Duration::from_millis(50));
1779                match *source {
1780                    RuntimeError::TaskJoin { task, source } => {
1781                        assert_eq!(task, "panic-after-timeout-shutdown");
1782                        assert!(source.is_panic());
1783                    }
1784                    other => panic!("expected task join source, got {other:?}"),
1785                }
1786            }
1787            other => panic!("expected timeout after task join error, got {other:?}"),
1788        }
1789    }
1790
1791    #[tokio::test]
1792    async fn panicked_task_maps_to_task_join_error() {
1793        let supervisor = Supervisor::from_tasks_for_tests(vec![test_task("panic-loop", async {
1794            panic!("forced supervisor test panic");
1795        })]);
1796
1797        let error = supervisor
1798            .join()
1799            .await
1800            .expect_err("panicked task should fail");
1801
1802        match error {
1803            Error::Runtime(RuntimeError::TaskJoin { task, source }) => {
1804                assert_eq!(task, "panic-loop");
1805                assert!(source.is_panic());
1806            }
1807            other => panic!("expected task join error, got {other:?}"),
1808        }
1809    }
1810}