Skip to main content

ash_flare/
worker.rs

1//! Worker trait and related types
2
3use crate::restart::RestartPolicy;
4use crate::supervisor_common::run_worker;
5use crate::types::ChildId;
6use async_trait::async_trait;
7use std::fmt;
8use std::sync::Arc;
9use tokio::sync::mpsc;
10use tokio::task::JoinHandle;
11
12/// A trait that all workers must implement to work with the supervisor tree.
13/// This allows for generic workers that can handle any type of work.
14#[async_trait]
15pub trait Worker: Send + Sync + 'static {
16    /// The type of error this worker can return
17    type Error: std::error::Error + Send + Sync + 'static;
18
19    /// Run the worker's main loop - this should run until completion or error
20    async fn run(&mut self) -> Result<(), Self::Error>;
21
22    /// Called when the worker is initialized
23    async fn initialize(&mut self) -> Result<(), Self::Error> {
24        Ok(())
25    }
26
27    /// Called when the worker is being shut down
28    async fn shutdown(&mut self) -> Result<(), Self::Error> {
29        Ok(())
30    }
31}
32
33/// Specification for creating and restarting a worker
34pub(crate) struct WorkerSpec<W: Worker> {
35    pub id: ChildId,
36    pub worker_factory: Arc<dyn Fn() -> W + Send + Sync>,
37    pub restart_policy: RestartPolicy,
38}
39
40impl<W: Worker> Clone for WorkerSpec<W> {
41    fn clone(&self) -> Self {
42        Self {
43            id: self.id.clone(),
44            worker_factory: Arc::clone(&self.worker_factory),
45            restart_policy: self.restart_policy,
46        }
47    }
48}
49
50impl<W: Worker> WorkerSpec<W> {
51    pub(crate) fn new(
52        id: impl Into<String>,
53        factory: impl Fn() -> W + Send + Sync + 'static,
54        restart_policy: RestartPolicy,
55    ) -> Self {
56        Self {
57            id: id.into(),
58            worker_factory: Arc::new(factory),
59            restart_policy,
60        }
61    }
62
63    pub(crate) fn create_worker(&self) -> W {
64        (self.worker_factory)()
65    }
66}
67
68impl<W: Worker> fmt::Debug for WorkerSpec<W> {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        f.debug_struct("WorkerSpec")
71            .field("id", &self.id)
72            .field("restart_policy", &self.restart_policy)
73            .finish()
74    }
75}
76
77/// Running worker process with its specification and task handle
78pub(crate) struct WorkerProcess<W: Worker> {
79    pub spec: WorkerSpec<W>,
80    pub handle: Option<JoinHandle<()>>,
81}
82
83impl<W: Worker> WorkerProcess<W> {
84    pub(crate) fn spawn<Cmd>(
85        spec: WorkerSpec<W>,
86        supervisor_name: String,
87        control_tx: mpsc::UnboundedSender<Cmd>,
88    ) -> Self
89    where
90        Cmd: From<WorkerTermination> + Send + 'static,
91    {
92        let worker = spec.create_worker();
93        let worker_id = spec.id.clone();
94        let handle = tokio::spawn(async move {
95            run_worker(supervisor_name, worker_id, worker, control_tx, None).await;
96        });
97
98        Self {
99            spec,
100            handle: Some(handle),
101        }
102    }
103
104    /// Spawns a worker with linked initialization handshake
105    pub(crate) fn spawn_with_link<Cmd>(
106        spec: WorkerSpec<W>,
107        supervisor_name: String,
108        control_tx: mpsc::UnboundedSender<Cmd>,
109        init_tx: tokio::sync::oneshot::Sender<Result<(), String>>,
110    ) -> Self
111    where
112        Cmd: From<WorkerTermination> + Send + 'static,
113    {
114        let worker = spec.create_worker();
115        let worker_id = spec.id.clone();
116        let handle = tokio::spawn(async move {
117            run_worker(
118                supervisor_name,
119                worker_id,
120                worker,
121                control_tx,
122                Some(init_tx),
123            )
124            .await;
125        });
126
127        Self {
128            spec,
129            handle: Some(handle),
130        }
131    }
132
133    pub(crate) async fn stop(&mut self) {
134        if let Some(handle) = self.handle.take() {
135            handle.abort();
136            let _ = handle.await;
137        }
138    }
139}
140
141impl<W: Worker> Drop for WorkerProcess<W> {
142    fn drop(&mut self) {
143        if let Some(handle) = self.handle.take() {
144            handle.abort();
145        }
146    }
147}
148
149// Re-export WorkerTermination from supervisor_common
150pub(crate) use crate::supervisor_common::WorkerTermination;
151
152/// Errors returned by worker operations.
153#[derive(Debug)]
154pub enum WorkerError {
155    /// Command channel was closed unexpectedly
156    CommandChannelClosed(String),
157    /// Worker panicked during execution
158    WorkerPanicked(String),
159    /// Worker failed with an error
160    WorkerFailed(String),
161}
162
163impl fmt::Display for WorkerError {
164    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165        match self {
166            WorkerError::CommandChannelClosed(name) => {
167                write!(f, "command channel to {} is closed", name)
168            }
169            WorkerError::WorkerPanicked(name) => {
170                write!(f, "worker {} panicked", name)
171            }
172            WorkerError::WorkerFailed(msg) => {
173                write!(f, "worker failed: {}", msg)
174            }
175        }
176    }
177}
178
179impl std::error::Error for WorkerError {}