use crate::restart::RestartPolicy;
use crate::supervisor_common::run_worker;
use crate::types::ChildId;
use async_trait::async_trait;
use std::fmt;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
#[async_trait]
pub trait Worker: Send + Sync + 'static {
type Error: std::error::Error + Send + Sync + 'static;
async fn run(&mut self) -> Result<(), Self::Error>;
async fn initialize(&mut self) -> Result<(), Self::Error> {
Ok(())
}
async fn shutdown(&mut self) -> Result<(), Self::Error> {
Ok(())
}
}
pub(crate) struct WorkerSpec<W: Worker> {
pub id: ChildId,
pub worker_factory: Arc<dyn Fn() -> W + Send + Sync>,
pub restart_policy: RestartPolicy,
}
impl<W: Worker> Clone for WorkerSpec<W> {
fn clone(&self) -> Self {
Self {
id: self.id.clone(),
worker_factory: Arc::clone(&self.worker_factory),
restart_policy: self.restart_policy,
}
}
}
impl<W: Worker> WorkerSpec<W> {
pub(crate) fn new(
id: impl Into<String>,
factory: impl Fn() -> W + Send + Sync + 'static,
restart_policy: RestartPolicy,
) -> Self {
Self {
id: id.into(),
worker_factory: Arc::new(factory),
restart_policy,
}
}
pub(crate) fn create_worker(&self) -> W {
(self.worker_factory)()
}
}
impl<W: Worker> fmt::Debug for WorkerSpec<W> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WorkerSpec")
.field("id", &self.id)
.field("restart_policy", &self.restart_policy)
.finish_non_exhaustive()
}
}
pub(crate) struct WorkerProcess<W: Worker> {
pub spec: WorkerSpec<W>,
pub handle: Option<JoinHandle<()>>,
}
impl<W: Worker> WorkerProcess<W> {
pub(crate) fn spawn<Cmd>(
spec: WorkerSpec<W>,
supervisor_name: String,
control_tx: mpsc::UnboundedSender<Cmd>,
) -> Self
where
Cmd: From<WorkerTermination> + Send + 'static,
{
let worker = spec.create_worker();
let worker_id = spec.id.clone();
let handle = tokio::spawn(async move {
run_worker(supervisor_name, worker_id, worker, control_tx, None).await;
});
Self {
spec,
handle: Some(handle),
}
}
pub(crate) fn spawn_with_link<Cmd>(
spec: WorkerSpec<W>,
supervisor_name: String,
control_tx: mpsc::UnboundedSender<Cmd>,
init_tx: tokio::sync::oneshot::Sender<Result<(), String>>,
) -> Self
where
Cmd: From<WorkerTermination> + Send + 'static,
{
let worker = spec.create_worker();
let worker_id = spec.id.clone();
let handle = tokio::spawn(async move {
run_worker(
supervisor_name,
worker_id,
worker,
control_tx,
Some(init_tx),
)
.await;
});
Self {
spec,
handle: Some(handle),
}
}
pub(crate) async fn stop(&mut self) {
if let Some(handle) = self.handle.take() {
handle.abort();
let _join_result = handle.await;
}
}
}
impl<W: Worker> Drop for WorkerProcess<W> {
fn drop(&mut self) {
if let Some(handle) = self.handle.take() {
handle.abort();
}
}
}
pub(crate) use crate::supervisor_common::WorkerTermination;
#[derive(Debug)]
pub enum WorkerError {
CommandChannelClosed(String),
WorkerPanicked(String),
WorkerFailed(String),
}
impl fmt::Display for WorkerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WorkerError::CommandChannelClosed(name) => {
write!(f, "command channel to {name} is closed")
}
WorkerError::WorkerPanicked(name) => {
write!(f, "worker {name} panicked")
}
WorkerError::WorkerFailed(msg) => {
write!(f, "worker failed: {msg}")
}
}
}
}
impl std::error::Error for WorkerError {}