1use crate::restart::RestartPolicy;
4use crate::supervisor_common::run_worker;
5use crate::types::ChildId;
6use async_trait::async_trait;
7use std::fmt;
8use std::sync::Arc;
9use tokio::sync::mpsc;
10use tokio::task::JoinHandle;
11
12#[async_trait]
15pub trait Worker: Send + Sync + 'static {
16 type Error: std::error::Error + Send + Sync + 'static;
18
19 async fn run(&mut self) -> Result<(), Self::Error>;
21
22 async fn initialize(&mut self) -> Result<(), Self::Error> {
24 Ok(())
25 }
26
27 async fn shutdown(&mut self) -> Result<(), Self::Error> {
29 Ok(())
30 }
31}
32
33pub(crate) struct WorkerSpec<W: Worker> {
35 pub id: ChildId,
36 pub worker_factory: Arc<dyn Fn() -> W + Send + Sync>,
37 pub restart_policy: RestartPolicy,
38}
39
40impl<W: Worker> Clone for WorkerSpec<W> {
41 fn clone(&self) -> Self {
42 Self {
43 id: self.id.clone(),
44 worker_factory: Arc::clone(&self.worker_factory),
45 restart_policy: self.restart_policy,
46 }
47 }
48}
49
50impl<W: Worker> WorkerSpec<W> {
51 pub(crate) fn new(
52 id: impl Into<String>,
53 factory: impl Fn() -> W + Send + Sync + 'static,
54 restart_policy: RestartPolicy,
55 ) -> Self {
56 Self {
57 id: id.into(),
58 worker_factory: Arc::new(factory),
59 restart_policy,
60 }
61 }
62
63 pub(crate) fn create_worker(&self) -> W {
64 (self.worker_factory)()
65 }
66}
67
68impl<W: Worker> fmt::Debug for WorkerSpec<W> {
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70 f.debug_struct("WorkerSpec")
71 .field("id", &self.id)
72 .field("restart_policy", &self.restart_policy)
73 .finish()
74 }
75}
76
77pub(crate) struct WorkerProcess<W: Worker> {
79 pub spec: WorkerSpec<W>,
80 pub handle: Option<JoinHandle<()>>,
81}
82
83impl<W: Worker> WorkerProcess<W> {
84 pub(crate) fn spawn<Cmd>(
85 spec: WorkerSpec<W>,
86 supervisor_name: String,
87 control_tx: mpsc::UnboundedSender<Cmd>,
88 ) -> Self
89 where
90 Cmd: From<WorkerTermination> + Send + 'static,
91 {
92 let worker = spec.create_worker();
93 let worker_id = spec.id.clone();
94 let handle = tokio::spawn(async move {
95 run_worker(supervisor_name, worker_id, worker, control_tx).await;
96 });
97
98 Self {
99 spec,
100 handle: Some(handle),
101 }
102 }
103
104 pub(crate) async fn stop(&mut self) {
105 if let Some(handle) = self.handle.take() {
106 handle.abort();
107 let _ = handle.await;
108 }
109 }
110}
111
112impl<W: Worker> Drop for WorkerProcess<W> {
113 fn drop(&mut self) {
114 if let Some(handle) = self.handle.take() {
115 handle.abort();
116 }
117 }
118}
119
120pub(crate) use crate::supervisor_common::WorkerTermination;
122
123#[derive(Debug)]
125pub enum WorkerError {
126 CommandChannelClosed(String),
128 WorkerPanicked(String),
130 WorkerFailed(String),
132}
133
134impl fmt::Display for WorkerError {
135 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136 match self {
137 WorkerError::CommandChannelClosed(name) => {
138 write!(f, "command channel to {} is closed", name)
139 }
140 WorkerError::WorkerPanicked(name) => {
141 write!(f, "worker {} panicked", name)
142 }
143 WorkerError::WorkerFailed(msg) => {
144 write!(f, "worker failed: {}", msg)
145 }
146 }
147 }
148}
149
150impl std::error::Error for WorkerError {}