1use std::{
135 fmt::{self, Debug, Formatter},
136 pin::Pin,
137 sync::{Arc, RwLock},
138 task::{Context, Poll},
139};
140
141use futures_util::{
142 Future, FutureExt, StreamExt,
143 future::{BoxFuture, Shared},
144};
145use tower_layer::Layer;
146use tower_service::Service;
147
148use crate::{
149 backend::Backend,
150 error::{BoxDynError, WorkerError},
151 monitor::shutdown::Shutdown,
152 task::Task,
153 worker::{
154 ReadinessService, TrackerService, Worker,
155 context::WorkerContext,
156 event::{Event, EventHandlerBuilder},
157 },
158};
159
160pub mod shutdown;
161
162#[pin_project::pin_project]
163struct MonitoredWorker {
165 factory: Box<
166 dyn Fn(usize) -> (WorkerContext, BoxFuture<'static, Result<(), WorkerError>>)
167 + 'static
168 + Send
169 + Sync,
170 >,
171 #[pin]
172 current: Option<(
173 WorkerContext,
174 Shared<BoxFuture<'static, Result<(), Arc<WorkerError>>>>,
175 )>,
176 attempt: usize,
177 should_restart: Arc<
178 RwLock<
179 Option<
180 Box<dyn Fn(&WorkerContext, &WorkerError, usize) -> bool + 'static + Send + Sync>,
181 >,
182 >,
183 >,
184}
185
186#[derive(Debug)]
188pub struct MonitoredWorkerError {
189 ctx: WorkerContext,
190 error: WorkerError,
191}
192
193impl Future for MonitoredWorker {
194 type Output = Result<(), MonitoredWorkerError>;
195
196 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
197 use std::panic::{AssertUnwindSafe, catch_unwind};
198 let mut this = self.project();
199
200 loop {
201 use futures_util::TryFutureExt;
202 if this.current.is_none() {
203 let (ctx, worker) = (this.factory)(*this.attempt);
204 this.current.set(Some((
205 ctx,
206 worker
207 .map_err(|e: WorkerError| Arc::new(e))
208 .boxed()
209 .shared(),
210 )));
211 }
212
213 let mut current = this.current.as_mut().as_pin_mut().unwrap();
214 if current.0.is_running() && current.0.is_shutting_down() {
215 return Poll::Ready(Ok(()));
216 }
217 let poll_result =
218 catch_unwind(AssertUnwindSafe(|| current.1.poll_unpin(cx))).map_err(|err| {
219 let err = if let Some(s) = err.downcast_ref::<&str>() {
220 s.to_string()
221 } else if let Some(s) = err.downcast_ref::<String>() {
222 s.clone()
223 } else {
224 "Unknown panic".to_string()
225 };
226 Arc::new(WorkerError::PanicError(err))
227 });
228
229 match poll_result {
230 Ok(Poll::Pending) => return Poll::Pending,
231 Ok(Poll::Ready(Ok(()))) => return Poll::Ready(Ok(())),
232 Ok(Poll::Ready(Err(e))) | Err(e) => {
233 let (ctx, _) = this.current.take().unwrap();
234 ctx.stop().unwrap();
235 let should_restart = this.should_restart.read();
236 match should_restart.as_ref().map(|s| s.as_ref()) {
237 Ok(Some(cb)) => {
238 if !(cb)(&ctx, &e, *this.attempt) {
239 return Poll::Ready(Err(MonitoredWorkerError {
240 ctx,
241 error: Arc::into_inner(e).unwrap(),
242 }));
243 }
244 *this.attempt += 1;
245 }
246 _ => {
247 return Poll::Ready(Err(MonitoredWorkerError {
248 ctx,
249 error: Arc::into_inner(e).unwrap(),
250 }));
251 }
252 }
253 }
254 }
255 }
256 }
257}
258
259type ShouldRestart = Arc<
260 RwLock<
261 Option<Box<dyn Fn(&WorkerContext, &WorkerError, usize) -> bool + Send + Sync + 'static>>,
262 >,
263>;
264
265#[derive(Default)]
267pub struct Monitor {
268 workers: Vec<MonitoredWorker>,
269 terminator: Option<Shared<BoxFuture<'static, ()>>>,
270 shutdown: Shutdown,
271 event_handler: EventHandlerBuilder,
272 should_restart: ShouldRestart,
273}
274
275impl Debug for Monitor {
276 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
277 f.debug_struct("Monitor")
278 .field("shutdown", &"[Graceful shutdown listener]")
279 .field("workers", &self.workers.len())
280 .finish()
281 }
282}
283
284impl Monitor {
285 fn run_worker<Args, S, B, M>(
286 mut ctx: WorkerContext,
287 worker: Worker<Args, B::Context, B, S, M>,
288 ) -> BoxFuture<'static, Result<(), WorkerError>>
289 where
290 S: Service<Task<Args, B::Context, B::IdType>> + Send + 'static,
291 S::Future: Send,
292 S::Error: Send + Sync + 'static + Into<BoxDynError>,
293 B: Backend<Args = Args> + Send + 'static,
294 B::Error: Into<BoxDynError> + Send + 'static,
295 B::Layer: Layer<ReadinessService<TrackerService<S>>> + 'static,
296 M: Layer<<<B as Backend>::Layer as Layer<ReadinessService<TrackerService<S>>>>::Service> + 'static,
297 <M as Layer<
298 <<B as Backend>::Layer as Layer<ReadinessService<TrackerService<S>>>>::Service,
299 >>::Service: Service<Task<Args, B::Context, B::IdType>> + Send + 'static,
300 <<M as Layer<<B::Layer as Layer<ReadinessService<TrackerService<S>>>>::Service>>::Service as Service<Task<Args, B::Context, B::IdType>>>::Future: Send,
301 <<M as Layer<<B::Layer as Layer<ReadinessService<TrackerService<S>>>>::Service>>::Service as Service<Task<Args, B::Context, B::IdType>>>::Error: Into<BoxDynError> + Send + Sync + 'static,
302 B::Stream: Unpin + Send + 'static,
303 B::Beat: Unpin + Send,
304 Args: Send + 'static,
305 B::Context: Send + 'static,
306 B::IdType: Sync + Send + 'static,
307 {
308 let mut stream = worker.stream_with_ctx(&mut ctx);
309 async move {
310 loop {
311 match stream.next().await {
312 Some(Err(e)) => return Err(e),
313 None => return Ok(()),
314 _ => (),
315 }
316 }
317 }
318 .boxed()
319 }
320 pub fn register<Args, S, B, M>(
341 mut self,
342 factory: impl Fn(usize) -> Worker<Args, B::Context, B, S, M> + 'static + Send + Sync,
343 ) -> Self
344 where
345 S: Service<Task<Args, B::Context, B::IdType>> + Send + 'static,
346 S::Future: Send,
347 S::Error: Send + Sync + 'static + Into<BoxDynError>,
348 B: Backend<Args = Args> + Send + 'static,
349 B::Error: Into<BoxDynError> + Send + 'static,
350 B::Stream: Unpin + Send + 'static,
351 B::Beat: Unpin + Send,
352 Args: Send + 'static,
353 B::Context: Send + 'static,
354 B::Layer: Layer<ReadinessService<TrackerService<S>>> + 'static,
355 M: Layer<<<B as Backend>::Layer as Layer<ReadinessService<TrackerService<S>>>>::Service> + 'static,
356 <M as Layer<
357 <<B as Backend>::Layer as Layer<ReadinessService<TrackerService<S>>>>::Service,
358 >>::Service: Service<Task<Args, B::Context, B::IdType>> + Send + 'static,
359 <<M as Layer<<B::Layer as Layer<ReadinessService<TrackerService<S>>>>::Service>>::Service as Service<Task<Args, B::Context, B::IdType>>>::Future: Send,
360 <<M as Layer<<B::Layer as Layer<ReadinessService<TrackerService<S>>>>::Service>>::Service as Service<Task<Args, B::Context, B::IdType>>>::Error:
361 Into<BoxDynError> + Send + Sync + 'static,
362 B::IdType: Send + Sync + 'static,
363 {
364 let shutdown = Some(self.shutdown.clone());
365 let handler = self.event_handler.clone();
366 let should_restart = self.should_restart.clone();
367 let worker = MonitoredWorker {
368 current: None,
369 factory: Box::new(move |attempt| {
370 let new_worker = factory(attempt);
371 let id = Arc::new(new_worker.name.clone());
372 let mut ctx = WorkerContext::new::<M::Service>(&id);
373 let handler = handler.clone();
374 ctx.wrap_listener(move |ctx, ev| {
375 let handlers = handler.read();
376 if let Ok(handlers) = handlers {
377 for h in handlers.iter() {
378 h(ctx, ev);
379 }
380 }
381 });
382 ctx.shutdown = shutdown.clone();
383 (ctx.clone(), Self::run_worker(ctx.clone(), new_worker))
384 }),
385 attempt: 0,
386 should_restart,
387 };
388 self.workers.push(worker);
389 self
390 }
391
392 pub async fn run_with_signal<S>(self, signal: S) -> Result<(), MonitorError>
408 where
409 S: Send + Future<Output = std::io::Result<()>>,
410 {
411 let shutdown = self.shutdown.clone();
412 let shutdown_after = self.shutdown.shutdown_after(signal);
413 if let Some(terminator) = self.terminator {
414 let _res = futures_util::future::join(
415 Self::run_all_workers(self.workers, shutdown),
416 async {
417 let res = shutdown_after.await;
418 terminator.await;
419 res.map_err(MonitorError::ShutdownSignal)
420 }
421 .boxed(),
422 )
423 .await;
424 } else {
425 let runner = self.run();
426 let res = futures_util::join!(shutdown_after, runner); match res {
428 (Ok(_), Ok(_)) => {
429 }
431 (Err(e), Ok(_)) => return Err(e.into()),
432 (Ok(_), Err(e)) => return Err(e),
433 (Err(e), Err(_)) => return Err(e.into()),
434 }
435 }
436 Ok(())
437 }
438
439 pub async fn run(self) -> Result<(), MonitorError> {
449 let shutdown = self.shutdown.clone();
450 let shutdown_future = self.shutdown.boxed().map(|_| ());
451 let (result, _) = futures_util::join!(
452 Self::run_all_workers(self.workers, shutdown),
453 shutdown_future,
454 );
455
456 result
457 }
458 async fn run_all_workers(
459 workers: Vec<MonitoredWorker>,
460 shutdown: Shutdown,
461 ) -> Result<(), MonitorError> {
462 let results = futures_util::future::join_all(workers).await;
463
464 shutdown.start_shutdown();
465
466 let mut errors = Vec::new();
467 for r in results {
469 match r {
470 Ok(_) => {}
471 Err(MonitoredWorkerError { ctx, error }) => match error {
472 WorkerError::GracefulExit => {}
473 _ => errors.push(MonitoredWorkerError { ctx, error }),
474 },
475 }
476 }
477 if !errors.is_empty() {
478 return Err(MonitorError::ExitError(ExitError(errors)));
479 }
480 Ok(())
481 }
482
483 pub fn on_event<F: Fn(&WorkerContext, &Event) + Send + Sync + 'static>(self, f: F) -> Self {
485 let _ = self.event_handler.write().map(|mut res| {
486 let _ = res.insert(Box::new(f));
487 });
488 self
489 }
490}
491
492impl Monitor {
493 pub fn new() -> Self {
499 Self::default()
500 }
501
502 #[cfg(feature = "sleep")]
512 pub fn shutdown_timeout(self, duration: std::time::Duration) -> Self {
513 self.with_terminator(crate::timer::sleep(duration))
514 }
515
516 pub fn with_terminator(mut self, fut: impl Future<Output = ()> + Send + 'static) -> Self {
524 self.terminator = Some(fut.boxed().shared());
525 self
526 }
527
528 pub fn should_restart<F>(self, cb: F) -> Self
530 where
531 F: Fn(&WorkerContext, &WorkerError, usize) -> bool + Send + Sync + 'static,
532 {
533 let _ = self.should_restart.write().map(|mut res| {
534 let _ = res.insert(Box::new(cb));
535 });
536 self
537 }
538}
539
540#[derive(Debug, thiserror::Error)]
542pub enum MonitorError {
543 #[error("Worker errors:\n{0}")]
545 ExitError(#[from] ExitError),
546
547 #[error("Shutdown signal error: {0}")]
549 ShutdownSignal(#[from] std::io::Error),
550}
551
552impl fmt::Debug for ExitError {
553 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
554 std::fmt::Display::fmt(&self, f)
555 }
556}
557
558#[derive(thiserror::Error)]
560pub struct ExitError(pub Vec<MonitoredWorkerError>);
561
562impl std::fmt::Display for ExitError {
563 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
564 writeln!(f, "MonitoredErrors:")?;
565 for worker in &self.0 {
566 writeln!(f, " - Worker `{}`: {}", worker.ctx.name(), worker.error)?;
567 }
568 Ok(())
569 }
570}
571
572#[cfg(test)]
573#[cfg(feature = "json")]
574mod tests {
575 use super::*;
576 use crate::{
577 backend::{TaskSink, json::JsonStorage},
578 task::task_id::TaskId,
579 worker::context::WorkerContext,
580 };
581 use core::panic;
582 use std::{io, time::Duration};
583
584 use tokio::time::sleep;
585 use tower::limit::ConcurrencyLimitLayer;
586
587 use crate::{monitor::Monitor, worker::builder::WorkerBuilder};
588
589 #[tokio::test]
590 async fn basic_with_workers() {
591 let mut backend = JsonStorage::new_temp().unwrap();
592
593 for i in 0..10 {
594 backend.push(i).await.unwrap();
595 }
596
597 let monitor: Monitor = Monitor::new();
598 let monitor = monitor.register(move |index| {
599 WorkerBuilder::new(format!("rango-tango-{index}"))
600 .backend(backend.clone())
601 .build(move |r: u32, id: TaskId, w: WorkerContext| async move {
602 println!("{id:?}, {}", w.name());
603 tokio::time::sleep(Duration::from_secs(index as u64)).await;
604 Ok::<_, io::Error>(r)
605 })
606 });
607 let shutdown = monitor.shutdown.clone();
608 tokio::spawn(async move {
609 sleep(Duration::from_millis(1500)).await;
610 shutdown.start_shutdown();
611 });
612 monitor.run().await.unwrap();
613 }
614 #[tokio::test]
615 async fn test_monitor_run() {
616 let mut backend = JsonStorage::new(
617 "/var/folders/h_/sd1_gb5x73bbcxz38dts7pj80000gp/T/apalis-json-store-girmm9e36pz",
618 )
619 .unwrap();
620
621 for i in 0..10 {
622 backend.push(i).await.unwrap();
623 }
624
625 let monitor: Monitor = Monitor::new()
626 .register(move |index| {
627 WorkerBuilder::new(format!("rango-tango-{index}"))
628 .backend(backend.clone())
629 .build(move |r| async move {
630 if r % 2 == 0 {
631 panic!("Brrr")
632 }
633 })
634 })
635 .should_restart(|ctx, e, index| {
636 println!(
637 "Encountered error in {} with {e:?} for attempt {index}",
638 ctx.name()
639 );
640 if index > 3 {
641 return false;
642 }
643 return true;
644 })
645 .on_event(|wrk, e| {
646 println!("{}: {e:?}", wrk.name());
647 });
648 assert_eq!(monitor.workers.len(), 1);
649 let shutdown = monitor.shutdown.clone();
650
651 tokio::spawn(async move {
652 sleep(Duration::from_millis(5000)).await;
653 shutdown.start_shutdown();
654 });
655
656 let result = monitor.run().await;
657 assert!(
658 result.is_err_and(|e| matches!(e, MonitorError::ExitError(_))),
659 "Monitor did not return an error as expected"
660 );
661 }
662
663 #[tokio::test]
664 async fn test_monitor_register_multiple() {
665 let mut backend = JsonStorage::new_temp().unwrap();
666
667 for i in 0..10 {
668 backend.push(i).await.unwrap();
669 }
670
671 let monitor: Monitor = Monitor::new();
672
673 assert_send_sync::<Monitor>();
674
675 let monitor = monitor.on_event(|wrk, e| {
676 println!("{:?}, {e:?}", wrk.name());
677 });
678 let b = backend.clone();
679 let monitor = monitor
680 .register(move |index| {
681 WorkerBuilder::new(format!("worker0-{}", index))
682 .backend(backend.clone())
683 .layer(ConcurrencyLimitLayer::new(1))
684 .build(
685 move |request: i32, id: TaskId, w: WorkerContext| async move {
686 println!("{id:?}, {}", w.name());
687 tokio::time::sleep(Duration::from_secs(index as u64)).await;
688 Ok::<_, io::Error>(request)
689 },
690 )
691 })
692 .register(move |index| {
693 WorkerBuilder::new(format!("worker1-{}", index))
694 .backend(b.clone())
695 .layer(ConcurrencyLimitLayer::new(1))
696 .build(
697 move |request: i32, id: TaskId, w: WorkerContext| async move {
698 println!("{id:?}, {}", w.name());
699 tokio::time::sleep(Duration::from_secs(index as u64)).await;
700 Ok::<_, io::Error>(request)
701 },
702 )
703 });
704 assert_eq!(monitor.workers.len(), 2);
705 let shutdown = monitor.shutdown.clone();
706
707 tokio::spawn(async move {
708 sleep(Duration::from_millis(5000)).await;
709 shutdown.start_shutdown();
710 });
711
712 let result = monitor.run().await;
713 assert!(result.is_ok());
714 }
715
716 fn assert_send_sync<T: Send + Sync>() {}
717}