use std::{
fmt::{self, Debug, Formatter},
time::Duration,
};
use futures::{Future, FutureExt};
use graceful_shutdown::Shutdown;
use log::info;
use tower::Service;
use tracing::warn;
use crate::{
executor::{Executor, TokioExecutor},
job::Job,
request::JobRequest,
worker::{Worker, WorkerContext, WorkerId},
};
pub struct Monitor<E> {
shutdown: Shutdown,
worker_handles: Vec<WorkerId>,
timeout: Option<Duration>,
executor: E,
}
impl<E: Executor> Debug for Monitor<E> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Monitor")
.field("shutdown", &"[Graceful shutdown listener]")
.field("worker_handles", &self.worker_handles.iter().cloned())
.field("timeout", &self.timeout)
.field("executor", &std::any::type_name::<E>())
.finish()
}
}
impl<E: Executor + Send + Sync + 'static> Monitor<E> {
pub fn register<
Strm,
Serv: Service<JobRequest<J>>,
J: Job + 'static,
W: Worker<J, Service = Serv, Source = Strm> + 'static,
>(
mut self,
worker: W,
) -> Self
where
<Serv as Service<JobRequest<J>>>::Future: std::marker::Send,
{
let shutdown = self.shutdown.clone();
let worker_id = worker.id();
self.executor.spawn(
self.shutdown.graceful(
worker
.start(WorkerContext {
shutdown,
executor: self.executor.clone(),
worker_id: worker_id.clone(),
})
.map(|_| ()),
),
);
self.worker_handles.push(worker_id);
self
}
pub fn register_with_count<
Strm,
Serv: Service<JobRequest<J>>,
J: Job + 'static,
W: Worker<J, Service = Serv, Source = Strm> + 'static,
Call: Fn(u16) -> W,
>(
mut self,
count: u16,
caller: Call,
) -> Self
where
<Serv as Service<JobRequest<J>>>::Future: std::marker::Send,
{
for index in 0..count {
let worker = caller(index);
self = self.register(worker);
}
self
}
pub fn shutdown_timeout(mut self, duration: Duration) -> Self {
self.timeout = Some(duration);
self
}
pub async fn run_with_signal<S: Future<Output = std::io::Result<()>>>(
self,
signal: S,
) -> std::io::Result<()> {
self.shutdown.shutdown_after(signal).await?;
info!("Shutting down the system");
self.run().await?;
Ok(())
}
pub async fn run(self) -> std::io::Result<()> {
if let Some(timeout) = self.timeout {
if self.shutdown.with_timeout(timeout).await {
warn!("Shutdown timeout reached. Exiting forcefully");
return Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"Shutdown timeout reached. Exiting forcefully",
));
}
} else {
self.shutdown.await;
}
info!("Successfully shutdown monitor and all workers");
Ok(())
}
}
impl<E: Executor + Default> Default for Monitor<E> {
fn default() -> Self {
Self {
executor: E::default(),
shutdown: Shutdown::new(),
timeout: None,
worker_handles: Vec::new(),
}
}
}
impl Monitor<TokioExecutor> {
pub fn new() -> Self {
Self {
shutdown: Shutdown::new(),
worker_handles: Vec::new(),
timeout: None,
executor: TokioExecutor::new(),
}
}
pub fn executor<E: Executor>(self, executor: E) -> Monitor<E> {
Monitor {
shutdown: self.shutdown,
worker_handles: Vec::new(),
timeout: self.timeout,
executor,
}
}
}
#[cfg(test)]
mod tests {
use std::{
pin::Pin,
task::{Context, Poll},
};
use crate::{context::JobContext, job_fn::job_fn, worker::WorkerError};
use super::*;
use futures::Stream;
use tokio::time::sleep;
use tower::ServiceBuilder;
struct TestJob {}
impl Job for TestJob {
const NAME: &'static str = "TestJob";
}
struct TestWorker<S> {
_service: S,
}
async fn test_service(_req: TestJob, _ctx: JobContext) {}
#[async_trait::async_trait]
impl<S: Send> Worker<TestJob> for TestWorker<S> {
type Service = S;
type Source = TestSource;
fn id(&self) -> WorkerId {
WorkerId::new("test-worker")
}
async fn start<E: Executor + Send>(
self,
_ctx: WorkerContext<E>,
) -> Result<(), WorkerError> {
sleep(Duration::from_millis(100)).await;
Ok(())
}
}
struct TestSource {}
impl Stream for TestSource {
type Item = Result<TestJob, ()>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Pending
}
}
#[tokio::test]
async fn test_monitor_run() {
let monitor = Monitor::new()
.register(TestWorker {
_service: ServiceBuilder::new().service(job_fn(test_service)),
})
.shutdown_timeout(Duration::from_secs(1));
let shutdown = monitor.shutdown.clone();
tokio::spawn(async move {
sleep(Duration::from_millis(500)).await;
shutdown.shutdown();
});
let result = monitor.run().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_monitor_run_with_signal() {
let monitor = Monitor::new()
.register(TestWorker {
_service: ServiceBuilder::new().service(job_fn(test_service)),
})
.shutdown_timeout(Duration::from_secs(1));
let shutdown = monitor.shutdown.clone();
tokio::spawn(async move {
sleep(Duration::from_millis(500)).await;
shutdown.shutdown();
});
let result = monitor.run_with_signal(async { Ok(()) }).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_monitor_register() {
let monitor = Monitor::new();
assert_eq!(monitor.worker_handles.len(), 0);
let monitor = monitor.register(TestWorker {
_service: ServiceBuilder::new().service(job_fn(test_service)),
});
assert_eq!(monitor.worker_handles.len(), 1);
}
#[tokio::test]
async fn test_monitor_register_with_count() {
let monitor = Monitor::new();
assert_eq!(monitor.worker_handles.len(), 0);
let monitor = monitor.register_with_count(5, |_| TestWorker {
_service: ServiceBuilder::new().service(job_fn(test_service)),
});
assert_eq!(monitor.worker_handles.len(), 5);
}
}