1use std::{
49 any::type_name,
50 fmt,
51 future::Future,
52 pin::Pin,
53 sync::{
54 atomic::{AtomicBool, AtomicUsize, Ordering},
55 Arc, Mutex,
56 },
57 task::{Context, Poll, Waker},
58};
59
60use crate::{
61 error::{WorkerError, WorkerStateError},
62 monitor::shutdown::Shutdown,
63 task::{data::MissingDataError, Task},
64 task_fn::FromRequest,
65 worker::{
66 event::{Event, EventHandler},
67 state::{InnerWorkerState, WorkerState},
68 },
69};
70
71#[derive(Clone)]
78pub struct WorkerContext {
79 pub(super) name: Arc<String>,
80 task_count: Arc<AtomicUsize>,
81 waker: Arc<Mutex<Option<Waker>>>,
82 state: Arc<WorkerState>,
83 pub(crate) shutdown: Option<Shutdown>,
84 event_handler: EventHandler,
85 pub(super) is_ready: Arc<AtomicBool>,
86 pub(super) service: &'static str,
87}
88
89impl fmt::Debug for WorkerContext {
90 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91 f.debug_struct("WorkerContext")
92 .field("shutdown", &["Shutdown handle"])
93 .field("task_count", &self.task_count)
94 .field("state", &self.state.load(Ordering::SeqCst))
95 .field("service", &self.service)
96 .field("is_ready", &self.is_ready)
97 .finish()
98 }
99}
100
101#[pin_project::pin_project(PinnedDrop)]
103#[derive(Debug)]
104pub struct Tracked<F> {
105 ctx: WorkerContext,
106 #[pin]
107 task: F,
108}
109
110impl<F: Future> Future for Tracked<F> {
111 type Output = F::Output;
112
113 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> {
114 let this = self.project();
115
116 match this.task.poll(cx) {
117 res @ Poll::Ready(_) => res,
118 Poll::Pending => Poll::Pending,
119 }
120 }
121}
122
123#[pin_project::pinned_drop]
124impl<F> PinnedDrop for Tracked<F> {
125 fn drop(self: Pin<&mut Self>) {
126 self.ctx.end_task();
127 }
128}
129
130impl WorkerContext {
131 pub fn new<S>(name: &str) -> Self {
133 Self {
134 name: Arc::new(name.to_owned()),
135 service: type_name::<S>(),
136 task_count: Default::default(),
137 waker: Default::default(),
138 state: Default::default(),
139 shutdown: Default::default(),
140 event_handler: Arc::new(Box::new(|_, _| {
141 })),
143 is_ready: Default::default(),
144 }
145 }
146
147 pub fn name(&self) -> &String {
149 &self.name
150 }
151
152 pub fn start(&mut self) -> Result<(), WorkerError> {
154 let current_state = self.state.load(Ordering::SeqCst);
155 if current_state != InnerWorkerState::Pending {
156 return Err(WorkerError::StateError(WorkerStateError::AlreadyStarted));
157 }
158 self.state
159 .store(InnerWorkerState::Running, Ordering::SeqCst);
160 self.is_ready.store(false, Ordering::SeqCst);
161 self.emit(&Event::Start);
162 info!("Worker {} started", self.name());
163 Ok(())
164 }
165
166 pub fn restart(&mut self) -> Result<(), WorkerError> {
168 self.state
169 .store(InnerWorkerState::Pending, Ordering::SeqCst);
170 self.is_ready.store(false, Ordering::SeqCst);
171 info!("Worker {} restarted", self.name());
172 Ok(())
173 }
174
175 pub fn track<F: Future>(&self, task: F) -> Tracked<F> {
177 self.start_task();
178 Tracked {
179 ctx: self.clone(),
180 task,
181 }
182 }
183 pub fn pause(&self) -> Result<(), WorkerError> {
185 if !self.is_running() {
186 return Err(WorkerError::StateError(WorkerStateError::NotRunning));
187 }
188 self.state.store(InnerWorkerState::Paused, Ordering::SeqCst);
189 info!("Worker {} paused", self.name());
190 Ok(())
191 }
192
193 pub fn resume(&self) -> Result<(), WorkerError> {
195 if !self.is_paused() {
196 return Err(WorkerError::StateError(WorkerStateError::NotPaused));
197 }
198 if self.is_shutting_down() {
199 return Err(WorkerError::StateError(WorkerStateError::ShuttingDown));
200 }
201 self.state
202 .store(InnerWorkerState::Running, Ordering::SeqCst);
203 info!("Worker {} resumed", self.name());
204 Ok(())
205 }
206
207 pub fn stop(&self) -> Result<(), WorkerError> {
209 let current_state = self.state.load(Ordering::SeqCst);
210 if current_state == InnerWorkerState::Pending {
211 return Err(WorkerError::StateError(WorkerStateError::NotStarted));
212 }
213 self.state
214 .store(InnerWorkerState::Stopped, Ordering::SeqCst);
215 self.wake();
216 self.emit_ref(&Event::Stop);
217 info!("Worker {} stopped", self.name());
218 Ok(())
219 }
220
221 pub fn is_ready(&self) -> bool {
223 self.is_running() && !self.is_shutting_down() && self.is_ready.load(Ordering::SeqCst)
224 }
225
226 pub fn get_service(&self) -> &str {
228 &self.service
229 }
230
231 pub fn is_running(&self) -> bool {
233 self.state.load(Ordering::SeqCst) == InnerWorkerState::Running
234 }
235
236 pub fn is_pending(&self) -> bool {
238 self.state.load(Ordering::SeqCst) == InnerWorkerState::Pending
239 }
240
241 pub fn is_paused(&self) -> bool {
243 self.state.load(Ordering::SeqCst) == InnerWorkerState::Paused
244 }
245
246 pub fn task_count(&self) -> usize {
249 self.task_count.load(Ordering::Relaxed)
250 }
251
252 pub fn has_pending_tasks(&self) -> bool {
254 self.task_count.load(Ordering::Relaxed) > 0
255 }
256
257 pub fn is_shutting_down(&self) -> bool {
259 self.shutdown
260 .as_ref()
261 .map(|s| !self.is_running() || s.is_shutting_down())
262 .unwrap_or(!self.is_running())
263 }
264
265 pub fn emit(&mut self, event: &Event) {
267 self.emit_ref(event);
268 }
269
270 fn emit_ref(&self, event: &Event) {
271 let handler = self.event_handler.as_ref();
272 handler(self, event);
273 }
274
275 pub fn wrap_listener<F: Fn(&WorkerContext, &Event) + Send + Sync + 'static>(&mut self, f: F) {
277 let cur = self.event_handler.clone();
278 let new: Box<dyn Fn(&WorkerContext, &Event) + Send + Sync + 'static> =
279 Box::new(move |ctx, ev| {
280 f(&ctx, &ev);
281 cur(&ctx, &ev);
282 });
283 self.event_handler = Arc::new(new);
284 }
285
286 pub(crate) fn add_waker(&self, cx: &mut Context<'_>) {
287 if let Ok(mut waker_guard) = self.waker.lock() {
288 if waker_guard
289 .as_ref()
290 .map_or(true, |stored_waker| !stored_waker.will_wake(cx.waker()))
291 {
292 *waker_guard = Some(cx.waker().clone());
293 }
294 }
295 }
296
297 fn has_recent_waker(&self, cx: &Context<'_>) -> bool {
299 if let Ok(waker_guard) = self.waker.lock() {
300 if let Some(stored_waker) = &*waker_guard {
301 return stored_waker.will_wake(cx.waker());
302 }
303 }
304 false
305 }
306
307 fn start_task(&self) {
308 self.task_count.fetch_add(1, Ordering::Relaxed);
309 }
310
311 fn end_task(&self) {
312 if self.task_count.fetch_sub(1, Ordering::Relaxed) == 1 {
313 self.wake();
314 }
315 }
316
317 pub(crate) fn wake(&self) {
318 if let Ok(waker) = self.waker.lock() {
319 if let Some(waker) = &*waker {
320 waker.wake_by_ref();
321 }
322 }
323 }
324}
325
326impl Future for WorkerContext {
327 type Output = Result<(), WorkerError>;
328
329 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
330 let task_count = self.task_count.load(Ordering::Relaxed);
331 let state = self.state.load(Ordering::SeqCst);
332 if state == InnerWorkerState::Pending {
333 return Poll::Ready(Err(WorkerError::StateError(WorkerStateError::NotStarted)));
334 }
335 if self.is_shutting_down() && task_count == 0 {
336 Poll::Ready(Ok(()))
337 } else {
338 if !self.has_recent_waker(cx) {
339 self.add_waker(cx);
340 }
341 Poll::Pending
342 }
343 }
344}
345
346impl<Args: Sync, Ctx: Sync, IdType: Sync + Send> FromRequest<Task<Args, Ctx, IdType>>
347 for WorkerContext
348{
349 type Error = MissingDataError;
350 async fn from_request(task: &Task<Args, Ctx, IdType>) -> Result<Self, Self::Error> {
351 task.parts.data.get_checked().cloned()
352 }
353}
354
355impl Drop for WorkerContext {
356 fn drop(&mut self) {
357 if Arc::strong_count(&self.state) > 1 {
358 return;
360 }
361 if self.is_running() {
362 eprintln!(
363 "Worker '{}' is being dropped while running with `{}` tasks. Consider calling stop() before dropping.",
364 self.name(),
365 self.task_count()
366 );
367 }
368 }
369}