1use crate::backend::Backend;
2use crate::error::{BoxDynError, Error};
3use crate::layers::extensions::Data;
4use crate::monitor::shutdown::Shutdown;
5use crate::request::Request;
6use crate::service_fn::FromRequest;
7use crate::task::task_id::TaskId;
8use call_all::CallAllUnordered;
9use futures::future::{join, select, BoxFuture};
10use futures::stream::BoxStream;
11use futures::{Future, FutureExt, Stream, StreamExt};
12use pin_project_lite::pin_project;
13use serde::{Deserialize, Serialize};
14use std::fmt::Debug;
15use std::fmt::{self, Display};
16use std::ops::{Deref, DerefMut};
17use std::pin::Pin;
18use std::str::FromStr;
19use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
20use std::sync::{Arc, Mutex, RwLock};
21use std::task::{Context as TaskCtx, Poll, Waker};
22use thiserror::Error;
23use tower::{Layer, Service, ServiceBuilder};
24
25mod call_all;
26
27#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
29pub struct WorkerId {
30 name: String,
31}
32
33pub type EventHandler = Arc<RwLock<Option<Box<dyn Fn(Worker<Event>) + Send + Sync>>>>;
35
36impl FromStr for WorkerId {
37 type Err = ();
38
39 fn from_str(s: &str) -> Result<Self, Self::Err> {
40 Ok(WorkerId { name: s.to_owned() })
41 }
42}
43
44impl Display for WorkerId {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.write_str(self.name())?;
47 Ok(())
48 }
49}
50
51impl WorkerId {
52 pub fn new<T: AsRef<str>>(name: T) -> Self {
54 Self {
55 name: name.as_ref().to_string(),
56 }
57 }
58
59 pub fn name(&self) -> &str {
61 &self.name
62 }
63}
64
65#[derive(Debug)]
67pub enum Event {
68 Start,
70 Engage(TaskId),
72 Idle,
74 Custom(String),
76 Error(BoxDynError),
78 Stop,
80 Exit,
82}
83
84impl fmt::Display for Worker<Event> {
85 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86 let event_description = match &self.state {
87 Event::Start => "Worker started".to_string(),
88 Event::Engage(task_id) => format!("Worker engaged with Task ID: {}", task_id),
89 Event::Idle => "Worker is idle".to_string(),
90 Event::Custom(msg) => format!("Custom event: {}", msg),
91 Event::Error(err) => format!("Worker encountered an error: {}", err),
92 Event::Stop => "Worker stopped".to_string(),
93 Event::Exit => "Worker completed all pending tasks and exited".to_string(),
94 };
95
96 write!(f, "Worker [{}]: {}", self.id.name, event_description)
97 }
98}
99
100#[derive(Error, Debug, Clone)]
102pub enum WorkerError {
103 #[error("Failed to process job: {0}")]
105 ProcessingError(String),
106 #[error("Service error: {0}")]
108 ServiceError(String),
109 #[error("Failed to start worker: {0}")]
111 StartError(String),
112}
113
114pub struct Ready<S, P> {
116 service: S,
117 backend: P,
118 pub(crate) shutdown: Option<Shutdown>,
119 pub(crate) event_handler: EventHandler,
120}
121
122impl<S, P> fmt::Debug for Ready<S, P>
123where
124 S: fmt::Debug,
125 P: fmt::Debug,
126{
127 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128 f.debug_struct("Ready")
129 .field("service", &self.service)
130 .field("backend", &self.backend)
131 .field("shutdown", &self.shutdown)
132 .field("event_handler", &"...") .finish()
134 }
135}
136
137impl<S, P> Clone for Ready<S, P>
138where
139 S: Clone,
140 P: Clone,
141{
142 fn clone(&self) -> Self {
143 Ready {
144 service: self.service.clone(),
145 backend: self.backend.clone(),
146 shutdown: self.shutdown.clone(),
147 event_handler: self.event_handler.clone(),
148 }
149 }
150}
151
152impl<S, P> Ready<S, P> {
153 pub fn new(service: S, poller: P) -> Self {
155 Ready {
156 service,
157 backend: poller,
158 shutdown: None,
159 event_handler: EventHandler::default(),
160 }
161 }
162}
163
164#[derive(Debug, Clone, Serialize)]
166pub struct Worker<T> {
167 pub(crate) id: WorkerId,
168 pub(crate) state: T,
169}
170
171impl<T> Worker<T> {
172 pub fn new(id: WorkerId, state: T) -> Self {
174 Self { id, state }
175 }
176
177 pub fn inner(&self) -> &T {
179 &self.state
180 }
181
182 pub fn id(&self) -> &WorkerId {
184 &self.id
185 }
186}
187
188impl<T> Deref for Worker<T> {
189 type Target = T;
190 fn deref(&self) -> &Self::Target {
191 &self.state
192 }
193}
194
195impl<T> DerefMut for Worker<T> {
196 fn deref_mut(&mut self) -> &mut Self::Target {
197 &mut self.state
198 }
199}
200
201impl Worker<Context> {
202 pub fn emit(&self, event: Event) -> bool {
204 if let Some(handler) = self.state.event_handler.read().unwrap().as_ref() {
205 handler(Worker {
206 id: self.id().clone(),
207 state: event,
208 });
209 return true;
210 }
211 false
212 }
213 pub fn start(&self) {
215 self.state.running.store(true, Ordering::Relaxed);
216 self.state.is_ready.store(true, Ordering::Release);
217 self.emit(Event::Start);
218 }
219}
220
221impl<Req, Ctx> FromRequest<Request<Req, Ctx>> for Worker<Context> {
222 fn from_request(req: &Request<Req, Ctx>) -> Result<Self, Error> {
223 req.parts.data.get_checked().cloned()
224 }
225}
226
227impl<S, P> Worker<Ready<S, P>> {
228 pub fn on_event<F: Fn(Worker<Event>) + Send + Sync + 'static>(self, f: F) -> Self {
230 let _ = self.event_handler.write().map(|mut res| {
231 let _ = res.insert(Box::new(f));
232 });
233 self
234 }
235
236 fn poll_jobs<Svc, Stm, Req, Res, Ctx>(
237 worker: Worker<Context>,
238 service: Svc,
239 stream: Stm,
240 ) -> BoxStream<'static, ()>
241 where
242 Svc: Service<Request<Req, Ctx>, Response = Res> + Send + 'static,
243 Stm: Stream<Item = Result<Option<Request<Req, Ctx>>, Error>> + Send + Unpin + 'static,
244 Req: Send + 'static + Sync,
245 Svc::Future: Send,
246 Svc::Response: 'static + Send + Sync + Serialize,
247 Svc::Error: Send + Sync + 'static + Into<BoxDynError>,
248 Ctx: Send + 'static + Sync,
249 Res: 'static,
250 {
251 let w = worker.clone();
252 let stream = stream.filter_map(move |result| {
253 let worker = worker.clone();
254
255 async move {
256 match result {
257 Ok(Some(request)) => {
258 worker.emit(Event::Engage(request.parts.task_id.clone()));
259 Some(request)
260 }
261 Ok(None) => {
262 worker.emit(Event::Idle);
263 None
264 }
265 Err(err) => {
266 worker.emit(Event::Error(Box::new(err)));
267 None
268 }
269 }
270 }
271 });
272 let stream = CallAllUnordered::new(service, stream).map(move |res| {
273 if let Err(error) = res {
274 let error = error.into();
275 if let Some(Error::MissingData(_)) = error.downcast_ref::<Error>() {
276 w.stop();
277 }
278 w.emit(Event::Error(error));
279 }
280 });
281 stream.boxed()
282 }
283 pub fn run<Req, Res, Ctx>(self) -> Runnable
285 where
286 S: Service<Request<Req, Ctx>, Response = Res> + Send + 'static,
287 P: Backend<Request<Req, Ctx>, Res> + 'static,
288 Req: Send + 'static + Sync,
289 S::Future: Send,
290 S::Response: 'static + Send + Sync + Serialize,
291 S::Error: Send + Sync + 'static + Into<BoxDynError>,
292 P::Stream: Unpin + Send + 'static,
293 P::Layer: Layer<S>,
294 <P::Layer as Layer<S>>::Service: Service<Request<Req, Ctx>, Response = Res> + Send,
295 <<P::Layer as Layer<S>>::Service as Service<Request<Req, Ctx>>>::Future: Send,
296 <<P::Layer as Layer<S>>::Service as Service<Request<Req, Ctx>>>::Error:
297 Send + Into<BoxDynError> + Sync,
298 Ctx: Send + 'static + Sync,
299 Res: 'static,
300 {
301 let worker_id = self.id;
302 let ctx = Context {
303 running: Arc::default(),
304 task_count: Arc::default(),
305 waker: Arc::default(),
306 shutdown: self.state.shutdown,
307 event_handler: self.state.event_handler.clone(),
308 is_ready: Arc::default(),
309 };
310 let worker = Worker {
311 id: worker_id.clone(),
312 state: ctx.clone(),
313 };
314 let backend = self.state.backend;
315 let service = self.state.service;
316 let poller = backend.poll::<S>(&worker);
317 let stream = poller.stream;
318 let heartbeat = poller.heartbeat.boxed();
319 let layer = poller.layer;
320 let service = ServiceBuilder::new()
321 .layer(TrackerLayer::new(worker.state.clone()))
322 .layer(ReadinessLayer::new(worker.state.is_ready.clone()))
323 .layer(Data::new(worker.clone()))
324 .layer(layer)
325 .service(service);
326
327 Runnable {
328 poller: Self::poll_jobs(worker.clone(), service, stream),
329 heartbeat,
330 worker,
331 running: false,
332 }
333 }
334}
335
336#[must_use = "A Runnable must be awaited of no jobs will be consumed"]
342pub struct Runnable {
343 poller: BoxStream<'static, ()>,
344 heartbeat: BoxFuture<'static, ()>,
345 worker: Worker<Context>,
346 running: bool,
347}
348
349impl Runnable {
350 pub fn get_handle(&self) -> Worker<Context> {
352 self.worker.clone()
353 }
354}
355
356impl fmt::Debug for Runnable {
357 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358 f.debug_struct("Runnable")
359 .field("poller", &"<stream>")
360 .field("heartbeat", &"<future>")
361 .field("worker", &self.worker)
362 .field("running", &self.running)
363 .finish()
364 }
365}
366
367impl Future for Runnable {
368 type Output = ();
369
370 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
371 let this = self.get_mut();
372 let poller = &mut this.poller;
373 let heartbeat = &mut this.heartbeat;
374 let worker = &mut this.worker;
375
376 let poller_future = async { while (poller.next().await).is_some() {} };
377
378 if !this.running {
379 worker.start();
380 this.running = true;
381 }
382 let combined = Box::pin(join(poller_future, heartbeat.as_mut()));
383
384 let mut combined = select(
385 combined,
386 worker.state.clone().map(|_| worker.emit(Event::Stop)),
387 )
388 .boxed();
389 match Pin::new(&mut combined).poll(cx) {
390 Poll::Ready(_) => {
391 worker.emit(Event::Exit);
392 Poll::Ready(())
393 }
394 Poll::Pending => Poll::Pending,
395 }
396 }
397}
398
399#[derive(Clone, Default)]
401pub struct Context {
402 task_count: Arc<AtomicUsize>,
403 waker: Arc<Mutex<Option<Waker>>>,
404 running: Arc<AtomicBool>,
405 shutdown: Option<Shutdown>,
406 event_handler: EventHandler,
407 is_ready: Arc<AtomicBool>,
408}
409
410impl fmt::Debug for Context {
411 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
412 f.debug_struct("WorkerContext")
413 .field("shutdown", &["Shutdown handle"])
414 .field("task_count", &self.task_count)
415 .field("running", &self.running)
416 .finish()
417 }
418}
419
420pin_project! {
421 pub struct Tracked<F> {
423 ctx: Context,
424 #[pin]
425 task: F,
426 }
427}
428
429impl<F: Future> Future for Tracked<F> {
430 type Output = F::Output;
431
432 fn poll(self: Pin<&mut Self>, cx: &mut TaskCtx<'_>) -> Poll<F::Output> {
433 let this = self.project();
434
435 match this.task.poll(cx) {
436 res @ Poll::Ready(_) => {
437 this.ctx.end_task();
438 res
439 }
440 Poll::Pending => Poll::Pending,
441 }
442 }
443}
444
445impl Context {
446 pub fn track<F: Future>(&self, task: F) -> Tracked<F> {
448 self.start_task();
449 Tracked {
450 ctx: self.clone(),
451 task,
452 }
453 }
454
455 pub fn stop(&self) {
457 self.running.store(false, Ordering::Relaxed);
458 self.wake()
459 }
460
461 fn start_task(&self) {
462 self.task_count.fetch_add(1, Ordering::Relaxed);
463 }
464
465 fn end_task(&self) {
466 if self.task_count.fetch_sub(1, Ordering::Relaxed) == 1 {
467 self.wake();
468 }
469 }
470
471 pub(crate) fn wake(&self) {
472 if let Ok(waker) = self.waker.lock() {
473 if let Some(waker) = &*waker {
474 waker.wake_by_ref();
475 }
476 }
477 }
478
479 pub fn is_running(&self) -> bool {
481 self.running.load(Ordering::Relaxed)
482 }
483
484 pub fn task_count(&self) -> usize {
487 self.task_count.load(Ordering::Relaxed)
488 }
489
490 pub fn has_pending_tasks(&self) -> bool {
492 self.task_count.load(Ordering::Relaxed) > 0
493 }
494
495 pub fn is_shutting_down(&self) -> bool {
497 self.shutdown
498 .as_ref()
499 .map(|s| !self.is_running() || s.is_shutting_down())
500 .unwrap_or(!self.is_running())
501 }
502
503 fn add_waker(&self, cx: &mut TaskCtx<'_>) {
504 if let Ok(mut waker_guard) = self.waker.lock() {
505 if waker_guard
506 .as_ref()
507 .map_or(true, |stored_waker| !stored_waker.will_wake(cx.waker()))
508 {
509 *waker_guard = Some(cx.waker().clone());
510 }
511 }
512 }
513
514 fn has_recent_waker(&self, cx: &TaskCtx<'_>) -> bool {
516 if let Ok(waker_guard) = self.waker.lock() {
517 if let Some(stored_waker) = &*waker_guard {
518 return stored_waker.will_wake(cx.waker());
519 }
520 }
521 false
522 }
523
524 pub fn is_ready(&self) -> bool {
526 self.is_ready.load(Ordering::Acquire) && !self.is_shutting_down()
527 }
528}
529
530impl Future for Context {
531 type Output = ();
532
533 fn poll(self: Pin<&mut Self>, cx: &mut TaskCtx<'_>) -> Poll<()> {
534 let task_count = self.task_count.load(Ordering::Relaxed);
535 if self.is_shutting_down() && task_count == 0 {
536 Poll::Ready(())
537 } else {
538 if !self.has_recent_waker(cx) {
539 self.add_waker(cx);
540 }
541 Poll::Pending
542 }
543 }
544}
545
546#[derive(Debug, Clone)]
547struct TrackerLayer {
548 ctx: Context,
549}
550
551impl TrackerLayer {
552 fn new(ctx: Context) -> Self {
553 Self { ctx }
554 }
555}
556
557impl<S> Layer<S> for TrackerLayer {
558 type Service = TrackerService<S>;
559
560 fn layer(&self, service: S) -> Self::Service {
561 TrackerService {
562 ctx: self.ctx.clone(),
563 service,
564 }
565 }
566}
567#[derive(Debug, Clone)]
568struct TrackerService<S> {
569 ctx: Context,
570 service: S,
571}
572
573impl<S, Req, Ctx> Service<Request<Req, Ctx>> for TrackerService<S>
574where
575 S: Service<Request<Req, Ctx>>,
576{
577 type Response = S::Response;
578 type Error = S::Error;
579 type Future = Tracked<S::Future>;
580
581 fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
582 self.service.poll_ready(cx)
583 }
584
585 fn call(&mut self, request: Request<Req, Ctx>) -> Self::Future {
586 self.ctx.track(self.service.call(request))
587 }
588}
589
590#[derive(Clone)]
591struct ReadinessLayer {
592 is_ready: Arc<AtomicBool>,
593}
594
595impl ReadinessLayer {
596 fn new(is_ready: Arc<AtomicBool>) -> Self {
597 Self { is_ready }
598 }
599}
600
601impl<S> Layer<S> for ReadinessLayer {
602 type Service = ReadinessService<S>;
603
604 fn layer(&self, inner: S) -> Self::Service {
605 ReadinessService {
606 inner,
607 is_ready: self.is_ready.clone(),
608 }
609 }
610}
611
612struct ReadinessService<S> {
613 inner: S,
614 is_ready: Arc<AtomicBool>,
615}
616
617impl<S, Request> Service<Request> for ReadinessService<S>
618where
619 S: Service<Request>,
620{
621 type Response = S::Response;
622 type Error = S::Error;
623 type Future = S::Future;
624
625 fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
626 let result = self.inner.poll_ready(cx);
628 match &result {
630 Poll::Ready(Ok(_)) => self.is_ready.store(true, Ordering::Release),
631 Poll::Pending | Poll::Ready(Err(_)) => self.is_ready.store(false, Ordering::Release),
632 }
633
634 result
635 }
636
637 fn call(&mut self, req: Request) -> Self::Future {
638 self.inner.call(req)
639 }
640}
641
642#[cfg(test)]
643mod tests {
644 use std::{ops::Deref, sync::atomic::AtomicUsize};
645
646 use crate::{
647 builder::{WorkerBuilder, WorkerFactoryFn},
648 layers::extensions::Data,
649 memory::MemoryStorage,
650 mq::MessageQueue,
651 };
652
653 use super::*;
654
655 const ITEMS: u32 = 100;
656
657 #[test]
658 fn it_parses_worker_names() {
659 assert_eq!(
660 WorkerId::from_str("worker").unwrap(),
661 WorkerId {
662 name: "worker".to_string()
663 }
664 );
665 assert_eq!(
666 WorkerId::from_str("worker-0").unwrap(),
667 WorkerId {
668 name: "worker-0".to_string()
669 }
670 );
671 assert_eq!(
672 WorkerId::from_str("complex&*-worker-name-0").unwrap(),
673 WorkerId {
674 name: "complex&*-worker-name-0".to_string()
675 }
676 );
677 }
678
679 #[tokio::test]
680 async fn it_works() {
681 let in_memory = MemoryStorage::new();
682 let mut handle = in_memory.clone();
683
684 tokio::spawn(async move {
685 for i in 0..ITEMS {
686 handle.enqueue(i).await.unwrap();
687 }
688 });
689
690 #[derive(Clone, Debug, Default)]
691 struct Count(Arc<AtomicUsize>);
692
693 impl Deref for Count {
694 type Target = Arc<AtomicUsize>;
695 fn deref(&self) -> &Self::Target {
696 &self.0
697 }
698 }
699
700 async fn task(job: u32, count: Data<Count>, worker: Worker<Context>) {
701 count.fetch_add(1, Ordering::Relaxed);
702 if job == ITEMS - 1 {
703 worker.stop();
704 }
705 }
706 let worker = WorkerBuilder::new("rango-tango")
707 .data(Count::default())
708 .backend(in_memory);
709 let worker = worker.build_fn(task);
710 worker.run().await;
711 }
712}