1use std::{
49 any::type_name,
50 fmt,
51 future::Future,
52 pin::Pin,
53 sync::{
54 Arc, Mutex,
55 atomic::{AtomicBool, AtomicUsize, Ordering},
56 },
57 task::{Context, Poll, Waker},
58};
59
60use crate::{
61 error::{WorkerError, WorkerStateError},
62 monitor::shutdown::Shutdown,
63 task::{Task, data::MissingDataError},
64 task_fn::FromRequest,
65 worker::{
66 event::{Event, EventListener, RawEventListener},
67 state::{InnerWorkerState, WorkerState},
68 },
69};
70
71#[derive(Clone)]
78pub struct WorkerContext {
79 pub(super) name: Arc<String>,
80 task_count: Arc<AtomicUsize>,
81 waker: Arc<Mutex<Option<Waker>>>,
82 state: Arc<WorkerState>,
83 pub(crate) shutdown: Option<Shutdown>,
84 event_handler: EventListener,
85 pub(super) is_ready: Arc<AtomicBool>,
86 pub(super) service: &'static str,
87}
88
89impl fmt::Debug for WorkerContext {
90 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91 f.debug_struct("WorkerContext")
92 .field("shutdown", &["Shutdown handle"])
93 .field("task_count", &self.task_count)
94 .field("state", &self.state.load(Ordering::SeqCst))
95 .field("service", &self.service)
96 .field("is_ready", &self.is_ready)
97 .finish()
98 }
99}
100
101#[pin_project::pin_project(PinnedDrop)]
103#[derive(Debug)]
104pub struct Tracked<F> {
105 ctx: WorkerContext,
106 #[pin]
107 task: F,
108}
109
110impl<F: Future> Future for Tracked<F> {
111 type Output = F::Output;
112
113 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> {
114 let this = self.project();
115
116 match this.task.poll(cx) {
117 res @ Poll::Ready(_) => res,
118 Poll::Pending => Poll::Pending,
119 }
120 }
121}
122
123#[pin_project::pinned_drop]
124impl<F> PinnedDrop for Tracked<F> {
125 fn drop(self: Pin<&mut Self>) {
126 self.ctx.end_task();
127 }
128}
129
130impl WorkerContext {
131 #[must_use]
133 pub fn new<S>(name: &str) -> Self {
134 Self {
135 name: Arc::new(name.to_owned()),
136 service: type_name::<S>(),
137 task_count: Default::default(),
138 waker: Default::default(),
139 state: Default::default(),
140 shutdown: Default::default(),
141 event_handler: Arc::new(Box::new(|_, _| {
142 })),
144 is_ready: Default::default(),
145 }
146 }
147
148 #[must_use]
150 pub fn name(&self) -> &String {
151 &self.name
152 }
153
154 pub fn start(&mut self) -> Result<(), WorkerError> {
156 let current_state = self.state.load(Ordering::SeqCst);
157 if current_state != InnerWorkerState::Pending {
158 return Err(WorkerError::StateError(WorkerStateError::AlreadyStarted));
159 }
160 self.state
161 .store(InnerWorkerState::Running, Ordering::SeqCst);
162 self.is_ready.store(false, Ordering::SeqCst);
163 self.emit(&Event::Start);
164 info!("Worker {} started", self.name());
165 Ok(())
166 }
167
168 pub fn restart(&mut self) -> Result<(), WorkerError> {
170 self.state
171 .store(InnerWorkerState::Pending, Ordering::SeqCst);
172 self.is_ready.store(false, Ordering::SeqCst);
173 info!("Worker {} restarted", self.name());
174 Ok(())
175 }
176
177 pub fn track<F: Future>(&self, task: F) -> Tracked<F> {
179 self.start_task();
180 Tracked {
181 ctx: self.clone(),
182 task,
183 }
184 }
185 pub fn pause(&self) -> Result<(), WorkerError> {
187 if !self.is_running() {
188 return Err(WorkerError::StateError(WorkerStateError::NotRunning));
189 }
190 self.state.store(InnerWorkerState::Paused, Ordering::SeqCst);
191 info!("Worker {} paused", self.name());
192 Ok(())
193 }
194
195 pub fn resume(&self) -> Result<(), WorkerError> {
197 if !self.is_paused() {
198 return Err(WorkerError::StateError(WorkerStateError::NotPaused));
199 }
200 if self.is_shutting_down() {
201 return Err(WorkerError::StateError(WorkerStateError::ShuttingDown));
202 }
203 self.state
204 .store(InnerWorkerState::Running, Ordering::SeqCst);
205 self.wake();
206 info!("Worker {} resumed", self.name());
207 Ok(())
208 }
209
210 pub fn stop(&self) -> Result<(), WorkerError> {
212 let current_state = self.state.load(Ordering::SeqCst);
213 if current_state == InnerWorkerState::Pending {
214 return Err(WorkerError::StateError(WorkerStateError::NotStarted));
215 }
216 self.state
217 .store(InnerWorkerState::Stopped, Ordering::SeqCst);
218 self.wake();
219 self.emit_ref(&Event::Stop);
220 info!("Worker {} stopped", self.name());
221 Ok(())
222 }
223
224 #[must_use]
226 pub fn is_ready(&self) -> bool {
227 self.is_running() && !self.is_shutting_down() && self.is_ready.load(Ordering::SeqCst)
228 }
229
230 #[must_use]
232 pub fn get_service(&self) -> &str {
233 self.service
234 }
235
236 #[must_use]
238 pub fn is_running(&self) -> bool {
239 self.state.load(Ordering::SeqCst) == InnerWorkerState::Running
240 }
241
242 #[must_use]
244 pub fn is_pending(&self) -> bool {
245 self.state.load(Ordering::SeqCst) == InnerWorkerState::Pending
246 }
247
248 #[must_use]
250 pub fn is_paused(&self) -> bool {
251 self.state.load(Ordering::SeqCst) == InnerWorkerState::Paused
252 }
253
254 #[must_use]
256 pub fn is_stopped(&self) -> bool {
257 self.state.load(Ordering::SeqCst) == InnerWorkerState::Stopped
258 }
259
260 #[must_use]
263 pub fn task_count(&self) -> usize {
264 self.task_count.load(Ordering::Relaxed)
265 }
266
267 #[must_use]
269 pub fn has_pending_tasks(&self) -> bool {
270 self.task_count.load(Ordering::Relaxed) > 0
271 }
272
273 #[must_use]
275 pub fn is_shutting_down(&self) -> bool {
276 self.is_stopped() || self.shutdown.as_ref().is_some_and(|s| s.is_shutting_down())
277 }
278
279 pub fn emit(&mut self, event: &Event) {
281 self.emit_ref(event);
282 }
283
284 fn emit_ref(&self, event: &Event) {
285 let handler = self.event_handler.as_ref();
286 handler(self, event);
287 }
288
289 pub fn wrap_listener<F: Fn(&Self, &Event) + Send + Sync + 'static>(&mut self, f: F) {
291 let cur = self.event_handler.clone();
292 let new: RawEventListener = Box::new(move |ctx, ev| {
293 f(ctx, ev);
294 cur(ctx, ev);
295 });
296 self.event_handler = Arc::new(new);
297 }
298
299 pub(crate) fn add_waker(&self, cx: &Context<'_>) {
300 if let Ok(mut waker_guard) = self.waker.lock() {
301 if waker_guard
302 .as_ref()
303 .is_none_or(|stored_waker| !stored_waker.will_wake(cx.waker()))
304 {
305 *waker_guard = Some(cx.waker().clone());
306 }
307 }
308 }
309
310 fn has_recent_waker(&self, cx: &Context<'_>) -> bool {
312 if let Ok(waker_guard) = self.waker.lock() {
313 if let Some(stored_waker) = &*waker_guard {
314 return stored_waker.will_wake(cx.waker());
315 }
316 }
317 false
318 }
319
320 fn start_task(&self) {
321 self.task_count.fetch_add(1, Ordering::Relaxed);
322 }
323
324 fn end_task(&self) {
325 if self.task_count.fetch_sub(1, Ordering::Relaxed) == 1 {
326 self.wake();
327 }
328 }
329
330 pub(crate) fn wake(&self) {
331 if let Ok(waker) = self.waker.lock() {
332 if let Some(waker) = &*waker {
333 waker.wake_by_ref();
334 }
335 }
336 }
337}
338
339impl Future for WorkerContext {
340 type Output = Result<(), WorkerError>;
341
342 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
343 let task_count = self.task_count.load(Ordering::Relaxed);
344 let state = self.state.load(Ordering::SeqCst);
345 if state == InnerWorkerState::Pending {
346 return Poll::Ready(Err(WorkerError::StateError(WorkerStateError::NotStarted)));
347 }
348 if self.is_shutting_down() && task_count == 0 {
349 Poll::Ready(Ok(()))
350 } else {
351 if !self.has_recent_waker(cx) {
352 self.add_waker(cx);
353 }
354 Poll::Pending
355 }
356 }
357}
358
359impl<Args: Sync, Ctx: Sync, IdType: Sync + Send> FromRequest<Task<Args, Ctx, IdType>>
360 for WorkerContext
361{
362 type Error = MissingDataError;
363 async fn from_request(task: &Task<Args, Ctx, IdType>) -> Result<Self, Self::Error> {
364 task.parts.data.get_checked().cloned()
365 }
366}
367
368impl Drop for WorkerContext {
369 fn drop(&mut self) {
370 if Arc::strong_count(&self.state) > 1 {
371 return;
373 }
374 if self.is_running() && self.has_pending_tasks() {
375 error!(
376 "Worker '{}' is being dropped while running with `{}` tasks. Consider calling stop() before dropping.",
377 self.name(),
378 self.task_count()
379 );
380 }
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use crate::{
387 backend::memory::MemoryStorage, error::BoxDynError, worker::builder::WorkerBuilder,
388 };
389 use std::time::Duration;
390
391 use super::*;
392
393 #[tokio::test]
394 async fn test_worker_state_transitions() {
395 let backend = MemoryStorage::<u32>::new();
396
397 let worker = WorkerBuilder::new("test-worker")
398 .backend(backend)
399 .build(|_task: u32| async { Ok::<_, BoxDynError>(()) });
400
401 let mut ctx = WorkerContext::new::<()>("test-worker");
402 let ctx_handle = ctx.clone();
403
404 let worker_handle = tokio::spawn(async move { worker.run_with_ctx(&mut ctx).await });
405 tokio::time::sleep(Duration::from_millis(50)).await;
406
407 assert!(ctx_handle.is_running());
409 assert!(!ctx_handle.is_shutting_down());
410 assert!(!ctx_handle.is_stopped());
411
412 ctx_handle.pause().unwrap();
414 assert!(ctx_handle.is_paused());
415 assert!(
416 !ctx_handle.is_shutting_down(),
417 "Paused worker should NOT be considered shutting down"
418 );
419
420 ctx_handle.resume().unwrap();
422 assert!(ctx_handle.is_running());
423 assert!(!ctx_handle.is_paused());
424
425 ctx_handle.stop().unwrap();
427 assert!(ctx_handle.is_stopped());
428 assert!(ctx_handle.is_shutting_down());
429
430 assert!(
432 matches!(
433 ctx_handle.resume(),
434 Err(WorkerError::StateError(WorkerStateError::NotPaused))
435 ),
436 "Resuming a stopped worker should fail with NotPaused error"
437 );
438
439 worker_handle.await.unwrap().unwrap();
440 }
441}