1use crate::{
2 error::Result,
3 job::{JobContext, RawJob},
4 queue::Queue,
5 registry::build_dispatch_table,
6};
7use std::collections::HashMap;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11use tokio::sync::mpsc;
12use tokio_util::sync::CancellationToken;
13use tracing::{error, info, warn, Instrument};
14
15pub type HandlerFn = fn(&[u8], JobContext) -> Pin<Box<dyn Future<Output = Result<()>> + Send>>;
16
17#[derive(Debug, Clone)]
18pub struct WorkerConfig {
19 pub queues: Vec<String>,
20 pub concurrency: usize,
21 pub poll_interval: std::time::Duration,
22}
23
24impl Default for WorkerConfig {
25 fn default() -> Self {
26 Self {
27 queues: vec!["default".into()],
28 concurrency: 10,
29 poll_interval: std::time::Duration::from_millis(500),
30 }
31 }
32}
33
34pub struct WorkerPool {
35 queue: Arc<dyn Queue>,
36 config: WorkerConfig,
37 dispatch: HashMap<&'static str, HandlerFn>,
38 shutdown: CancellationToken,
39}
40
41impl WorkerPool {
42 pub fn new(queue: Arc<dyn Queue>, config: WorkerConfig) -> Self {
43 Self {
44 queue,
45 config,
46 dispatch: build_dispatch_table(),
47 shutdown: CancellationToken::new(),
48 }
49 }
50
51 pub fn shutdown_token(&self) -> CancellationToken {
52 self.shutdown.child_token()
53 }
54
55 pub async fn run(self) -> Result<()> {
56 let (tx, rx): (mpsc::Sender<RawJob>, mpsc::Receiver<RawJob>) =
57 mpsc::channel(self.config.concurrency * 2);
58 let rx = Arc::new(tokio::sync::Mutex::new(rx));
59
60 let mut handles = vec![];
61 for worker_id in 0..self.config.concurrency {
62 let rx = rx.clone();
63 let queue = self.queue.clone();
64 let dispatch = self.dispatch.clone();
65 let shutdown = self.shutdown.clone();
66 let handle = tokio::spawn(async move {
67 Self::worker_loop(worker_id.to_string(), rx, queue, dispatch, shutdown).await
68 });
69 handles.push(handle);
70 }
71
72 let fetch_shutdown = self.shutdown.clone();
73 let queue = self.queue.clone();
74 let queues: Vec<String> = self.config.queues.clone();
75 let poll_interval = self.config.poll_interval;
76
77 tokio::spawn(async move {
78 Self::fetch_loop(queue, queues, tx, fetch_shutdown, poll_interval).await
79 });
80
81 futures::future::join_all(handles).await;
82 Ok(())
83 }
84
85 async fn fetch_loop(
86 queue: Arc<dyn Queue>,
87 queues: Vec<String>,
88 tx: mpsc::Sender<RawJob>,
89 shutdown: CancellationToken,
90 poll_interval: std::time::Duration,
91 ) {
92 let queue_refs: Vec<&str> = queues.iter().map(|s| s.as_str()).collect();
93 loop {
94 tokio::select! {
95 _ = shutdown.cancelled() => break,
96 result = queue.pop(&queue_refs) => {
97 match result {
98 Ok(Some(job)) => {
99 let _ = tx.send(job).await;
100 }
101 Ok(None) => {
102 tokio::time::sleep(poll_interval).await;
103 }
104 Err(e) => {
105 error!("fetch error: {e}");
106 tokio::time::sleep(poll_interval).await;
107 }
108 }
109 }
110 }
111 }
112 }
113
114 async fn worker_loop(
115 worker_id: String,
116 rx: Arc<tokio::sync::Mutex<mpsc::Receiver<RawJob>>>,
117 queue: Arc<dyn Queue>,
118 dispatch: HashMap<&'static str, HandlerFn>,
119 shutdown: CancellationToken,
120 ) {
121 loop {
122 let job: Option<RawJob> = {
123 let mut rx = rx.lock().await;
124 tokio::select! {
125 _ = shutdown.cancelled() => break,
126 job = rx.recv() => job
127 }
128 };
129
130 let job = match job {
131 Some(j) => j,
132 None => break,
133 };
134
135 let ctx = JobContext {
136 queue: queue.clone(),
137 worker_id: worker_id.clone(),
138 };
139
140 match dispatch.get(job.job_type.as_str()) {
141 None => {
142 error!(job_type = %job.job_type, "no handler registered");
143 let _ = queue.fail(job.id, "no handler registered").await;
144 }
145 Some(handler) => {
146 let span = tracing::info_span!(
147 "execute_job",
148 job_id = %job.id,
149 job_type = %job.job_type,
150 queue = %job.queue,
151 attempt = job.attempts,
152 );
153 let result = handler(&job.payload, ctx.clone()).instrument(span).await;
154 match result {
155 Ok(()) => {
156 info!(job_id = %job.id, "job succeeded");
157 let _ = queue.ack(job.id).await;
158 }
159 Err(e) => {
160 warn!(job_id = %job.id, error = %e, "job failed");
161 if job.attempts >= job.max_retries {
162 let _ = queue.fail(job.id, &e.to_string()).await;
163 } else {
164 let retry_at = crate::retry::next_retry_at(job.attempts);
165 let _ = queue.retry(job.id, retry_at).await;
166 }
167 }
168 }
169 }
170 }
171 }
172 }
173}