1use std::{
135 fmt::{self, Debug, Formatter},
136 pin::Pin,
137 sync::{Arc, RwLock},
138 task::{Context, Poll},
139};
140
141use futures_util::{
142 Future, FutureExt, StreamExt,
143 future::{BoxFuture, Shared},
144};
145use tower_layer::Layer;
146use tower_service::Service;
147
148use crate::{
149 backend::Backend,
150 error::{BoxDynError, WorkerError},
151 monitor::shutdown::Shutdown,
152 task::Task,
153 worker::{
154 ReadinessService, TrackerService, Worker,
155 context::WorkerContext,
156 event::{Event, EventHandlerBuilder},
157 },
158};
159
160pub mod shutdown;
161
162type WorkerFactory = Box<
163 dyn Fn(usize) -> (WorkerContext, BoxFuture<'static, Result<(), WorkerError>>)
164 + 'static
165 + Send
166 + Sync,
167>;
168
169type ShouldRestart = Arc<
170 RwLock<
171 Option<Box<dyn Fn(&WorkerContext, &WorkerError, usize) -> bool + 'static + Send + Sync>>,
172 >,
173>;
174
175type CurrentWorker = Option<(
176 WorkerContext,
177 Shared<BoxFuture<'static, Result<(), Arc<WorkerError>>>>,
178)>;
179
180#[pin_project::pin_project]
181struct MonitoredWorker {
183 factory: WorkerFactory,
184 #[pin]
185 current: CurrentWorker,
186 attempt: usize,
187 should_restart: ShouldRestart,
188}
189
190#[derive(Debug, Clone)]
192pub struct MonitoredWorkerError {
193 ctx: WorkerContext,
194 error: Arc<WorkerError>,
195}
196
197impl Future for MonitoredWorker {
198 type Output = Result<(), MonitoredWorkerError>;
199
200 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
201 use std::panic::{AssertUnwindSafe, catch_unwind};
202 let mut this = self.project();
203
204 loop {
205 use futures_util::TryFutureExt;
206 if this.current.is_none() {
207 let (ctx, worker) = (this.factory)(*this.attempt);
208 this.current.set(Some((
209 ctx,
210 worker
211 .map_err(|e: WorkerError| Arc::new(e))
212 .boxed()
213 .shared(),
214 )));
215 }
216
217 let mut current = this.current.as_mut().as_pin_mut().unwrap();
218 if current.0.is_running() && current.0.is_shutting_down() {
219 let ctx = current.0.clone();
220 ctx.stop().unwrap();
221 }
222 let poll_result =
223 catch_unwind(AssertUnwindSafe(|| current.1.poll_unpin(cx))).map_err(|err| {
224 let err = if let Some(s) = err.downcast_ref::<&str>() {
225 (*s).to_owned()
226 } else if let Some(s) = err.downcast_ref::<String>() {
227 s.clone()
228 } else {
229 "Unknown panic".to_owned()
230 };
231 Arc::new(WorkerError::PanicError(err))
232 });
233
234 match poll_result {
235 Ok(Poll::Pending) => return Poll::Pending,
236 Ok(Poll::Ready(Ok(()))) => return Poll::Ready(Ok(())),
237 Ok(Poll::Ready(Err(e))) | Err(e) => {
238 let (ctx, _) = this.current.take().unwrap();
239 ctx.stop().unwrap();
240 let should_restart = this.should_restart.read();
241 match should_restart.as_ref().map(|s| s.as_ref()) {
242 Ok(Some(cb)) => {
243 if !(cb)(&ctx, &e, *this.attempt) {
244 return Poll::Ready(Err(MonitoredWorkerError {
245 ctx,
246 error: Arc::clone(&e),
247 }));
248 }
249 *this.attempt += 1;
250 }
251 _ => {
252 return Poll::Ready(Err(MonitoredWorkerError {
253 ctx,
254 error: Arc::clone(&e),
255 }));
256 }
257 }
258 }
259 }
260 }
261 }
262}
263
264#[derive(Default)]
266pub struct Monitor {
267 workers: Vec<MonitoredWorker>,
268 terminator: Option<Shared<BoxFuture<'static, ()>>>,
269 shutdown: Shutdown,
270 event_handler: EventHandlerBuilder,
271 should_restart: ShouldRestart,
272}
273
274impl Debug for Monitor {
275 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
276 f.debug_struct("Monitor")
277 .field("shutdown", &"[Graceful shutdown listener]")
278 .field("workers", &self.workers.len())
279 .finish()
280 }
281}
282
283impl Monitor {
284 fn run_worker<Args, S, B, M>(
285 mut ctx: WorkerContext,
286 worker: Worker<Args, B::Context, B, S, M>,
287 ) -> BoxFuture<'static, Result<(), WorkerError>>
288 where
289 S: Service<Task<Args, B::Context, B::IdType>> + Send + 'static,
290 S::Future: Send,
291 S::Error: Send + Sync + 'static + Into<BoxDynError>,
292 B: Backend<Args = Args> + Send + 'static,
293 B::Error: Into<BoxDynError> + Send + 'static,
294 B::Layer: Layer<ReadinessService<TrackerService<S>>> + 'static,
295 M: Layer<<<B as Backend>::Layer as Layer<ReadinessService<TrackerService<S>>>>::Service> + 'static,
296 <M as Layer<
297 <<B as Backend>::Layer as Layer<ReadinessService<TrackerService<S>>>>::Service,
298 >>::Service: Service<Task<Args, B::Context, B::IdType>> + Send + 'static,
299 <<M as Layer<<B::Layer as Layer<ReadinessService<TrackerService<S>>>>::Service>>::Service as Service<Task<Args, B::Context, B::IdType>>>::Future: Send,
300 <<M as Layer<<B::Layer as Layer<ReadinessService<TrackerService<S>>>>::Service>>::Service as Service<Task<Args, B::Context, B::IdType>>>::Error: Into<BoxDynError> + Send + Sync + 'static,
301 B::Stream: Unpin + Send + 'static,
302 B::Beat: Unpin + Send,
303 Args: Send + 'static,
304 B::Context: Send + 'static,
305 B::IdType: Sync + Send + 'static,
306 {
307 let mut stream = worker.stream_with_ctx(&mut ctx);
308 async move {
309 loop {
310 match stream.next().await {
311 Some(Err(e)) => return Err(e),
312 None => return Ok(()),
313 _ => (),
314 }
315 }
316 }
317 .boxed()
318 }
319 #[must_use]
340 pub fn register<Args, S, B, M>(
341 mut self,
342 factory: impl Fn(usize) -> Worker<Args, B::Context, B, S, M> + 'static + Send + Sync,
343 ) -> Self
344 where
345 S: Service<Task<Args, B::Context, B::IdType>> + Send + 'static,
346 S::Future: Send,
347 S::Error: Send + Sync + 'static + Into<BoxDynError>,
348 B: Backend<Args = Args> + Send + 'static,
349 B::Error: Into<BoxDynError> + Send + 'static,
350 B::Stream: Unpin + Send + 'static,
351 B::Beat: Unpin + Send,
352 Args: Send + 'static,
353 B::Context: Send + 'static,
354 B::Layer: Layer<ReadinessService<TrackerService<S>>> + 'static,
355 M: Layer<<<B as Backend>::Layer as Layer<ReadinessService<TrackerService<S>>>>::Service> + 'static,
356 <M as Layer<
357 <<B as Backend>::Layer as Layer<ReadinessService<TrackerService<S>>>>::Service,
358 >>::Service: Service<Task<Args, B::Context, B::IdType>> + Send + 'static,
359 <<M as Layer<<B::Layer as Layer<ReadinessService<TrackerService<S>>>>::Service>>::Service as Service<Task<Args, B::Context, B::IdType>>>::Future: Send,
360 <<M as Layer<<B::Layer as Layer<ReadinessService<TrackerService<S>>>>::Service>>::Service as Service<Task<Args, B::Context, B::IdType>>>::Error:
361 Into<BoxDynError> + Send + Sync + 'static,
362 B::IdType: Send + Sync + 'static,
363 {
364 let shutdown = Some(self.shutdown.clone());
365 let handler = self.event_handler.clone();
366 let should_restart = self.should_restart.clone();
367 let worker = MonitoredWorker {
368 current: None,
369 factory: Box::new(move |attempt| {
370 let new_worker = factory(attempt);
371 let id = Arc::new(new_worker.name.clone());
372 let mut ctx = WorkerContext::new::<M::Service>(&id);
373 let handler = handler.clone();
374 ctx.wrap_listener(move |ctx, ev| {
375 let handlers = handler.read();
376 if let Ok(handlers) = handlers {
377 for h in handlers.iter() {
378 h(ctx, ev);
379 }
380 }
381 });
382 ctx.shutdown = shutdown.clone();
383 (ctx.clone(), Self::run_worker(ctx.clone(), new_worker))
384 }),
385 attempt: 0,
386 should_restart,
387 };
388 self.workers.push(worker);
389 self
390 }
391
392 pub async fn run_with_signal<S>(self, signal: S) -> Result<(), MonitorError>
408 where
409 S: Send + Future<Output = std::io::Result<()>>,
410 {
411 let shutdown = self.shutdown.clone();
412 let shutdown_after = self.shutdown.shutdown_after(signal);
413 if let Some(terminator) = self.terminator {
414 let _res = futures_util::future::select(
415 Self::run_all_workers(self.workers, shutdown).boxed(),
416 async {
417 let res = shutdown_after.await;
418 terminator.await;
419 res.map_err(MonitorError::ShutdownSignal)
420 }
421 .boxed(),
422 )
423 .await;
424 } else {
425 let runner = self.run();
426 let res = futures_util::join!(shutdown_after, runner); match res {
428 (Ok(_), Ok(_)) => {
429 }
431 (Err(e), Ok(_) | Err(_)) => return Err(e.into()),
432 (Ok(_), Err(e)) => return Err(e),
433 }
434 }
435 Ok(())
436 }
437
438 pub async fn run(self) -> Result<(), MonitorError> {
448 let shutdown = self.shutdown.clone();
449 let shutdown_future = self.shutdown.boxed().map(|_| ());
450 let (result, _) = futures_util::join!(
451 Self::run_all_workers(self.workers, shutdown),
452 shutdown_future,
453 );
454
455 result
456 }
457 async fn run_all_workers(
458 workers: Vec<MonitoredWorker>,
459 shutdown: Shutdown,
460 ) -> Result<(), MonitorError> {
461 let results = futures_util::future::join_all(workers).await;
462
463 shutdown.start_shutdown();
464
465 let mut errors = Vec::new();
466 for r in results {
468 match r {
469 Ok(_) => {}
470 Err(MonitoredWorkerError { ctx, error }) => match &*error {
471 WorkerError::GracefulExit => {}
472 _ => errors.push(MonitoredWorkerError {
473 ctx,
474 error: Arc::clone(&error),
475 }),
476 },
477 }
478 }
479 if !errors.is_empty() {
480 return Err(MonitorError::ExitError(ExitError(errors)));
481 }
482 Ok(())
483 }
484
485 #[must_use]
487 pub fn on_event<F: Fn(&WorkerContext, &Event) + Send + Sync + 'static>(self, f: F) -> Self {
488 let _ = self.event_handler.write().map(|mut res| {
489 let _ = res.insert(Box::new(f));
490 });
491 self
492 }
493}
494
495impl Monitor {
496 #[must_use]
502 pub fn new() -> Self {
503 Self::default()
504 }
505
506 #[cfg(feature = "sleep")]
516 #[must_use]
517 pub fn shutdown_timeout(self, duration: std::time::Duration) -> Self {
518 self.with_terminator(crate::timer::sleep(duration))
519 }
520
521 #[must_use]
529 pub fn with_terminator(mut self, fut: impl Future<Output = ()> + Send + 'static) -> Self {
530 self.terminator = Some(fut.boxed().shared());
531 self
532 }
533
534 #[must_use]
536 pub fn should_restart<F>(self, cb: F) -> Self
537 where
538 F: Fn(&WorkerContext, &WorkerError, usize) -> bool + Send + Sync + 'static,
539 {
540 let _ = self.should_restart.write().map(|mut res| {
541 let _ = res.insert(Box::new(cb));
542 });
543 self
544 }
545}
546
547#[derive(Debug, thiserror::Error)]
549pub enum MonitorError {
550 #[error("Worker errors:\n{0}")]
552 ExitError(#[from] ExitError),
553
554 #[error("Shutdown signal error: {0}")]
556 ShutdownSignal(#[from] std::io::Error),
557}
558
559impl fmt::Debug for ExitError {
560 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
561 std::fmt::Display::fmt(&self, f)
562 }
563}
564
565#[derive(thiserror::Error)]
567pub struct ExitError(pub Vec<MonitoredWorkerError>);
568
569impl std::fmt::Display for ExitError {
570 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
571 writeln!(f, "MonitoredErrors:")?;
572 for worker in &self.0 {
573 writeln!(f, " - Worker `{}`: {}", worker.ctx.name(), worker.error)?;
574 }
575 Ok(())
576 }
577}
578
579#[cfg(test)]
580#[cfg(feature = "json")]
581mod tests {
582 use super::*;
583 use crate::{
584 backend::{TaskSink, json::JsonStorage},
585 task::task_id::TaskId,
586 worker::context::WorkerContext,
587 };
588 use core::panic;
589 use std::{io, time::Duration};
590
591 use tokio::time::sleep;
592 use tower::limit::ConcurrencyLimitLayer;
593
594 use crate::{monitor::Monitor, worker::builder::WorkerBuilder};
595
596 #[tokio::test]
597 async fn basic_with_workers() {
598 let mut backend = JsonStorage::new_temp().unwrap();
599
600 for i in 0..10 {
601 backend.push(i).await.unwrap();
602 }
603
604 let monitor: Monitor = Monitor::new();
605 let monitor = monitor.register(move |index| {
606 WorkerBuilder::new(format!("rango-tango-{index}"))
607 .backend(backend.clone())
608 .build(move |r: u32, id: TaskId, w: WorkerContext| async move {
609 println!("{id:?}, {}", w.name());
610 tokio::time::sleep(Duration::from_secs(index as u64)).await;
611 Ok::<_, io::Error>(r)
612 })
613 });
614 let shutdown = monitor.shutdown.clone();
615 tokio::spawn(async move {
616 sleep(Duration::from_millis(1500)).await;
617 shutdown.start_shutdown();
618 });
619 monitor.run().await.unwrap();
620 }
621 #[tokio::test]
622 async fn test_monitor_run() {
623 let mut backend = JsonStorage::new(
624 "/var/folders/h_/sd1_gb5x73bbcxz38dts7pj80000gp/T/apalis-json-store-girmm9e36pz",
625 )
626 .unwrap();
627
628 for i in 0..10 {
629 backend.push(i).await.unwrap();
630 }
631
632 let monitor: Monitor = Monitor::new()
633 .register(move |index| {
634 WorkerBuilder::new(format!("rango-tango-{index}"))
635 .backend(backend.clone())
636 .build(move |r| async move {
637 if r % 2 == 0 {
638 panic!("Brrr")
639 }
640 })
641 })
642 .should_restart(|ctx, e, index| {
643 println!(
644 "Encountered error in {} with {e:?} for attempt {index}",
645 ctx.name()
646 );
647 if index > 3 {
648 return false;
649 }
650 true
651 })
652 .on_event(|wrk, e| {
653 println!("{}: {e:?}", wrk.name());
654 });
655 assert_eq!(monitor.workers.len(), 1);
656 let shutdown = monitor.shutdown.clone();
657
658 tokio::spawn(async move {
659 sleep(Duration::from_millis(5000)).await;
660 shutdown.start_shutdown();
661 });
662
663 let result = monitor.run().await;
664 assert!(
665 result.is_err_and(|e| matches!(e, MonitorError::ExitError(_))),
666 "Monitor did not return an error as expected"
667 );
668 }
669
670 #[tokio::test]
671 async fn test_monitor_register_multiple() {
672 let mut backend = JsonStorage::new_temp().unwrap();
673
674 for i in 0..10 {
675 backend.push(i).await.unwrap();
676 }
677
678 let monitor: Monitor = Monitor::new();
679
680 assert_send_sync::<Monitor>();
681
682 let monitor = monitor.on_event(|wrk, e| {
683 println!("{:?}, {e:?}", wrk.name());
684 });
685 let b = backend.clone();
686 let monitor = monitor
687 .register(move |index| {
688 WorkerBuilder::new(format!("worker0-{index}"))
689 .backend(backend.clone())
690 .layer(ConcurrencyLimitLayer::new(1))
691 .build(
692 move |request: i32, id: TaskId, w: WorkerContext| async move {
693 println!("{id:?}, {}", w.name());
694 tokio::time::sleep(Duration::from_secs(index as u64)).await;
695 Ok::<_, io::Error>(request)
696 },
697 )
698 })
699 .register(move |index| {
700 WorkerBuilder::new(format!("worker1-{index}"))
701 .backend(b.clone())
702 .layer(ConcurrencyLimitLayer::new(1))
703 .build(
704 move |request: i32, id: TaskId, w: WorkerContext| async move {
705 println!("{id:?}, {}", w.name());
706 tokio::time::sleep(Duration::from_secs(index as u64)).await;
707 Ok::<_, io::Error>(request)
708 },
709 )
710 });
711 assert_eq!(monitor.workers.len(), 2);
712 let shutdown = monitor.shutdown.clone();
713
714 tokio::spawn(async move {
715 sleep(Duration::from_millis(5000)).await;
716 shutdown.start_shutdown();
717 });
718
719 let result = monitor.run().await;
720 assert!(result.is_ok());
721 }
722
723 fn assert_send_sync<T: Send + Sync>() {}
724}