1use crate::backend::Backend;
2use crate::error::{BoxDynError, Error};
3use crate::layers::extensions::Data;
4use crate::monitor::shutdown::Shutdown;
5use crate::request::Request;
6use crate::service_fn::FromRequest;
7use crate::task::task_id::TaskId;
8use call_all::CallAllUnordered;
9use futures::future::{join, select, BoxFuture};
10use futures::stream::BoxStream;
11use futures::{Future, FutureExt, Stream, StreamExt};
12use pin_project_lite::pin_project;
13use serde::{Deserialize, Serialize};
14use std::fmt::Debug;
15use std::fmt::{self, Display};
16use std::ops::{Deref, DerefMut};
17use std::pin::Pin;
18use std::str::FromStr;
19use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
20use std::sync::{Arc, Mutex, RwLock};
21use std::task::{Context as TaskCtx, Poll, Waker};
22use thiserror::Error;
23use tower::{Layer, Service, ServiceBuilder};
24
25mod call_all;
26
27#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
29pub struct WorkerId {
30 name: String,
31}
32
33pub type EventHandler = Arc<RwLock<Option<Box<dyn Fn(Worker<Event>) + Send + Sync>>>>;
35
36impl FromStr for WorkerId {
37 type Err = ();
38
39 fn from_str(s: &str) -> Result<Self, Self::Err> {
40 Ok(WorkerId { name: s.to_owned() })
41 }
42}
43
44impl Display for WorkerId {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.write_str(self.name())?;
47 Ok(())
48 }
49}
50
51impl WorkerId {
52 pub fn new<T: AsRef<str>>(name: T) -> Self {
54 Self {
55 name: name.as_ref().to_string(),
56 }
57 }
58
59 pub fn name(&self) -> &str {
61 &self.name
62 }
63}
64
65#[derive(Debug)]
67pub enum Event {
68 Start,
70 Engage(TaskId),
72 Idle,
74 Custom(String),
76 Error(BoxDynError),
78 Stop,
80 Exit,
82}
83
84impl fmt::Display for Worker<Event> {
85 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86 let event_description = match &self.state {
87 Event::Start => "Worker started".to_string(),
88 Event::Engage(task_id) => format!("Worker engaged with Task ID: {}", task_id),
89 Event::Idle => "Worker is idle".to_string(),
90 Event::Custom(msg) => format!("Custom event: {}", msg),
91 Event::Error(err) => format!("Worker encountered an error: {}", err),
92 Event::Stop => "Worker stopped".to_string(),
93 Event::Exit => "Worker completed all pending tasks and exited".to_string(),
94 };
95
96 write!(f, "Worker [{}]: {}", self.id.name, event_description)
97 }
98}
99
100#[derive(Error, Debug, Clone)]
102pub enum WorkerError {
103 #[error("Failed to process job: {0}")]
105 ProcessingError(String),
106 #[error("Service error: {0}")]
108 ServiceError(String),
109 #[error("Failed to start worker: {0}")]
111 StartError(String),
112}
113
114pub struct Ready<S, P> {
116 service: S,
117 backend: P,
118 pub(crate) shutdown: Option<Shutdown>,
119 pub(crate) event_handler: EventHandler,
120}
121
122impl<S, P> fmt::Debug for Ready<S, P>
123where
124 S: fmt::Debug,
125 P: fmt::Debug,
126{
127 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128 f.debug_struct("Ready")
129 .field("service", &self.service)
130 .field("backend", &self.backend)
131 .field("shutdown", &self.shutdown)
132 .field("event_handler", &"...") .finish()
134 }
135}
136
137impl<S, P> Clone for Ready<S, P>
138where
139 S: Clone,
140 P: Clone,
141{
142 fn clone(&self) -> Self {
143 Ready {
144 service: self.service.clone(),
145 backend: self.backend.clone(),
146 shutdown: self.shutdown.clone(),
147 event_handler: self.event_handler.clone(),
148 }
149 }
150}
151
152impl<S, P> Ready<S, P> {
153 pub fn new(service: S, poller: P) -> Self {
155 Ready {
156 service,
157 backend: poller,
158 shutdown: None,
159 event_handler: EventHandler::default(),
160 }
161 }
162}
163
164#[derive(Debug, Clone, Serialize)]
166pub struct Worker<T> {
167 pub(crate) id: WorkerId,
168 pub(crate) state: T,
169}
170
171impl<T> Worker<T> {
172 pub fn new(id: WorkerId, state: T) -> Self {
174 Self { id, state }
175 }
176
177 pub fn inner(&self) -> &T {
179 &self.state
180 }
181
182 pub fn id(&self) -> &WorkerId {
184 &self.id
185 }
186}
187
188impl<T> Deref for Worker<T> {
189 type Target = T;
190 fn deref(&self) -> &Self::Target {
191 &self.state
192 }
193}
194
195impl<T> DerefMut for Worker<T> {
196 fn deref_mut(&mut self) -> &mut Self::Target {
197 &mut self.state
198 }
199}
200
201impl Worker<Context> {
202 pub fn emit(&self, event: Event) -> bool {
204 if let Some(handler) = self.state.event_handler.read().unwrap().as_ref() {
205 handler(Worker {
206 id: self.id().clone(),
207 state: event,
208 });
209 return true;
210 }
211 false
212 }
213 pub fn start(&self) {
215 self.state.running.store(true, Ordering::Relaxed);
216 self.state.is_ready.store(true, Ordering::Release);
217 self.emit(Event::Start);
218 }
219}
220
221impl<Req, Ctx> FromRequest<Request<Req, Ctx>> for Worker<Context> {
222 fn from_request(req: &Request<Req, Ctx>) -> Result<Self, Error> {
223 req.parts.data.get_checked().cloned()
224 }
225}
226
227impl<S, P> Worker<Ready<S, P>> {
228 pub fn on_event<F: Fn(Worker<Event>) + Send + Sync + 'static>(self, f: F) -> Self {
230 let _ = self.event_handler.write().map(|mut res| {
231 let _ = res.insert(Box::new(f));
232 });
233 self
234 }
235
236 fn poll_jobs<Svc, Stm, Req, Ctx>(
237 worker: Worker<Context>,
238 service: Svc,
239 stream: Stm,
240 ) -> BoxStream<'static, ()>
241 where
242 Svc: Service<Request<Req, Ctx>> + Send + 'static,
243 Stm: Stream<Item = Result<Option<Request<Req, Ctx>>, Error>> + Send + Unpin + 'static,
244 Req: Send + 'static,
245 Svc::Future: Send,
246 Svc::Error: Send + 'static + Into<BoxDynError>,
247 Ctx: Send + 'static,
248 {
249 let w = worker.clone();
250 let stream = stream.filter_map(move |result| {
251 let worker = worker.clone();
252
253 async move {
254 match result {
255 Ok(Some(request)) => {
256 worker.emit(Event::Engage(request.parts.task_id.clone()));
257 Some(request)
258 }
259 Ok(None) => {
260 worker.emit(Event::Idle);
261 None
262 }
263 Err(err) => {
264 worker.emit(Event::Error(Box::new(err)));
265 None
266 }
267 }
268 }
269 });
270 let stream = CallAllUnordered::new(service, stream).map(move |res| {
271 if let Err(error) = res {
272 let error = error.into();
273 if let Some(Error::MissingData(_)) = error.downcast_ref::<Error>() {
274 w.stop();
275 }
276 w.emit(Event::Error(error));
277 }
278 });
279 stream.boxed()
280 }
281 pub fn run<Req, Ctx>(self) -> Runnable
283 where
284 S: Service<Request<Req, Ctx>> + 'static,
285 P: Backend<Request<Req, Ctx>> + 'static,
286 Req: Send + 'static,
287 S::Error: Send + 'static + Into<BoxDynError>,
288 P::Stream: Unpin + Send + 'static,
289 P::Layer: Layer<S>,
290 <P::Layer as Layer<S>>::Service: Service<Request<Req, Ctx>> + Send,
291 <<P::Layer as Layer<S>>::Service as Service<Request<Req, Ctx>>>::Future: Send,
292 <<P::Layer as Layer<S>>::Service as Service<Request<Req, Ctx>>>::Error:
293 Send + Into<BoxDynError>,
294 Ctx: Send + 'static,
295 {
296 fn type_name_of_val<T>(_t: &T) -> &'static str {
297 std::any::type_name::<T>()
298 }
299 let service = self.state.service;
300 let worker_id = self.id;
301 let ctx = Context {
302 running: Arc::default(),
303 task_count: Arc::default(),
304 waker: Arc::default(),
305 shutdown: self.state.shutdown,
306 event_handler: self.state.event_handler.clone(),
307 is_ready: Arc::default(),
308 service: type_name_of_val(&service).to_owned(),
309 };
310 let worker = Worker {
311 id: worker_id.clone(),
312 state: ctx.clone(),
313 };
314 let backend = self.state.backend;
315
316 let poller = backend.poll(&worker);
317 let stream = poller.stream;
318 let heartbeat = poller.heartbeat.boxed();
319 let layer = poller.layer;
320 let service = ServiceBuilder::new()
321 .layer(TrackerLayer::new(worker.state.clone()))
322 .layer(ReadinessLayer::new(worker.state.is_ready.clone()))
323 .layer(Data::new(worker.clone()))
324 .layer(layer)
325 .service(service);
326
327 Runnable {
328 poller: Self::poll_jobs(worker.clone(), service, stream),
329 heartbeat,
330 worker,
331 running: false,
332 }
333 }
334}
335
336#[must_use = "A Runnable must be awaited of no jobs will be consumed"]
342pub struct Runnable {
343 poller: BoxStream<'static, ()>,
344 heartbeat: BoxFuture<'static, ()>,
345 worker: Worker<Context>,
346 running: bool,
347}
348
349impl Runnable {
350 pub fn get_handle(&self) -> Worker<Context> {
352 self.worker.clone()
353 }
354}
355
356impl fmt::Debug for Runnable {
357 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358 f.debug_struct("Runnable")
359 .field("poller", &"<stream>")
360 .field("heartbeat", &"<future>")
361 .field("worker", &self.worker)
362 .field("running", &self.running)
363 .finish()
364 }
365}
366
367impl Future for Runnable {
368 type Output = ();
369
370 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
371 let this = self.get_mut();
372 let poller = &mut this.poller;
373 let heartbeat = &mut this.heartbeat;
374 let worker = &mut this.worker;
375
376 let poller_future = async { while (poller.next().await).is_some() {} };
377
378 if !this.running {
379 worker.start();
380 this.running = true;
381 }
382 let combined = Box::pin(join(poller_future, heartbeat.as_mut()));
383
384 let mut combined = select(
385 combined,
386 worker.state.clone().map(|_| worker.emit(Event::Stop)),
387 )
388 .boxed();
389 match Pin::new(&mut combined).poll(cx) {
390 Poll::Ready(_) => {
391 worker.emit(Event::Exit);
392 Poll::Ready(())
393 }
394 Poll::Pending => Poll::Pending,
395 }
396 }
397}
398
399#[derive(Clone, Default)]
401pub struct Context {
402 task_count: Arc<AtomicUsize>,
403 waker: Arc<Mutex<Option<Waker>>>,
404 running: Arc<AtomicBool>,
405 shutdown: Option<Shutdown>,
406 event_handler: EventHandler,
407 is_ready: Arc<AtomicBool>,
408 service: String,
409}
410
411impl fmt::Debug for Context {
412 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
413 f.debug_struct("WorkerContext")
414 .field("shutdown", &["Shutdown handle"])
415 .field("task_count", &self.task_count)
416 .field("running", &self.running)
417 .field("service", &self.service)
418 .finish()
419 }
420}
421
422pin_project! {
423 pub struct Tracked<F> {
425 ctx: Context,
426 #[pin]
427 task: F,
428 }
429}
430
431impl<F: Future> Future for Tracked<F> {
432 type Output = F::Output;
433
434 fn poll(self: Pin<&mut Self>, cx: &mut TaskCtx<'_>) -> Poll<F::Output> {
435 let this = self.project();
436
437 match this.task.poll(cx) {
438 res @ Poll::Ready(_) => {
439 this.ctx.end_task();
440 res
441 }
442 Poll::Pending => Poll::Pending,
443 }
444 }
445}
446
447impl Context {
448 pub fn track<F: Future>(&self, task: F) -> Tracked<F> {
450 self.start_task();
451 Tracked {
452 ctx: self.clone(),
453 task,
454 }
455 }
456
457 pub fn stop(&self) {
459 self.running.store(false, Ordering::Relaxed);
460 self.wake()
461 }
462
463 fn start_task(&self) {
464 self.task_count.fetch_add(1, Ordering::Relaxed);
465 }
466
467 fn end_task(&self) {
468 if self.task_count.fetch_sub(1, Ordering::Relaxed) == 1 {
469 self.wake();
470 }
471 }
472
473 pub(crate) fn wake(&self) {
474 if let Ok(waker) = self.waker.lock() {
475 if let Some(waker) = &*waker {
476 waker.wake_by_ref();
477 }
478 }
479 }
480
481 pub fn is_running(&self) -> bool {
483 self.running.load(Ordering::Relaxed)
484 }
485
486 pub fn task_count(&self) -> usize {
489 self.task_count.load(Ordering::Relaxed)
490 }
491
492 pub fn has_pending_tasks(&self) -> bool {
494 self.task_count.load(Ordering::Relaxed) > 0
495 }
496
497 pub fn is_shutting_down(&self) -> bool {
499 self.shutdown
500 .as_ref()
501 .map(|s| !self.is_running() || s.is_shutting_down())
502 .unwrap_or(!self.is_running())
503 }
504
505 fn add_waker(&self, cx: &mut TaskCtx<'_>) {
506 if let Ok(mut waker_guard) = self.waker.lock() {
507 if waker_guard
508 .as_ref()
509 .map_or(true, |stored_waker| !stored_waker.will_wake(cx.waker()))
510 {
511 *waker_guard = Some(cx.waker().clone());
512 }
513 }
514 }
515
516 fn has_recent_waker(&self, cx: &TaskCtx<'_>) -> bool {
518 if let Ok(waker_guard) = self.waker.lock() {
519 if let Some(stored_waker) = &*waker_guard {
520 return stored_waker.will_wake(cx.waker());
521 }
522 }
523 false
524 }
525
526 pub fn is_ready(&self) -> bool {
528 self.is_ready.load(Ordering::Acquire) && !self.is_shutting_down()
529 }
530
531 pub fn get_service(&self) -> &String {
533 &self.service
534 }
535}
536
537impl Future for Context {
538 type Output = ();
539
540 fn poll(self: Pin<&mut Self>, cx: &mut TaskCtx<'_>) -> Poll<()> {
541 let task_count = self.task_count.load(Ordering::Relaxed);
542 if self.is_shutting_down() && task_count == 0 {
543 Poll::Ready(())
544 } else {
545 if !self.has_recent_waker(cx) {
546 self.add_waker(cx);
547 }
548 Poll::Pending
549 }
550 }
551}
552
553#[derive(Debug, Clone)]
554struct TrackerLayer {
555 ctx: Context,
556}
557
558impl TrackerLayer {
559 fn new(ctx: Context) -> Self {
560 Self { ctx }
561 }
562}
563
564impl<S> Layer<S> for TrackerLayer {
565 type Service = TrackerService<S>;
566
567 fn layer(&self, service: S) -> Self::Service {
568 TrackerService {
569 ctx: self.ctx.clone(),
570 service,
571 }
572 }
573}
574#[derive(Debug, Clone)]
575struct TrackerService<S> {
576 ctx: Context,
577 service: S,
578}
579
580impl<S, Req, Ctx> Service<Request<Req, Ctx>> for TrackerService<S>
581where
582 S: Service<Request<Req, Ctx>>,
583{
584 type Response = S::Response;
585 type Error = S::Error;
586 type Future = Tracked<S::Future>;
587
588 fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
589 self.service.poll_ready(cx)
590 }
591
592 fn call(&mut self, request: Request<Req, Ctx>) -> Self::Future {
593 request.parts.attempt.increment();
594 self.ctx.track(self.service.call(request))
595 }
596}
597
598#[derive(Clone)]
599struct ReadinessLayer {
600 is_ready: Arc<AtomicBool>,
601}
602
603impl ReadinessLayer {
604 fn new(is_ready: Arc<AtomicBool>) -> Self {
605 Self { is_ready }
606 }
607}
608
609impl<S> Layer<S> for ReadinessLayer {
610 type Service = ReadinessService<S>;
611
612 fn layer(&self, inner: S) -> Self::Service {
613 ReadinessService {
614 inner,
615 is_ready: self.is_ready.clone(),
616 }
617 }
618}
619
620struct ReadinessService<S> {
621 inner: S,
622 is_ready: Arc<AtomicBool>,
623}
624
625impl<S, Request> Service<Request> for ReadinessService<S>
626where
627 S: Service<Request>,
628{
629 type Response = S::Response;
630 type Error = S::Error;
631 type Future = S::Future;
632
633 fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
634 let result = self.inner.poll_ready(cx);
636 match &result {
638 Poll::Ready(Ok(_)) => self.is_ready.store(true, Ordering::Release),
639 Poll::Pending | Poll::Ready(Err(_)) => self.is_ready.store(false, Ordering::Release),
640 }
641
642 result
643 }
644
645 fn call(&mut self, req: Request) -> Self::Future {
646 self.inner.call(req)
647 }
648}
649
650#[cfg(test)]
651mod tests {
652 use std::{ops::Deref, sync::atomic::AtomicUsize};
653
654 use crate::{
655 builder::{WorkerBuilder, WorkerFactoryFn},
656 layers::extensions::Data,
657 memory::MemoryStorage,
658 mq::MessageQueue,
659 };
660
661 use super::*;
662
663 const ITEMS: u32 = 100;
664
665 #[test]
666 fn it_parses_worker_names() {
667 assert_eq!(
668 WorkerId::from_str("worker").unwrap(),
669 WorkerId {
670 name: "worker".to_string()
671 }
672 );
673 assert_eq!(
674 WorkerId::from_str("worker-0").unwrap(),
675 WorkerId {
676 name: "worker-0".to_string()
677 }
678 );
679 assert_eq!(
680 WorkerId::from_str("complex&*-worker-name-0").unwrap(),
681 WorkerId {
682 name: "complex&*-worker-name-0".to_string()
683 }
684 );
685 }
686
687 #[tokio::test]
688 async fn it_works() {
689 let in_memory = MemoryStorage::new();
690 let mut handle = in_memory.clone();
691
692 tokio::spawn(async move {
693 for i in 0..ITEMS {
694 handle.enqueue(i).await.unwrap();
695 }
696 });
697
698 #[derive(Clone, Debug, Default)]
699 struct Count(Arc<AtomicUsize>);
700
701 impl Deref for Count {
702 type Target = Arc<AtomicUsize>;
703 fn deref(&self) -> &Self::Target {
704 &self.0
705 }
706 }
707
708 async fn task(job: u32, count: Data<Count>, worker: Worker<Context>) {
709 count.fetch_add(1, Ordering::Relaxed);
710 if job == ITEMS - 1 {
711 worker.stop();
712 }
713 }
714 let worker = WorkerBuilder::new("rango-tango")
715 .data(Count::default())
716 .backend(in_memory);
717 let worker = worker.build_fn(task);
718 worker.run().await;
719 }
720}