1use std::{
49 any::type_name,
50 fmt,
51 future::Future,
52 pin::Pin,
53 sync::{
54 Arc, Mutex,
55 atomic::{AtomicBool, AtomicUsize, Ordering},
56 },
57 task::{Context, Poll, Waker},
58};
59
60use crate::{
61 error::{WorkerError, WorkerStateError},
62 monitor::shutdown::Shutdown,
63 task::{Task, data::MissingDataError},
64 task_fn::FromRequest,
65 worker::{
66 event::{Event, EventListener, RawEventListener},
67 state::{InnerWorkerState, WorkerState},
68 },
69};
70
71#[derive(Clone)]
78pub struct WorkerContext {
79 pub(super) name: Arc<String>,
80 task_count: Arc<AtomicUsize>,
81 waker: Arc<Mutex<Option<Waker>>>,
82 state: Arc<WorkerState>,
83 pub(crate) shutdown: Option<Shutdown>,
84 event_handler: EventListener,
85 pub(super) is_ready: Arc<AtomicBool>,
86 pub(super) service: &'static str,
87}
88
89impl fmt::Debug for WorkerContext {
90 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91 f.debug_struct("WorkerContext")
92 .field("shutdown", &["Shutdown handle"])
93 .field("task_count", &self.task_count)
94 .field("state", &self.state.load(Ordering::SeqCst))
95 .field("service", &self.service)
96 .field("is_ready", &self.is_ready)
97 .finish()
98 }
99}
100
101#[pin_project::pin_project(PinnedDrop)]
103#[derive(Debug)]
104pub struct Tracked<F> {
105 ctx: WorkerContext,
106 #[pin]
107 task: F,
108}
109
110impl<F: Future> Future for Tracked<F> {
111 type Output = F::Output;
112
113 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> {
114 let this = self.project();
115
116 match this.task.poll(cx) {
117 res @ Poll::Ready(_) => res,
118 Poll::Pending => Poll::Pending,
119 }
120 }
121}
122
123#[pin_project::pinned_drop]
124impl<F> PinnedDrop for Tracked<F> {
125 fn drop(self: Pin<&mut Self>) {
126 self.ctx.end_task();
127 }
128}
129
130impl WorkerContext {
131 #[must_use]
133 pub fn new<S>(name: &str) -> Self {
134 Self {
135 name: Arc::new(name.to_owned()),
136 service: type_name::<S>(),
137 task_count: Default::default(),
138 waker: Default::default(),
139 state: Default::default(),
140 shutdown: Default::default(),
141 event_handler: Arc::new(Box::new(|_, _| {
142 })),
144 is_ready: Default::default(),
145 }
146 }
147
148 #[must_use]
150 pub fn name(&self) -> &String {
151 &self.name
152 }
153
154 pub fn start(&mut self) -> Result<(), WorkerError> {
156 let current_state = self.state.load(Ordering::SeqCst);
157 if current_state != InnerWorkerState::Pending {
158 return Err(WorkerError::StateError(WorkerStateError::AlreadyStarted));
159 }
160 self.state
161 .store(InnerWorkerState::Running, Ordering::SeqCst);
162 self.is_ready.store(false, Ordering::SeqCst);
163 self.emit(&Event::Start);
164 info!("Worker {} started", self.name());
165 Ok(())
166 }
167
168 pub fn restart(&mut self) -> Result<(), WorkerError> {
170 self.state
171 .store(InnerWorkerState::Pending, Ordering::SeqCst);
172 self.is_ready.store(false, Ordering::SeqCst);
173 info!("Worker {} restarted", self.name());
174 Ok(())
175 }
176
177 pub fn track<F: Future>(&self, task: F) -> Tracked<F> {
179 self.start_task();
180 Tracked {
181 ctx: self.clone(),
182 task,
183 }
184 }
185 pub fn pause(&self) -> Result<(), WorkerError> {
187 if !self.is_running() {
188 return Err(WorkerError::StateError(WorkerStateError::NotRunning));
189 }
190 self.state.store(InnerWorkerState::Paused, Ordering::SeqCst);
191 info!("Worker {} paused", self.name());
192 Ok(())
193 }
194
195 pub fn resume(&self) -> Result<(), WorkerError> {
197 if !self.is_paused() {
198 return Err(WorkerError::StateError(WorkerStateError::NotPaused));
199 }
200 if self.is_shutting_down() {
201 return Err(WorkerError::StateError(WorkerStateError::ShuttingDown));
202 }
203 self.state
204 .store(InnerWorkerState::Running, Ordering::SeqCst);
205 info!("Worker {} resumed", self.name());
206 Ok(())
207 }
208
209 pub fn stop(&self) -> Result<(), WorkerError> {
211 let current_state = self.state.load(Ordering::SeqCst);
212 if current_state == InnerWorkerState::Pending {
213 return Err(WorkerError::StateError(WorkerStateError::NotStarted));
214 }
215 self.state
216 .store(InnerWorkerState::Stopped, Ordering::SeqCst);
217 self.wake();
218 self.emit_ref(&Event::Stop);
219 info!("Worker {} stopped", self.name());
220 Ok(())
221 }
222
223 #[must_use]
225 pub fn is_ready(&self) -> bool {
226 self.is_running() && !self.is_shutting_down() && self.is_ready.load(Ordering::SeqCst)
227 }
228
229 #[must_use]
231 pub fn get_service(&self) -> &str {
232 self.service
233 }
234
235 #[must_use]
237 pub fn is_running(&self) -> bool {
238 self.state.load(Ordering::SeqCst) == InnerWorkerState::Running
239 }
240
241 #[must_use]
243 pub fn is_pending(&self) -> bool {
244 self.state.load(Ordering::SeqCst) == InnerWorkerState::Pending
245 }
246
247 #[must_use]
249 pub fn is_paused(&self) -> bool {
250 self.state.load(Ordering::SeqCst) == InnerWorkerState::Paused
251 }
252
253 #[must_use]
256 pub fn task_count(&self) -> usize {
257 self.task_count.load(Ordering::Relaxed)
258 }
259
260 #[must_use]
262 pub fn has_pending_tasks(&self) -> bool {
263 self.task_count.load(Ordering::Relaxed) > 0
264 }
265
266 #[must_use]
268 pub fn is_shutting_down(&self) -> bool {
269 self.shutdown
270 .as_ref()
271 .map(|s| !self.is_running() || s.is_shutting_down())
272 .unwrap_or(!self.is_running())
273 }
274
275 pub fn emit(&mut self, event: &Event) {
277 self.emit_ref(event);
278 }
279
280 fn emit_ref(&self, event: &Event) {
281 let handler = self.event_handler.as_ref();
282 handler(self, event);
283 }
284
285 pub fn wrap_listener<F: Fn(&Self, &Event) + Send + Sync + 'static>(&mut self, f: F) {
287 let cur = self.event_handler.clone();
288 let new: RawEventListener = Box::new(move |ctx, ev| {
289 f(ctx, ev);
290 cur(ctx, ev);
291 });
292 self.event_handler = Arc::new(new);
293 }
294
295 pub(crate) fn add_waker(&self, cx: &Context<'_>) {
296 if let Ok(mut waker_guard) = self.waker.lock() {
297 if waker_guard
298 .as_ref()
299 .map_or(true, |stored_waker| !stored_waker.will_wake(cx.waker()))
300 {
301 *waker_guard = Some(cx.waker().clone());
302 }
303 }
304 }
305
306 fn has_recent_waker(&self, cx: &Context<'_>) -> bool {
308 if let Ok(waker_guard) = self.waker.lock() {
309 if let Some(stored_waker) = &*waker_guard {
310 return stored_waker.will_wake(cx.waker());
311 }
312 }
313 false
314 }
315
316 fn start_task(&self) {
317 self.task_count.fetch_add(1, Ordering::Relaxed);
318 }
319
320 fn end_task(&self) {
321 if self.task_count.fetch_sub(1, Ordering::Relaxed) == 1 {
322 self.wake();
323 }
324 }
325
326 pub(crate) fn wake(&self) {
327 if let Ok(waker) = self.waker.lock() {
328 if let Some(waker) = &*waker {
329 waker.wake_by_ref();
330 }
331 }
332 }
333}
334
335impl Future for WorkerContext {
336 type Output = Result<(), WorkerError>;
337
338 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
339 let task_count = self.task_count.load(Ordering::Relaxed);
340 let state = self.state.load(Ordering::SeqCst);
341 if state == InnerWorkerState::Pending {
342 return Poll::Ready(Err(WorkerError::StateError(WorkerStateError::NotStarted)));
343 }
344 if self.is_shutting_down() && task_count == 0 {
345 Poll::Ready(Ok(()))
346 } else {
347 if !self.has_recent_waker(cx) {
348 self.add_waker(cx);
349 }
350 Poll::Pending
351 }
352 }
353}
354
355impl<Args: Sync, Ctx: Sync, IdType: Sync + Send> FromRequest<Task<Args, Ctx, IdType>>
356 for WorkerContext
357{
358 type Error = MissingDataError;
359 async fn from_request(task: &Task<Args, Ctx, IdType>) -> Result<Self, Self::Error> {
360 task.parts.data.get_checked().cloned()
361 }
362}
363
364impl Drop for WorkerContext {
365 fn drop(&mut self) {
366 if Arc::strong_count(&self.state) > 1 {
367 return;
369 }
370 if self.is_running() {
371 error!(
372 "Worker '{}' is being dropped while running with `{}` tasks. Consider calling stop() before dropping.",
373 self.name(),
374 self.task_count()
375 );
376 }
377 }
378}