1use crate::restart::RestartPolicy;
4use crate::supervisor_common::run_worker;
5use crate::types::ChildId;
6use async_trait::async_trait;
7use std::fmt;
8use std::sync::Arc;
9use tokio::sync::mpsc;
10use tokio::task::JoinHandle;
11
12#[async_trait]
15pub trait Worker: Send + Sync + 'static {
16 type Error: std::error::Error + Send + Sync + 'static;
18
19 async fn run(&mut self) -> Result<(), Self::Error>;
21
22 async fn initialize(&mut self) -> Result<(), Self::Error> {
24 Ok(())
25 }
26
27 async fn shutdown(&mut self) -> Result<(), Self::Error> {
29 Ok(())
30 }
31}
32
33pub(crate) struct WorkerSpec<W: Worker> {
35 pub id: ChildId,
36 pub worker_factory: Arc<dyn Fn() -> W + Send + Sync>,
37 pub restart_policy: RestartPolicy,
38}
39
40impl<W: Worker> Clone for WorkerSpec<W> {
41 fn clone(&self) -> Self {
42 Self {
43 id: self.id.clone(),
44 worker_factory: Arc::clone(&self.worker_factory),
45 restart_policy: self.restart_policy,
46 }
47 }
48}
49
50impl<W: Worker> WorkerSpec<W> {
51 pub(crate) fn new(
52 id: impl Into<String>,
53 factory: impl Fn() -> W + Send + Sync + 'static,
54 restart_policy: RestartPolicy,
55 ) -> Self {
56 Self {
57 id: id.into(),
58 worker_factory: Arc::new(factory),
59 restart_policy,
60 }
61 }
62
63 pub(crate) fn create_worker(&self) -> W {
64 (self.worker_factory)()
65 }
66}
67
68impl<W: Worker> fmt::Debug for WorkerSpec<W> {
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70 f.debug_struct("WorkerSpec")
71 .field("id", &self.id)
72 .field("restart_policy", &self.restart_policy)
73 .finish()
74 }
75}
76
77pub(crate) struct WorkerProcess<W: Worker> {
79 pub spec: WorkerSpec<W>,
80 pub handle: Option<JoinHandle<()>>,
81}
82
83impl<W: Worker> WorkerProcess<W> {
84 pub(crate) fn spawn<Cmd>(
85 spec: WorkerSpec<W>,
86 supervisor_name: String,
87 control_tx: mpsc::UnboundedSender<Cmd>,
88 ) -> Self
89 where
90 Cmd: From<WorkerTermination> + Send + 'static,
91 {
92 let worker = spec.create_worker();
93 let worker_id = spec.id.clone();
94 let handle = tokio::spawn(async move {
95 run_worker(supervisor_name, worker_id, worker, control_tx, None).await;
96 });
97
98 Self {
99 spec,
100 handle: Some(handle),
101 }
102 }
103
104 pub(crate) fn spawn_with_link<Cmd>(
106 spec: WorkerSpec<W>,
107 supervisor_name: String,
108 control_tx: mpsc::UnboundedSender<Cmd>,
109 init_tx: tokio::sync::oneshot::Sender<Result<(), String>>,
110 ) -> Self
111 where
112 Cmd: From<WorkerTermination> + Send + 'static,
113 {
114 let worker = spec.create_worker();
115 let worker_id = spec.id.clone();
116 let handle = tokio::spawn(async move {
117 run_worker(
118 supervisor_name,
119 worker_id,
120 worker,
121 control_tx,
122 Some(init_tx),
123 )
124 .await;
125 });
126
127 Self {
128 spec,
129 handle: Some(handle),
130 }
131 }
132
133 pub(crate) async fn stop(&mut self) {
134 if let Some(handle) = self.handle.take() {
135 handle.abort();
136 let _ = handle.await;
137 }
138 }
139}
140
141impl<W: Worker> Drop for WorkerProcess<W> {
142 fn drop(&mut self) {
143 if let Some(handle) = self.handle.take() {
144 handle.abort();
145 }
146 }
147}
148
149pub(crate) use crate::supervisor_common::WorkerTermination;
151
152#[derive(Debug)]
154pub enum WorkerError {
155 CommandChannelClosed(String),
157 WorkerPanicked(String),
159 WorkerFailed(String),
161}
162
163impl fmt::Display for WorkerError {
164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 match self {
166 WorkerError::CommandChannelClosed(name) => {
167 write!(f, "command channel to {} is closed", name)
168 }
169 WorkerError::WorkerPanicked(name) => {
170 write!(f, "worker {} panicked", name)
171 }
172 WorkerError::WorkerFailed(msg) => {
173 write!(f, "worker failed: {}", msg)
174 }
175 }
176 }
177}
178
179impl std::error::Error for WorkerError {}