1use std::borrow::Borrow;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::task::{Context, Poll};
7use std::time::Duration;
8
9use futures_util::stream::{FuturesUnordered, StreamExt};
10use tokio::runtime::Handle;
11use tokio::sync::watch;
12use tokio::task::{AbortHandle, JoinError, JoinHandle};
13use tokio::time::Instant;
14use tracing::{Instrument, debug, error, info_span, warn};
15
16use crate::catalog::JobCatalog;
17use crate::config::JobsConfig;
18use crate::reaper::run_reaper_loop;
19use crate::registry::JobRegistry;
20use crate::scheduler::run_scheduler_loop;
21use crate::worker::run_worker_loop;
22use crate::{Error, Result, RuntimeError, RuntimeLoopExit};
23
24const WORKER_TASK: &str = "worker";
25const SCHEDULER_TASK: &str = "scheduler";
26const REAPER_TASK: &str = "reaper";
27const MAX_ABORT_DRAIN_TIMEOUT: Duration = Duration::from_secs(1);
28
29#[must_use]
40pub struct Supervisor {
41 shutdown_tx: watch::Sender<bool>,
42 shutdown_requested: Arc<AtomicBool>,
43 tasks: Vec<RuntimeTask>,
44}
45
46#[must_use]
52pub struct SupervisorBuilder<'a> {
53 pool: &'a runledger_postgres::DbPool,
54 runtime: Handle,
55 registry: Option<JobRegistry>,
56 registry_source: Option<RegistrySource>,
57 mixed_registry_sources: bool,
58 config: JobsConfig,
59 worker_enabled: bool,
60 scheduler_enabled: bool,
61 reaper_enabled: bool,
62}
63
64#[derive(Clone)]
66pub struct SupervisorShutdown {
67 shutdown_tx: watch::Sender<bool>,
68 shutdown_requested: Arc<AtomicBool>,
69}
70
71struct RuntimeTask {
72 name: &'static str,
73 handle: JoinHandle<RuntimeTaskExit>,
74}
75
76#[derive(Clone, Copy, Debug, Eq, PartialEq)]
77enum RegistrySource {
78 Registry,
79 Catalog,
80}
81
82#[derive(Clone, Copy, Debug, Eq, PartialEq)]
83enum RuntimeTaskExit {
84 Completed,
85 Shutdown,
86}
87
88#[derive(Clone, Copy, Debug, Eq, PartialEq)]
89enum DrainResult {
90 Drained,
91 TimedOut,
92}
93
94struct RuntimeTaskFuture {
95 name: &'static str,
96 future: Pin<Box<dyn Future<Output = RuntimeTaskExit> + Send>>,
97 started: bool,
98}
99
100type RuntimeTaskJoinResult = std::result::Result<RuntimeTaskExit, JoinError>;
101type JoinedRuntimeTask = (&'static str, RuntimeTaskJoinResult);
102
103impl Supervisor {
104 pub fn builder(
110 pool: &runledger_postgres::DbPool,
111 config: JobsConfig,
112 ) -> std::result::Result<SupervisorBuilder<'_>, RuntimeError> {
113 let runtime =
114 Handle::try_current().map_err(|source| RuntimeError::MissingTokioRuntime { source })?;
115
116 Ok(SupervisorBuilder {
117 pool,
118 runtime,
119 registry: None,
120 registry_source: None,
121 mixed_registry_sources: false,
122 config,
123 worker_enabled: true,
124 scheduler_enabled: true,
125 reaper_enabled: true,
126 })
127 }
128
129 #[must_use]
132 pub fn shutdown_handle(&self) -> SupervisorShutdown {
133 SupervisorShutdown {
134 shutdown_tx: self.shutdown_tx.clone(),
135 shutdown_requested: Arc::clone(&self.shutdown_requested),
136 }
137 }
138
139 pub fn request_shutdown(&self) {
141 request_shutdown_signal(&self.shutdown_tx, self.shutdown_requested.as_ref());
142 }
143
144 #[must_use]
147 pub fn is_shutdown_requested(&self) -> bool {
148 self.shutdown_requested.load(Ordering::SeqCst)
149 }
150
151 pub async fn join(mut self) -> Result<()> {
161 let tasks = std::mem::take(&mut self.tasks);
162 self.join_tasks(tasks).await
163 }
164
165 pub async fn shutdown(mut self) -> Result<()> {
174 if let Some(error) = join_pre_shutdown_finished_tasks(&mut self.tasks).await {
175 self.request_shutdown();
176 let tasks = std::mem::take(&mut self.tasks);
177 drain_tasks(tasks).await;
178 return Err(Error::Runtime(error));
179 }
180
181 self.request_shutdown();
182 let tasks = std::mem::take(&mut self.tasks);
183 self.join_tasks(tasks).await
184 }
185
186 pub async fn run_until_shutdown<F>(mut self, shutdown: F, timeout: Duration) -> Result<()>
210 where
211 F: Future<Output = ()>,
212 {
213 let _ = shutdown_deadline(timeout)?;
216 let tasks = std::mem::take(&mut self.tasks);
217 if tasks.is_empty() {
218 shutdown.await;
219 self.request_shutdown();
220 return Ok(());
221 }
222
223 let mut abort_handles = Some(task_abort_handles(&tasks));
224 let mut joined = join_runtime_tasks(tasks);
225 let mut shutdown = std::pin::pin!(shutdown);
226
227 loop {
228 tokio::select! {
229 _ = shutdown.as_mut() => {
230 self.request_shutdown();
231 let abort_handles = abort_handles.take().expect("abort handles are consumed on return");
232 let deadline = match shutdown_deadline(timeout) {
236 Ok(deadline) => deadline,
237 Err(error) => {
238 abort_and_drain_joined_tasks_or_log(
239 &mut joined,
240 abort_handles,
241 abort_drain_timeout(timeout),
242 )
243 .await;
244 return Err(error.into());
245 }
246 };
247 return self
248 .join_joined_tasks_with_timeout(
249 &mut joined,
250 abort_handles,
251 timeout,
252 deadline,
253 )
254 .await;
255 }
256 joined_result = joined.next() => {
257 let Some((task, result)) = joined_result else {
258 return Ok(());
259 };
260 let Some(error) = classify_task_result(task, result) else {
261 continue;
262 };
263
264 self.request_shutdown();
265 let abort_handles = abort_handles.take().expect("abort handles are consumed on return");
266 let deadline = match shutdown_deadline(timeout) {
270 Ok(deadline) => deadline,
271 Err(error) => {
272 abort_and_drain_joined_tasks_or_log(
273 &mut joined,
274 abort_handles,
275 abort_drain_timeout(timeout),
276 )
277 .await;
278 return Err(error.into());
279 }
280 };
281 return drain_after_task_error_with_timeout(
282 &mut joined,
283 abort_handles,
284 timeout,
285 deadline,
286 error,
287 )
288 .await;
289 }
290 }
291 }
292 }
293
294 pub async fn shutdown_with_timeout(mut self, timeout: Duration) -> Result<()> {
311 let deadline = shutdown_deadline(timeout)?;
312
313 if let Some(error) = join_pre_shutdown_finished_tasks(&mut self.tasks).await {
314 self.request_shutdown();
315 let tasks = std::mem::take(&mut self.tasks);
316 let abort_handles = task_abort_handles(&tasks);
317 let mut joined = join_runtime_tasks(tasks);
318
319 return drain_after_task_error_with_timeout(
320 &mut joined,
321 abort_handles,
322 timeout,
323 deadline,
324 error,
325 )
326 .await;
327 }
328
329 self.request_shutdown();
330 let tasks = std::mem::take(&mut self.tasks);
331 self.join_tasks_with_timeout(tasks, timeout, deadline).await
332 }
333
334 async fn join_tasks(&self, tasks: Vec<RuntimeTask>) -> Result<()> {
335 let mut joined = join_runtime_tasks(tasks);
336
337 while let Some((task, result)) = joined.next().await {
338 if let Some(error) = classify_task_result(task, result) {
339 self.request_shutdown();
340 drain_joined_tasks(&mut joined).await;
341 return Err(Error::Runtime(error));
342 }
343 }
344
345 Ok(())
346 }
347
348 async fn join_tasks_with_timeout(
349 &self,
350 tasks: Vec<RuntimeTask>,
351 timeout: Duration,
352 deadline: Instant,
353 ) -> Result<()> {
354 let abort_handles = task_abort_handles(&tasks);
355 let mut joined = join_runtime_tasks(tasks);
356
357 self.join_joined_tasks_with_timeout(&mut joined, abort_handles, timeout, deadline)
358 .await
359 }
360
361 async fn join_joined_tasks_with_timeout(
362 &self,
363 joined: &mut FuturesUnordered<impl Future<Output = JoinedRuntimeTask>>,
364 abort_handles: Vec<AbortHandle>,
365 timeout: Duration,
366 deadline: Instant,
367 ) -> Result<()> {
368 loop {
369 match tokio::time::timeout_at(deadline, joined.next()).await {
370 Ok(Some((task, result))) => {
371 if let Some(error) = classify_task_result(task, result) {
372 self.request_shutdown();
373 return drain_after_task_error_with_timeout(
374 joined,
375 abort_handles,
376 timeout,
377 deadline,
378 error,
379 )
380 .await;
381 }
382 }
383 Ok(None) => return Ok(()),
384 Err(_) => {
385 abort_and_drain_joined_tasks_or_log(
386 joined,
387 abort_handles,
388 abort_drain_timeout(timeout),
389 )
390 .await;
391 return Err(Error::Runtime(RuntimeError::ShutdownTimeout { timeout }));
392 }
393 }
394 }
395 }
396
397 #[cfg(test)]
398 fn from_tasks_for_tests(tasks: Vec<RuntimeTask>) -> Self {
399 let (shutdown_tx, _) = watch::channel(false);
402 Self {
403 shutdown_tx,
404 shutdown_requested: Arc::new(AtomicBool::new(false)),
405 tasks,
406 }
407 }
408}
409
410impl Drop for Supervisor {
411 fn drop(&mut self) {
412 if !self.tasks.is_empty() {
413 warn!(
414 task_count = self.tasks.len(),
415 "dropping jobs runtime supervisor before joining tasks; tasks may continue detached after shutdown is requested and later panics will not be observed"
416 );
417 }
418 self.request_shutdown();
420 }
421}
422
423impl<'a> SupervisorBuilder<'a> {
424 #[must_use = "builder methods return an updated builder value"]
429 pub fn with_registry(mut self, registry: JobRegistry) -> Self {
430 self.mixed_registry_sources |= self.registry_source == Some(RegistrySource::Catalog);
431 self.registry_source = Some(RegistrySource::Registry);
432 self.registry = Some(registry);
433 self
434 }
435
436 #[must_use = "builder methods return an updated builder value"]
449 pub fn with_catalog(mut self, catalog: impl Borrow<JobCatalog>) -> Self {
450 self.mixed_registry_sources |= self.registry_source == Some(RegistrySource::Registry);
451 self.registry_source = Some(RegistrySource::Catalog);
452 self.registry = Some(catalog.borrow().to_registry());
453 self
454 }
455
456 #[must_use = "builder methods return an updated builder value"]
458 pub fn disable_worker(mut self) -> Self {
459 self.worker_enabled = false;
460 self
461 }
462
463 #[must_use = "builder methods return an updated builder value"]
465 pub fn disable_scheduler(mut self) -> Self {
466 self.scheduler_enabled = false;
467 self
468 }
469
470 #[must_use = "builder methods return an updated builder value"]
472 pub fn disable_reaper(mut self) -> Self {
473 self.reaper_enabled = false;
474 self
475 }
476
477 pub fn build(self) -> std::result::Result<Supervisor, RuntimeError> {
482 let Self {
483 pool,
484 runtime,
485 registry,
486 registry_source: _,
487 mixed_registry_sources,
488 config,
489 worker_enabled,
490 scheduler_enabled,
491 reaper_enabled,
492 } = self;
493
494 if mixed_registry_sources {
495 return Err(RuntimeError::MixedRegistrySources);
496 }
497
498 let registry = match registry {
499 Some(registry) => registry,
500 None if worker_enabled || reaper_enabled => {
501 return Err(RuntimeError::MissingRegistry {
502 worker_enabled,
503 reaper_enabled,
504 });
505 }
506 None => JobRegistry::new(),
507 };
508
509 let (shutdown_tx, shutdown_rx) = watch::channel(false);
510 let shutdown_requested = Arc::new(AtomicBool::new(false));
511 let mut tasks = Vec::new();
512
513 if worker_enabled {
514 tasks.push(RuntimeTask::spawn_on(&runtime, WORKER_TASK, {
515 let pool = pool.clone();
516 let registry = registry.clone();
517 let config = config.clone();
518 let shutdown_rx = shutdown_rx.clone();
519 async move { run_worker_loop(pool, registry, config, shutdown_rx).await }
520 }));
521 }
522
523 if scheduler_enabled {
524 tasks.push(RuntimeTask::spawn_on(&runtime, SCHEDULER_TASK, {
525 let pool = pool.clone();
526 let config = config.clone();
527 let shutdown_rx = shutdown_rx.clone();
528 async move { run_scheduler_loop(pool, config, shutdown_rx).await }
529 }));
530 }
531
532 if reaper_enabled {
533 let pool = pool.clone();
534 let registry = registry.clone();
535 let config = config.clone();
536 let shutdown_rx = shutdown_rx.clone();
537 tasks.push(RuntimeTask::spawn_on(&runtime, REAPER_TASK, async move {
538 run_reaper_loop(pool, registry, config, shutdown_rx).await
539 }));
540 }
541
542 Ok(Supervisor {
543 shutdown_tx,
544 shutdown_requested,
545 tasks,
546 })
547 }
548}
549
550impl SupervisorShutdown {
551 pub fn request_shutdown(&self) {
553 request_shutdown_signal(&self.shutdown_tx, self.shutdown_requested.as_ref());
554 }
555
556 #[must_use]
558 pub fn is_shutdown_requested(&self) -> bool {
559 self.shutdown_requested.load(Ordering::SeqCst)
560 }
561}
562
563impl RuntimeTask {
564 fn spawn_on<F>(runtime: &Handle, name: &'static str, future: F) -> Self
565 where
566 F: Future<Output = RuntimeLoopExit> + Send + 'static,
567 {
568 let span = info_span!("runledger_runtime_supervisor_task", task = name);
569 Self {
570 name,
571 handle: runtime.spawn(
572 RuntimeTaskFuture::new(name, async move { future.await.into() }).instrument(span),
573 ),
574 }
575 }
576
577 #[cfg(test)]
578 fn spawn<F>(name: &'static str, future: F) -> Self
579 where
580 F: Future<Output = RuntimeTaskExit> + Send + 'static,
581 {
582 Self {
583 name,
584 handle: tokio::spawn(RuntimeTaskFuture::new(name, future)),
585 }
586 }
587
588 async fn await_result(self) -> RuntimeTaskJoinResult {
589 self.handle.await
590 }
591}
592
593impl RuntimeTaskFuture {
594 fn new<F>(name: &'static str, future: F) -> Self
595 where
596 F: Future<Output = RuntimeTaskExit> + Send + 'static,
597 {
598 Self {
599 name,
600 future: Box::pin(future),
601 started: false,
602 }
603 }
604}
605
606impl Future for RuntimeTaskFuture {
607 type Output = RuntimeTaskExit;
608
609 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
610 let task = self.as_mut().get_mut();
611 if !task.started {
612 task.started = true;
613 debug!(task = task.name, "supervised runtime task started");
614 }
615
616 match task.future.as_mut().poll(cx) {
617 Poll::Pending => Poll::Pending,
618 Poll::Ready(exit) => {
619 debug!(task = task.name, ?exit, "supervised runtime task completed");
620 Poll::Ready(exit)
621 }
622 }
623 }
624}
625
626impl From<RuntimeLoopExit> for RuntimeTaskExit {
627 fn from(exit: RuntimeLoopExit) -> Self {
628 match exit {
629 RuntimeLoopExit::Shutdown => Self::Shutdown,
630 RuntimeLoopExit::Completed => Self::Completed,
631 }
632 }
633}
634
635fn request_shutdown_signal(shutdown_tx: &watch::Sender<bool>, shutdown_requested: &AtomicBool) {
636 if !shutdown_requested.swap(true, Ordering::SeqCst) {
637 let _ = shutdown_tx.send(true);
642 }
643}
644
645fn take_finished_tasks(tasks: &mut Vec<RuntimeTask>) -> Vec<RuntimeTask> {
646 let mut finished = Vec::new();
647 let mut index = 0;
648 while index < tasks.len() {
649 if tasks[index].handle.is_finished() {
650 finished.push(tasks.swap_remove(index));
653 } else {
654 index += 1;
655 }
656 }
657 finished
658}
659
660async fn join_pre_shutdown_finished_tasks(tasks: &mut Vec<RuntimeTask>) -> Option<RuntimeError> {
661 let finished = take_finished_tasks(tasks);
662 let mut first_error = None;
663
664 for task in finished {
665 let task_name = task.name;
666 let result = task.await_result().await;
667 let Some(error) = classify_task_result(task_name, result) else {
668 continue;
669 };
670
671 if first_error.is_none() {
672 first_error = Some(error);
673 } else {
674 log_drained_task_error(error);
675 }
676 }
677
678 first_error
679}
680
681fn join_runtime_tasks(
682 tasks: Vec<RuntimeTask>,
683) -> FuturesUnordered<impl Future<Output = JoinedRuntimeTask>> {
684 tasks
685 .into_iter()
686 .map(|task| async move {
687 let name = task.name;
688 (name, task.await_result().await)
689 })
690 .collect()
691}
692
693async fn drain_tasks(tasks: Vec<RuntimeTask>) {
694 let mut joined = join_runtime_tasks(tasks);
695 drain_joined_tasks(&mut joined).await;
696}
697
698fn task_abort_handles(tasks: &[RuntimeTask]) -> Vec<AbortHandle> {
699 tasks
700 .iter()
701 .map(|task| task.handle.abort_handle())
702 .collect()
703}
704
705fn shutdown_deadline(timeout: Duration) -> std::result::Result<Instant, RuntimeError> {
706 Instant::now()
707 .checked_add(timeout)
708 .ok_or(RuntimeError::ShutdownTimeoutTooLarge { timeout })
709}
710
711fn abort_drain_timeout(timeout: Duration) -> Duration {
712 timeout.min(MAX_ABORT_DRAIN_TIMEOUT)
713}
714
715async fn drain_joined_tasks(
716 joined: &mut FuturesUnordered<impl Future<Output = JoinedRuntimeTask>>,
717) {
718 while let Some((task, result)) = joined.next().await {
719 if let Some(error) = classify_task_result(task, result) {
720 log_drained_task_error(error);
721 }
722 }
723}
724
725async fn drain_joined_tasks_until_deadline(
726 joined: &mut FuturesUnordered<impl Future<Output = JoinedRuntimeTask>>,
727 deadline: Instant,
728) -> DrainResult {
729 loop {
730 match tokio::time::timeout_at(deadline, joined.next()).await {
731 Ok(Some((task, result))) => {
732 if let Some(error) = classify_task_result(task, result) {
733 log_drained_task_error(error);
734 }
735 }
736 Ok(None) => return DrainResult::Drained,
737 Err(_) => return DrainResult::TimedOut,
738 }
739 }
740}
741
742async fn drain_after_task_error_with_timeout(
743 joined: &mut FuturesUnordered<impl Future<Output = JoinedRuntimeTask>>,
744 abort_handles: Vec<AbortHandle>,
745 timeout: Duration,
746 deadline: Instant,
747 error: RuntimeError,
748) -> Result<()> {
749 if matches!(
750 drain_joined_tasks_until_deadline(joined, deadline).await,
751 DrainResult::Drained
752 ) {
753 return Err(error.into());
754 }
755
756 abort_and_drain_joined_tasks_or_log(joined, abort_handles, abort_drain_timeout(timeout)).await;
757 Err(RuntimeError::ShutdownTimeoutAfterTaskError {
758 timeout,
759 source: Box::new(error),
760 }
761 .into())
762}
763
764async fn abort_and_drain_joined_tasks_with_timeout(
765 joined: &mut FuturesUnordered<impl Future<Output = JoinedRuntimeTask>>,
766 abort_handles: Vec<AbortHandle>,
767 timeout: Duration,
768) -> DrainResult {
769 for abort_handle in abort_handles {
770 abort_handle.abort();
771 }
772
773 match tokio::time::timeout(timeout, drain_aborted_joined_tasks(joined)).await {
774 Ok(()) => DrainResult::Drained,
775 Err(_) => DrainResult::TimedOut,
776 }
777}
778
779async fn abort_and_drain_joined_tasks_or_log(
780 joined: &mut FuturesUnordered<impl Future<Output = JoinedRuntimeTask>>,
781 abort_handles: Vec<AbortHandle>,
782 timeout: Duration,
783) {
784 if matches!(
785 abort_and_drain_joined_tasks_with_timeout(joined, abort_handles, timeout).await,
786 DrainResult::TimedOut
787 ) {
788 log_abort_drain_timeout(timeout);
789 }
790}
791
792async fn drain_aborted_joined_tasks(
793 joined: &mut FuturesUnordered<impl Future<Output = JoinedRuntimeTask>>,
794) {
795 while let Some((task, result)) = joined.next().await {
796 match result {
797 Ok(_) => {}
798 Err(source) if source.is_cancelled() => {
799 }
803 Err(source) => {
804 log_drained_task_error(RuntimeError::TaskJoin { task, source });
805 }
806 }
807 }
808}
809
810fn log_drained_task_error(error: RuntimeError) {
811 error!(
812 %error,
813 "supervised runtime task failed while draining after an earlier failure"
814 );
815}
816
817fn log_abort_drain_timeout(timeout: Duration) {
818 warn!(
819 ?timeout,
820 "timed out draining aborted supervisor tasks; later task failures may be unobserved"
821 );
822}
823
824fn classify_task_result(task: &'static str, result: RuntimeTaskJoinResult) -> Option<RuntimeError> {
825 match result {
826 Ok(RuntimeTaskExit::Shutdown) => {
827 debug!(task, "supervised runtime task joined after shutdown");
828 None
829 }
830 Ok(RuntimeTaskExit::Completed) => {
831 debug!(task, "supervised runtime task exited before shutdown");
832 Some(RuntimeError::TaskExitedUnexpectedly { task })
833 }
834 Err(source) => {
835 debug!(
836 task,
837 is_cancelled = source.is_cancelled(),
838 is_panic = source.is_panic(),
839 "supervised runtime task join failed"
840 );
841 Some(RuntimeError::TaskJoin { task, source })
842 }
843 }
844}
845
846#[cfg(test)]
847mod tests {
848 use std::sync::Arc;
849 use std::sync::atomic::Ordering;
850 use std::time::Duration;
851
852 use sqlx::postgres::PgPoolOptions;
853 use tokio::time::timeout;
854
855 use super::*;
856 use crate::Error;
857
858 const UNUSED_LAZY_POOL_URL: &str = "postgres://postgres:postgres@127.0.0.1:65535/runledger";
859
860 struct DropFlag(Arc<AtomicBool>);
861
862 impl Drop for DropFlag {
863 fn drop(&mut self) {
864 self.0.store(true, Ordering::SeqCst);
865 }
866 }
867
868 struct CompleteAfterPollSignal {
869 entered_tx: Option<std::sync::mpsc::Sender<()>>,
870 release_rx: std::sync::mpsc::Receiver<()>,
871 exit: RuntimeTaskExit,
872 }
873
874 impl Future for CompleteAfterPollSignal {
875 type Output = RuntimeTaskExit;
876
877 fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
878 let task = self.as_mut().get_mut();
879 if let Some(entered_tx) = task.entered_tx.take() {
880 entered_tx
881 .send(())
882 .expect("completion poll entry signal should be received");
883 }
884 task.release_rx
885 .recv()
886 .expect("completion poll should be released");
887 Poll::Ready(task.exit)
888 }
889 }
890
891 fn lazy_pool() -> runledger_postgres::DbPool {
892 PgPoolOptions::new()
893 .connect_lazy(UNUSED_LAZY_POOL_URL)
896 .expect("construct lazy pool")
897 }
898
899 fn test_config() -> JobsConfig {
900 JobsConfig {
901 worker_id: "supervisor-test-worker".to_string(),
902 poll_interval: Duration::from_millis(25),
903 claim_batch_size: 4,
904 lease_ttl_seconds: 10,
905 max_global_concurrency: 4,
906 reaper_interval: Duration::from_millis(50),
907 schedule_poll_interval: Duration::from_millis(50),
908 reaper_retry_delay_ms: 1_000,
909 }
910 }
911
912 fn empty_builder(pool: &runledger_postgres::DbPool) -> SupervisorBuilder<'_> {
913 Supervisor::builder(pool, test_config()).expect("supervisor builder has runtime")
914 }
915
916 fn missing_registry_flags(builder: SupervisorBuilder<'_>) -> (bool, bool) {
917 match builder.build() {
918 Err(RuntimeError::MissingRegistry {
919 worker_enabled,
920 reaper_enabled,
921 }) => (worker_enabled, reaper_enabled),
922 Ok(_) => panic!("missing registry should be a build error"),
923 Err(other) => panic!("expected missing registry error, got {other:?}"),
924 }
925 }
926
927 fn test_task<F>(name: &'static str, future: F) -> RuntimeTask
928 where
929 F: Future<Output = ()> + Send + 'static,
930 {
931 test_task_with_exit(name, RuntimeTaskExit::Completed, future)
932 }
933
934 fn test_shutdown_task<F>(name: &'static str, future: F) -> RuntimeTask
935 where
936 F: Future<Output = ()> + Send + 'static,
937 {
938 test_task_with_exit(name, RuntimeTaskExit::Shutdown, future)
939 }
940
941 fn test_task_with_exit<F>(name: &'static str, exit: RuntimeTaskExit, future: F) -> RuntimeTask
942 where
943 F: Future<Output = ()> + Send + 'static,
944 {
945 RuntimeTask::spawn(name, async move {
946 future.await;
947 exit
948 })
949 }
950
951 fn supervisor_from_shutdown_channel(
952 shutdown_tx: watch::Sender<bool>,
953 shutdown_requested: Arc<AtomicBool>,
954 tasks: Vec<RuntimeTask>,
955 ) -> Supervisor {
956 Supervisor {
957 shutdown_tx,
958 shutdown_requested,
959 tasks,
960 }
961 }
962
963 fn task_names(supervisor: &Supervisor) -> Vec<&'static str> {
964 supervisor.tasks.iter().map(|task| task.name).collect()
965 }
966
967 async fn abort_supervisor_tasks(mut supervisor: Supervisor) {
968 let tasks = std::mem::take(&mut supervisor.tasks);
969 for task in tasks {
970 task.handle.abort();
971 let _ = task.handle.await;
972 }
973 }
974
975 #[tokio::test]
976 async fn builder_defaults_enable_all_loops() {
977 let pool = lazy_pool();
978 let builder = empty_builder(&pool);
979
980 assert!(builder.worker_enabled);
981 assert!(builder.scheduler_enabled);
982 assert!(builder.reaper_enabled);
983 assert!(builder.registry.is_none());
984 assert_eq!(builder.registry_source, None);
985 assert!(!builder.mixed_registry_sources);
986 }
987
988 #[tokio::test]
989 async fn builder_accepts_registry_for_worker_and_reaper_loops() {
990 let pool = lazy_pool();
991 let builder = empty_builder(&pool).with_registry(JobRegistry::new());
992
993 assert!(builder.registry.is_some());
994 assert_eq!(builder.registry_source, Some(RegistrySource::Registry));
995 assert!(!builder.mixed_registry_sources);
996 }
997
998 #[tokio::test]
999 async fn builder_rejects_mixed_registry_sources() {
1000 let pool = lazy_pool();
1001 let registry_then_catalog = empty_builder(&pool)
1002 .with_registry(JobRegistry::new())
1003 .with_catalog(JobCatalog::new())
1004 .disable_worker()
1005 .disable_reaper()
1006 .build();
1007 let Err(registry_then_catalog) = registry_then_catalog else {
1008 panic!("mixed registry sources should be rejected");
1009 };
1010 assert!(matches!(
1011 registry_then_catalog,
1012 RuntimeError::MixedRegistrySources
1013 ));
1014
1015 let catalog_then_registry = empty_builder(&pool)
1016 .with_catalog(JobCatalog::new())
1017 .with_registry(JobRegistry::new())
1018 .disable_worker()
1019 .disable_reaper()
1020 .build();
1021 let Err(catalog_then_registry) = catalog_then_registry else {
1022 panic!("mixed registry sources should be rejected");
1023 };
1024 assert!(matches!(
1025 catalog_then_registry,
1026 RuntimeError::MixedRegistrySources
1027 ));
1028 }
1029
1030 #[tokio::test]
1031 async fn builder_requires_registry_when_worker_or_reaper_is_enabled() {
1032 let pool = lazy_pool();
1033
1034 assert_eq!(missing_registry_flags(empty_builder(&pool)), (true, true));
1035 assert_eq!(
1036 missing_registry_flags(empty_builder(&pool).disable_scheduler().disable_reaper()),
1037 (true, false)
1038 );
1039 assert_eq!(
1040 missing_registry_flags(empty_builder(&pool).disable_worker().disable_scheduler()),
1041 (false, true)
1042 );
1043 }
1044
1045 #[test]
1046 fn builder_requires_tokio_runtime_before_cloning_pool() {
1047 let runtime = tokio::runtime::Runtime::new().expect("construct Tokio runtime");
1048 let pool = runtime.block_on(async { lazy_pool() });
1049 let error = match Supervisor::builder(&pool, test_config()) {
1050 Err(error) => error,
1051 Ok(builder) => {
1052 drop(builder);
1053 runtime.block_on(async {
1054 pool.close().await;
1055 });
1056 std::mem::forget(pool);
1057 panic!("missing Tokio runtime should be a builder error");
1058 }
1059 };
1060
1061 runtime.block_on(async {
1066 pool.close().await;
1067 });
1068 std::mem::forget(pool);
1069 match error {
1070 RuntimeError::MissingTokioRuntime { .. } => {}
1071 other => panic!("expected missing Tokio runtime error, got {other:?}"),
1072 }
1073 }
1074
1075 #[tokio::test]
1076 async fn builder_can_disable_each_loop() {
1077 let pool = lazy_pool();
1078 let builder = empty_builder(&pool)
1079 .disable_worker()
1080 .disable_scheduler()
1081 .disable_reaper();
1082
1083 assert!(!builder.worker_enabled);
1084 assert!(!builder.scheduler_enabled);
1085 assert!(!builder.reaper_enabled);
1086 }
1087
1088 #[tokio::test]
1089 async fn builder_spawns_only_enabled_tasks() {
1090 let pool = lazy_pool();
1091
1092 let all_disabled = empty_builder(&pool)
1093 .disable_worker()
1094 .disable_scheduler()
1095 .disable_reaper()
1096 .build()
1097 .expect("all-disabled supervisor should build");
1098 assert_eq!(task_names(&all_disabled), Vec::<&'static str>::new());
1099 abort_supervisor_tasks(all_disabled).await;
1100
1101 let scheduler_only = empty_builder(&pool)
1102 .disable_worker()
1103 .disable_reaper()
1104 .build()
1105 .expect("scheduler-only supervisor should not require registry");
1106 assert_eq!(task_names(&scheduler_only), vec![SCHEDULER_TASK]);
1107 abort_supervisor_tasks(scheduler_only).await;
1108
1109 let worker_only = empty_builder(&pool)
1110 .with_registry(JobRegistry::new())
1111 .disable_scheduler()
1112 .disable_reaper()
1113 .build()
1114 .expect("worker-only supervisor should build with registry");
1115 assert_eq!(task_names(&worker_only), vec![WORKER_TASK]);
1116 abort_supervisor_tasks(worker_only).await;
1117
1118 let reaper_only = empty_builder(&pool)
1119 .with_registry(JobRegistry::new())
1120 .disable_worker()
1121 .disable_scheduler()
1122 .build()
1123 .expect("reaper-only supervisor should build with registry");
1124 assert_eq!(task_names(&reaper_only), vec![REAPER_TASK]);
1125 abort_supervisor_tasks(reaper_only).await;
1126
1127 let all_enabled = empty_builder(&pool)
1128 .with_registry(JobRegistry::new())
1129 .build()
1130 .expect("all-enabled supervisor should build with registry");
1131 assert_eq!(
1132 task_names(&all_enabled),
1133 vec![WORKER_TASK, SCHEDULER_TASK, REAPER_TASK]
1134 );
1135 abort_supervisor_tasks(all_enabled).await;
1136 }
1137
1138 #[tokio::test]
1139 async fn all_disabled_supervisor_join_and_shutdown_succeed() {
1140 Supervisor::builder(&lazy_pool(), test_config())
1141 .expect("supervisor builder has runtime")
1142 .disable_worker()
1143 .disable_scheduler()
1144 .disable_reaper()
1145 .build()
1146 .expect("all-disabled supervisor should build")
1147 .join()
1148 .await
1149 .expect("all-disabled supervisor should join");
1150
1151 Supervisor::builder(&lazy_pool(), test_config())
1152 .expect("supervisor builder has runtime")
1153 .disable_worker()
1154 .disable_scheduler()
1155 .disable_reaper()
1156 .build()
1157 .expect("all-disabled supervisor should build")
1158 .shutdown()
1159 .await
1160 .expect("all-disabled supervisor should shut down");
1161 }
1162
1163 #[tokio::test]
1164 async fn shutdown_handle_can_request_shutdown_before_join() {
1165 let supervisor = Supervisor::builder(&lazy_pool(), test_config())
1166 .expect("supervisor builder has runtime")
1167 .disable_worker()
1168 .disable_scheduler()
1169 .disable_reaper()
1170 .build()
1171 .expect("all-disabled supervisor should build");
1172 let shutdown = supervisor.shutdown_handle();
1173 let cloned_shutdown = shutdown.clone();
1174
1175 cloned_shutdown.request_shutdown();
1176
1177 assert!(shutdown.is_shutdown_requested());
1178 assert!(supervisor.is_shutdown_requested());
1179 supervisor
1180 .join()
1181 .await
1182 .expect("supervisor should join after shutdown handle request");
1183 }
1184
1185 #[tokio::test]
1186 async fn shutdown_after_shutdown_handle_request_allows_clean_task_exit() {
1187 let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
1188 let shutdown_requested = Arc::new(AtomicBool::new(false));
1189 let supervisor = supervisor_from_shutdown_channel(
1190 shutdown_tx,
1191 Arc::clone(&shutdown_requested),
1192 vec![test_shutdown_task("cooperative-loop", async move {
1193 while !*shutdown_rx.borrow() {
1194 if shutdown_rx.changed().await.is_err() {
1195 break;
1196 }
1197 }
1198 })],
1199 );
1200 let shutdown = supervisor.shutdown_handle();
1201
1202 shutdown.request_shutdown();
1203
1204 supervisor
1205 .shutdown()
1206 .await
1207 .expect("clean exit after requested shutdown should succeed");
1208 }
1209
1210 #[tokio::test]
1211 async fn run_until_shutdown_requests_shutdown_when_signal_resolves() {
1212 let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
1213 let shutdown_requested = Arc::new(AtomicBool::new(false));
1214 let (signal_tx, signal_rx) = tokio::sync::oneshot::channel();
1215 let supervisor = supervisor_from_shutdown_channel(
1216 shutdown_tx,
1217 Arc::clone(&shutdown_requested),
1218 vec![test_shutdown_task(
1219 "run-until-cooperative-loop",
1220 async move {
1221 while !*shutdown_rx.borrow() {
1222 if shutdown_rx.changed().await.is_err() {
1223 break;
1224 }
1225 }
1226 },
1227 )],
1228 );
1229
1230 signal_tx.send(()).expect("signal receiver should be alive");
1231
1232 supervisor
1233 .run_until_shutdown(
1234 async move {
1235 signal_rx.await.expect("shutdown signal should be sent");
1236 },
1237 Duration::from_secs(1),
1238 )
1239 .await
1240 .expect("resolved shutdown signal should shut down cleanly");
1241 assert!(shutdown_requested.load(Ordering::SeqCst));
1242 }
1243
1244 #[tokio::test]
1245 async fn run_until_shutdown_with_no_tasks_waits_for_signal() {
1246 let supervisor = Supervisor::from_tasks_for_tests(Vec::new());
1247 let (signal_tx, signal_rx) = tokio::sync::oneshot::channel();
1248 let mut run = tokio::spawn(supervisor.run_until_shutdown(
1249 async move {
1250 signal_rx.await.expect("shutdown signal should be sent");
1251 },
1252 Duration::from_secs(1),
1253 ));
1254
1255 assert!(
1256 timeout(Duration::from_millis(50), &mut run).await.is_err(),
1257 "all-disabled supervisor should wait for the shutdown signal"
1258 );
1259
1260 signal_tx.send(()).expect("signal receiver should be alive");
1261 run.await
1262 .expect("run-until-shutdown task should join")
1263 .expect("all-disabled supervisor should complete after signal");
1264 }
1265
1266 #[tokio::test]
1267 async fn run_until_shutdown_reports_task_exit_before_signal() {
1268 let supervisor =
1269 Supervisor::from_tasks_for_tests(vec![test_task("run-until-early-loop", async {})]);
1270
1271 let error = timeout(
1272 Duration::from_secs(1),
1273 supervisor.run_until_shutdown(std::future::pending::<()>(), Duration::from_secs(1)),
1274 )
1275 .await
1276 .expect("task exit should be reported before external signal")
1277 .expect_err("early task exit should fail run-until shutdown");
1278
1279 match error {
1280 Error::Runtime(RuntimeError::TaskExitedUnexpectedly { task }) => {
1281 assert_eq!(task, "run-until-early-loop");
1282 }
1283 other => panic!("expected unexpected task exit, got {other:?}"),
1284 }
1285 }
1286
1287 #[tokio::test]
1288 async fn run_until_shutdown_times_out_and_aborts_after_signal() {
1289 let dropped = Arc::new(AtomicBool::new(false));
1290 let drop_flag = DropFlag(Arc::clone(&dropped));
1291 let supervisor =
1292 Supervisor::from_tasks_for_tests(vec![test_task("run-until-stubborn-loop", async {
1293 let _drop_flag = drop_flag;
1294 std::future::pending::<()>().await;
1295 })]);
1296
1297 let error = supervisor
1298 .run_until_shutdown(async {}, Duration::from_millis(50))
1299 .await
1300 .expect_err("stubborn task should time out after shutdown signal");
1301
1302 match error {
1303 Error::Runtime(RuntimeError::ShutdownTimeout { timeout }) => {
1304 assert_eq!(timeout, Duration::from_millis(50));
1305 }
1306 other => panic!("expected shutdown timeout error, got {other:?}"),
1307 }
1308 assert!(dropped.load(Ordering::SeqCst));
1309 }
1310
1311 #[tokio::test]
1312 async fn run_until_shutdown_reports_task_exit_after_signal_before_deadline() {
1313 let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
1314 let shutdown_requested = Arc::new(AtomicBool::new(false));
1315 let supervisor = supervisor_from_shutdown_channel(
1316 shutdown_tx,
1317 Arc::clone(&shutdown_requested),
1318 vec![test_task("run-until-bad-shutdown-loop", async move {
1319 while !*shutdown_rx.borrow() {
1320 if shutdown_rx.changed().await.is_err() {
1321 break;
1322 }
1323 }
1324 })],
1325 );
1326
1327 let error = supervisor
1328 .run_until_shutdown(async {}, Duration::from_secs(1))
1329 .await
1330 .expect_err("task completion after signal should still be reported");
1331
1332 match error {
1333 Error::Runtime(RuntimeError::TaskExitedUnexpectedly { task }) => {
1334 assert_eq!(task, "run-until-bad-shutdown-loop");
1335 }
1336 other => panic!("expected unexpected task exit, got {other:?}"),
1337 }
1338 }
1339
1340 #[tokio::test]
1341 async fn dropping_supervisor_requests_shutdown_signal() {
1342 let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
1343 let mut observed_shutdown = shutdown_rx.clone();
1344 let shutdown_requested = Arc::new(AtomicBool::new(false));
1345 let supervisor = supervisor_from_shutdown_channel(
1346 shutdown_tx,
1347 Arc::clone(&shutdown_requested),
1348 vec![test_shutdown_task("drop-shutdown-loop", async move {
1349 while !*shutdown_rx.borrow() {
1350 if shutdown_rx.changed().await.is_err() {
1351 break;
1352 }
1353 }
1354 })],
1355 );
1356
1357 drop(supervisor);
1358
1359 timeout(Duration::from_secs(1), observed_shutdown.changed())
1360 .await
1361 .expect("drop should promptly notify shutdown")
1362 .expect("shutdown sender should notify before closing");
1363 assert!(*observed_shutdown.borrow());
1364 assert!(shutdown_requested.load(Ordering::SeqCst));
1365 }
1366
1367 #[tokio::test]
1368 async fn join_reports_task_that_exited_before_late_shutdown_request() {
1369 let (shutdown_tx, _) = watch::channel(false);
1370 let shutdown_requested = Arc::new(AtomicBool::new(false));
1371 let supervisor = supervisor_from_shutdown_channel(
1372 shutdown_tx,
1373 Arc::clone(&shutdown_requested),
1374 vec![test_task("early-before-late-signal", async {})],
1375 );
1376
1377 while !supervisor.tasks[0].handle.is_finished() {
1378 tokio::task::yield_now().await;
1379 }
1380
1381 let shutdown = supervisor.shutdown_handle();
1382 shutdown.request_shutdown();
1383
1384 let error = supervisor
1385 .join()
1386 .await
1387 .expect_err("task exit before shutdown request should still be reported");
1388
1389 match error {
1390 Error::Runtime(RuntimeError::TaskExitedUnexpectedly { task }) => {
1391 assert_eq!(task, "early-before-late-signal");
1392 }
1393 other => panic!("expected unexpected task exit, got {other:?}"),
1394 }
1395 }
1396
1397 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1398 async fn join_reports_task_exit_when_shutdown_races_completion_poll() {
1399 let (entered_tx, entered_rx) = std::sync::mpsc::channel();
1400 let (release_tx, release_rx) = std::sync::mpsc::channel();
1401 let (shutdown_tx, _) = watch::channel(false);
1402 let shutdown_requested = Arc::new(AtomicBool::new(false));
1403 let supervisor = supervisor_from_shutdown_channel(
1404 shutdown_tx,
1405 Arc::clone(&shutdown_requested),
1406 vec![RuntimeTask::spawn(
1407 "race-completion",
1408 CompleteAfterPollSignal {
1409 entered_tx: Some(entered_tx),
1410 release_rx,
1411 exit: RuntimeTaskExit::Completed,
1412 },
1413 )],
1414 );
1415 entered_rx
1416 .recv_timeout(Duration::from_secs(1))
1417 .expect("task should enter its completion poll");
1418 let shutdown = supervisor.shutdown_handle();
1419
1420 shutdown.request_shutdown();
1421 release_tx
1422 .send(())
1423 .expect("completion poll release should be received");
1424
1425 let error = supervisor
1426 .join()
1427 .await
1428 .expect_err("task exit that began before shutdown should be reported");
1429
1430 match error {
1431 Error::Runtime(RuntimeError::TaskExitedUnexpectedly { task }) => {
1432 assert_eq!(task, "race-completion");
1433 }
1434 other => panic!("expected unexpected task exit, got {other:?}"),
1435 }
1436 }
1437
1438 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1439 async fn join_allows_shutdown_exit_when_shutdown_races_completion_poll() {
1440 let (entered_tx, entered_rx) = std::sync::mpsc::channel();
1441 let (release_tx, release_rx) = std::sync::mpsc::channel();
1442 let (shutdown_tx, _) = watch::channel(false);
1443 let shutdown_requested = Arc::new(AtomicBool::new(false));
1444 let supervisor = supervisor_from_shutdown_channel(
1445 shutdown_tx,
1446 Arc::clone(&shutdown_requested),
1447 vec![RuntimeTask::spawn(
1448 "shutdown-race-completion",
1449 CompleteAfterPollSignal {
1450 entered_tx: Some(entered_tx),
1451 release_rx,
1452 exit: RuntimeTaskExit::Shutdown,
1453 },
1454 )],
1455 );
1456 entered_rx
1457 .recv_timeout(Duration::from_secs(1))
1458 .expect("task should enter its completion poll");
1459 let shutdown = supervisor.shutdown_handle();
1460
1461 shutdown.request_shutdown();
1462 release_tx
1463 .send(())
1464 .expect("completion poll release should be received");
1465
1466 supervisor
1467 .join()
1468 .await
1469 .expect("task that reports shutdown should join cleanly");
1470 }
1471
1472 #[tokio::test]
1473 async fn panic_after_shutdown_request_is_reported() {
1474 let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
1475 let shutdown_requested = Arc::new(AtomicBool::new(false));
1476 let supervisor = supervisor_from_shutdown_channel(
1477 shutdown_tx,
1478 Arc::clone(&shutdown_requested),
1479 vec![test_shutdown_task("panic-after-shutdown", async move {
1480 while !*shutdown_rx.borrow() {
1481 if shutdown_rx.changed().await.is_err() {
1482 return;
1483 }
1484 }
1485 panic!("forced post-shutdown panic");
1486 })],
1487 );
1488 let shutdown = supervisor.shutdown_handle();
1489
1490 shutdown.request_shutdown();
1491
1492 let error = supervisor
1493 .shutdown()
1494 .await
1495 .expect_err("panic after requested shutdown should fail");
1496
1497 match error {
1498 Error::Runtime(RuntimeError::TaskJoin { task, source }) => {
1499 assert_eq!(task, "panic-after-shutdown");
1500 assert!(source.is_panic());
1501 }
1502 other => panic!("expected task join error, got {other:?}"),
1503 }
1504 }
1505
1506 #[tokio::test]
1507 async fn early_normal_task_exit_is_unexpected() {
1508 let supervisor = Supervisor::from_tasks_for_tests(vec![test_task("test-loop", async {})]);
1509
1510 let error = supervisor
1511 .join()
1512 .await
1513 .expect_err("early normal exit should fail");
1514
1515 match error {
1516 Error::Runtime(RuntimeError::TaskExitedUnexpectedly { task }) => {
1517 assert_eq!(task, "test-loop");
1518 }
1519 other => panic!("expected unexpected task exit, got {other:?}"),
1520 }
1521 }
1522
1523 #[tokio::test]
1524 async fn shutdown_reports_task_that_exited_before_shutdown_request() {
1525 let supervisor = Supervisor::from_tasks_for_tests(vec![test_task("early-loop", async {})]);
1526
1527 while !supervisor.tasks[0].handle.is_finished() {
1528 tokio::task::yield_now().await;
1529 }
1530
1531 let error = supervisor
1532 .shutdown()
1533 .await
1534 .expect_err("pre-shutdown task exit should fail");
1535
1536 match error {
1537 Error::Runtime(RuntimeError::TaskExitedUnexpectedly { task }) => {
1538 assert_eq!(task, "early-loop");
1539 }
1540 other => panic!("expected unexpected task exit, got {other:?}"),
1541 }
1542 }
1543
1544 #[tokio::test]
1545 async fn pre_shutdown_sweep_consumes_all_already_finished_tasks() {
1546 let mut tasks = vec![
1547 test_task("finished-a", async {}),
1548 test_task("pending", async {
1549 std::future::pending::<()>().await;
1550 }),
1551 test_task("finished-b", async {}),
1552 ];
1553
1554 while tasks
1555 .iter()
1556 .filter(|task| task.name != "pending")
1557 .any(|task| !task.handle.is_finished())
1558 {
1559 tokio::task::yield_now().await;
1560 }
1561
1562 let error = join_pre_shutdown_finished_tasks(&mut tasks)
1563 .await
1564 .expect("finished tasks should produce a pre-shutdown error");
1565
1566 match error {
1567 RuntimeError::TaskExitedUnexpectedly { task } => {
1568 assert!(
1569 matches!(task, "finished-a" | "finished-b"),
1570 "unexpected first finished task: {task}"
1571 );
1572 }
1573 other => panic!("expected unexpected task exit, got {other:?}"),
1574 }
1575 assert_eq!(tasks.len(), 1);
1576 assert_eq!(tasks[0].name, "pending");
1577
1578 let pending = tasks.pop().expect("pending task remains");
1579 pending.handle.abort();
1580 let _ = pending.handle.await;
1583 }
1584
1585 #[tokio::test]
1586 async fn pre_shutdown_sweep_allows_explicit_shutdown_exit() {
1587 let mut tasks = vec![test_shutdown_task("finished-after-signal", async {})];
1588 while !tasks[0].handle.is_finished() {
1589 tokio::task::yield_now().await;
1590 }
1591
1592 let error = join_pre_shutdown_finished_tasks(&mut tasks).await;
1593
1594 assert!(error.is_none());
1595 assert!(tasks.is_empty());
1596 }
1597
1598 #[tokio::test]
1599 async fn shutdown_with_timeout_aborts_and_drains_stubborn_task() {
1600 let dropped = Arc::new(AtomicBool::new(false));
1601 let drop_flag = DropFlag(Arc::clone(&dropped));
1602 let supervisor =
1603 Supervisor::from_tasks_for_tests(vec![test_task("stubborn-loop", async move {
1604 let _drop_flag = drop_flag;
1605 std::future::pending::<()>().await;
1606 })]);
1607
1608 let error = supervisor
1609 .shutdown_with_timeout(Duration::from_millis(50))
1610 .await
1611 .expect_err("stubborn task should time out shutdown");
1612
1613 match error {
1614 Error::Runtime(RuntimeError::ShutdownTimeout { timeout }) => {
1615 assert_eq!(timeout, Duration::from_millis(50));
1616 }
1617 other => panic!("expected shutdown timeout error, got {other:?}"),
1618 }
1619 assert!(dropped.load(Ordering::SeqCst));
1620 }
1621
1622 #[tokio::test]
1623 async fn shutdown_with_timeout_rejects_unrepresentable_deadline() {
1624 let error = Supervisor::from_tasks_for_tests(Vec::new())
1625 .shutdown_with_timeout(Duration::MAX)
1626 .await
1627 .expect_err("unrepresentable timeout should fail instead of panicking");
1628
1629 match error {
1630 Error::Runtime(RuntimeError::ShutdownTimeoutTooLarge { timeout }) => {
1631 assert_eq!(timeout, Duration::MAX);
1632 }
1633 other => panic!("expected oversized timeout error, got {other:?}"),
1634 }
1635 }
1636
1637 #[tokio::test]
1638 async fn shutdown_with_zero_timeout_aborts_immediately() {
1639 let supervisor =
1640 Supervisor::from_tasks_for_tests(vec![test_task("zero-timeout-pending-loop", async {
1641 std::future::pending::<()>().await;
1642 })]);
1643
1644 let error = supervisor
1645 .shutdown_with_timeout(Duration::ZERO)
1646 .await
1647 .expect_err("zero timeout should report an immediate shutdown timeout");
1648
1649 match error {
1650 Error::Runtime(RuntimeError::ShutdownTimeout { timeout }) => {
1651 assert_eq!(timeout, Duration::ZERO);
1652 }
1653 other => panic!("expected shutdown timeout error, got {other:?}"),
1654 }
1655 }
1656
1657 #[tokio::test]
1658 async fn shutdown_with_timeout_succeeds_when_task_exits_cooperatively() {
1659 let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
1660 let shutdown_requested = Arc::new(AtomicBool::new(false));
1661 let supervisor = supervisor_from_shutdown_channel(
1662 shutdown_tx,
1663 Arc::clone(&shutdown_requested),
1664 vec![test_shutdown_task("cooperative-timeout-loop", async move {
1665 while !*shutdown_rx.borrow() {
1666 if shutdown_rx.changed().await.is_err() {
1667 break;
1668 }
1669 }
1670 })],
1671 );
1672
1673 supervisor
1674 .shutdown_with_timeout(Duration::from_secs(1))
1675 .await
1676 .expect("cooperative task should shut down before timeout");
1677 }
1678
1679 #[tokio::test]
1680 async fn shutdown_with_timeout_pre_shutdown_error_allows_remaining_task_to_exit() {
1681 let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
1682 let shutdown_requested = Arc::new(AtomicBool::new(false));
1683 let dropped = Arc::new(AtomicBool::new(false));
1684 let drop_flag = DropFlag(Arc::clone(&dropped));
1685 let tasks = vec![
1686 test_task("finished-before-shutdown", async {}),
1687 test_shutdown_task("cooperative-after-error", async move {
1688 let _drop_flag = drop_flag;
1689 while !*shutdown_rx.borrow() {
1690 if shutdown_rx.changed().await.is_err() {
1691 break;
1692 }
1693 }
1694 }),
1695 ];
1696
1697 while !tasks[0].handle.is_finished() {
1698 tokio::task::yield_now().await;
1699 }
1700
1701 let supervisor =
1702 supervisor_from_shutdown_channel(shutdown_tx, Arc::clone(&shutdown_requested), tasks);
1703 let error = supervisor
1704 .shutdown_with_timeout(Duration::from_secs(1))
1705 .await
1706 .expect_err("pre-shutdown task exit should fail");
1707
1708 match error {
1709 Error::Runtime(RuntimeError::TaskExitedUnexpectedly { task }) => {
1710 assert_eq!(task, "finished-before-shutdown");
1711 }
1712 other => panic!("expected pre-shutdown task exit, got {other:?}"),
1713 }
1714 assert!(dropped.load(Ordering::SeqCst));
1715 }
1716
1717 #[tokio::test]
1718 async fn shutdown_with_timeout_reports_timeout_after_pre_shutdown_error() {
1719 let tasks = vec![
1720 test_task("finished-before-shutdown", async {}),
1721 test_task("pending-after-error", async {
1722 std::future::pending::<()>().await;
1723 }),
1724 ];
1725
1726 while !tasks[0].handle.is_finished() {
1727 tokio::task::yield_now().await;
1728 }
1729
1730 let error = Supervisor::from_tasks_for_tests(tasks)
1731 .shutdown_with_timeout(Duration::from_millis(1))
1732 .await
1733 .expect_err("pre-shutdown task exit with stuck drain should time out");
1734
1735 match error {
1736 Error::Runtime(RuntimeError::ShutdownTimeoutAfterTaskError { timeout, source }) => {
1737 assert_eq!(timeout, Duration::from_millis(1));
1738 match *source {
1739 RuntimeError::TaskExitedUnexpectedly { task } => {
1740 assert_eq!(task, "finished-before-shutdown");
1741 }
1742 other => panic!("expected pre-shutdown task exit source, got {other:?}"),
1743 }
1744 }
1745 other => panic!("expected shutdown timeout after task error, got {other:?}"),
1746 }
1747 }
1748
1749 #[tokio::test]
1750 async fn shutdown_with_timeout_reports_task_error_when_remaining_task_misses_deadline() {
1751 let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
1752 let shutdown_requested = Arc::new(AtomicBool::new(false));
1753 let supervisor = supervisor_from_shutdown_channel(
1754 shutdown_tx,
1755 Arc::clone(&shutdown_requested),
1756 vec![
1757 test_shutdown_task("panic-after-timeout-shutdown", async move {
1758 while !*shutdown_rx.borrow() {
1759 if shutdown_rx.changed().await.is_err() {
1760 return;
1761 }
1762 }
1763 panic!("forced live shutdown panic");
1764 }),
1765 test_shutdown_task("pending-after-timeout-panic", async {
1766 std::future::pending::<()>().await;
1767 }),
1768 ],
1769 );
1770
1771 let error = supervisor
1772 .shutdown_with_timeout(Duration::from_millis(50))
1773 .await
1774 .expect_err("task failure with stuck drain should preserve task error source");
1775
1776 match error {
1777 Error::Runtime(RuntimeError::ShutdownTimeoutAfterTaskError { timeout, source }) => {
1778 assert_eq!(timeout, Duration::from_millis(50));
1779 match *source {
1780 RuntimeError::TaskJoin { task, source } => {
1781 assert_eq!(task, "panic-after-timeout-shutdown");
1782 assert!(source.is_panic());
1783 }
1784 other => panic!("expected task join source, got {other:?}"),
1785 }
1786 }
1787 other => panic!("expected timeout after task join error, got {other:?}"),
1788 }
1789 }
1790
1791 #[tokio::test]
1792 async fn panicked_task_maps_to_task_join_error() {
1793 let supervisor = Supervisor::from_tasks_for_tests(vec![test_task("panic-loop", async {
1794 panic!("forced supervisor test panic");
1795 })]);
1796
1797 let error = supervisor
1798 .join()
1799 .await
1800 .expect_err("panicked task should fail");
1801
1802 match error {
1803 Error::Runtime(RuntimeError::TaskJoin { task, source }) => {
1804 assert_eq!(task, "panic-loop");
1805 assert!(source.is_panic());
1806 }
1807 other => panic!("expected task join error, got {other:?}"),
1808 }
1809 }
1810}