ash_flare/
worker.rs

1//! Worker trait and related types
2
3use crate::restart::RestartPolicy;
4use crate::supervisor_common::run_worker;
5use crate::types::ChildId;
6use async_trait::async_trait;
7use std::fmt;
8use std::sync::Arc;
9use tokio::sync::mpsc;
10use tokio::task::JoinHandle;
11
12/// A trait that all workers must implement to work with the supervisor tree.
13/// This allows for generic workers that can handle any type of work.
14#[async_trait]
15pub trait Worker: Send + Sync + 'static {
16    /// The type of error this worker can return
17    type Error: std::error::Error + Send + Sync + 'static;
18
19    /// Run the worker's main loop - this should run until completion or error
20    async fn run(&mut self) -> Result<(), Self::Error>;
21
22    /// Called when the worker is initialized
23    async fn initialize(&mut self) -> Result<(), Self::Error> {
24        Ok(())
25    }
26
27    /// Called when the worker is being shut down
28    async fn shutdown(&mut self) -> Result<(), Self::Error> {
29        Ok(())
30    }
31}
32
33/// Specification for creating and restarting a worker
34pub(crate) struct WorkerSpec<W: Worker> {
35    pub id: ChildId,
36    pub worker_factory: Arc<dyn Fn() -> W + Send + Sync>,
37    pub restart_policy: RestartPolicy,
38}
39
40impl<W: Worker> Clone for WorkerSpec<W> {
41    fn clone(&self) -> Self {
42        Self {
43            id: self.id.clone(),
44            worker_factory: Arc::clone(&self.worker_factory),
45            restart_policy: self.restart_policy,
46        }
47    }
48}
49
50impl<W: Worker> WorkerSpec<W> {
51    pub(crate) fn new(
52        id: impl Into<String>,
53        factory: impl Fn() -> W + Send + Sync + 'static,
54        restart_policy: RestartPolicy,
55    ) -> Self {
56        Self {
57            id: id.into(),
58            worker_factory: Arc::new(factory),
59            restart_policy,
60        }
61    }
62
63    pub(crate) fn create_worker(&self) -> W {
64        (self.worker_factory)()
65    }
66}
67
68impl<W: Worker> fmt::Debug for WorkerSpec<W> {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        f.debug_struct("WorkerSpec")
71            .field("id", &self.id)
72            .field("restart_policy", &self.restart_policy)
73            .finish()
74    }
75}
76
77/// Running worker process with its specification and task handle
78pub(crate) struct WorkerProcess<W: Worker> {
79    pub spec: WorkerSpec<W>,
80    pub handle: Option<JoinHandle<()>>,
81}
82
83impl<W: Worker> WorkerProcess<W> {
84    pub(crate) fn spawn<Cmd>(
85        spec: WorkerSpec<W>,
86        supervisor_name: String,
87        control_tx: mpsc::UnboundedSender<Cmd>,
88    ) -> Self
89    where
90        Cmd: From<WorkerTermination> + Send + 'static,
91    {
92        let worker = spec.create_worker();
93        let worker_id = spec.id.clone();
94        let handle = tokio::spawn(async move {
95            run_worker(supervisor_name, worker_id, worker, control_tx).await;
96        });
97
98        Self {
99            spec,
100            handle: Some(handle),
101        }
102    }
103
104    pub(crate) async fn stop(&mut self) {
105        if let Some(handle) = self.handle.take() {
106            handle.abort();
107            let _ = handle.await;
108        }
109    }
110}
111
112impl<W: Worker> Drop for WorkerProcess<W> {
113    fn drop(&mut self) {
114        if let Some(handle) = self.handle.take() {
115            handle.abort();
116        }
117    }
118}
119
120// Re-export WorkerTermination from supervisor_common
121pub(crate) use crate::supervisor_common::WorkerTermination;
122
123/// Errors returned by worker operations.
124#[derive(Debug)]
125pub enum WorkerError {
126    /// Command channel was closed unexpectedly
127    CommandChannelClosed(String),
128    /// Worker panicked during execution
129    WorkerPanicked(String),
130    /// Worker failed with an error
131    WorkerFailed(String),
132}
133
134impl fmt::Display for WorkerError {
135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136        match self {
137            WorkerError::CommandChannelClosed(name) => {
138                write!(f, "command channel to {} is closed", name)
139            }
140            WorkerError::WorkerPanicked(name) => {
141                write!(f, "worker {} panicked", name)
142            }
143            WorkerError::WorkerFailed(msg) => {
144                write!(f, "worker failed: {}", msg)
145            }
146        }
147    }
148}
149
150impl std::error::Error for WorkerError {}