Skip to main content

backyard_core/
worker.rs

1use crate::{
2    error::Result,
3    job::{JobContext, RawJob},
4    queue::Queue,
5    registry::build_dispatch_table,
6};
7use std::collections::HashMap;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11use tokio::sync::mpsc;
12use tokio_util::sync::CancellationToken;
13use tracing::{error, info, warn, Instrument};
14
15pub type HandlerFn = fn(&[u8], JobContext) -> Pin<Box<dyn Future<Output = Result<()>> + Send>>;
16
17#[derive(Debug, Clone)]
18pub struct WorkerConfig {
19    pub queues: Vec<String>,
20    pub concurrency: usize,
21    pub poll_interval: std::time::Duration,
22}
23
24impl Default for WorkerConfig {
25    fn default() -> Self {
26        Self {
27            queues: vec!["default".into()],
28            concurrency: 10,
29            poll_interval: std::time::Duration::from_millis(500),
30        }
31    }
32}
33
34pub struct WorkerPool {
35    queue: Arc<dyn Queue>,
36    config: WorkerConfig,
37    dispatch: HashMap<&'static str, HandlerFn>,
38    shutdown: CancellationToken,
39}
40
41impl WorkerPool {
42    pub fn new(queue: Arc<dyn Queue>, config: WorkerConfig) -> Self {
43        Self {
44            queue,
45            config,
46            dispatch: build_dispatch_table(),
47            shutdown: CancellationToken::new(),
48        }
49    }
50
51    pub fn shutdown_token(&self) -> CancellationToken {
52        self.shutdown.child_token()
53    }
54
55    pub async fn run(self) -> Result<()> {
56        let (tx, rx): (mpsc::Sender<RawJob>, mpsc::Receiver<RawJob>) =
57            mpsc::channel(self.config.concurrency * 2);
58        let rx = Arc::new(tokio::sync::Mutex::new(rx));
59
60        let mut handles = vec![];
61        for worker_id in 0..self.config.concurrency {
62            let rx = rx.clone();
63            let queue = self.queue.clone();
64            let dispatch = self.dispatch.clone();
65            let shutdown = self.shutdown.clone();
66            let handle = tokio::spawn(async move {
67                Self::worker_loop(worker_id.to_string(), rx, queue, dispatch, shutdown).await
68            });
69            handles.push(handle);
70        }
71
72        let fetch_shutdown = self.shutdown.clone();
73        let queue = self.queue.clone();
74        let queues: Vec<String> = self.config.queues.clone();
75        let poll_interval = self.config.poll_interval;
76
77        tokio::spawn(async move {
78            Self::fetch_loop(queue, queues, tx, fetch_shutdown, poll_interval).await
79        });
80
81        futures::future::join_all(handles).await;
82        Ok(())
83    }
84
85    async fn fetch_loop(
86        queue: Arc<dyn Queue>,
87        queues: Vec<String>,
88        tx: mpsc::Sender<RawJob>,
89        shutdown: CancellationToken,
90        poll_interval: std::time::Duration,
91    ) {
92        let queue_refs: Vec<&str> = queues.iter().map(|s| s.as_str()).collect();
93        loop {
94            tokio::select! {
95                _ = shutdown.cancelled() => break,
96                result = queue.pop(&queue_refs) => {
97                    match result {
98                        Ok(Some(job)) => {
99                            let _ = tx.send(job).await;
100                        }
101                        Ok(None) => {
102                            tokio::time::sleep(poll_interval).await;
103                        }
104                        Err(e) => {
105                            error!("fetch error: {e}");
106                            tokio::time::sleep(poll_interval).await;
107                        }
108                    }
109                }
110            }
111        }
112    }
113
114    async fn worker_loop(
115        worker_id: String,
116        rx: Arc<tokio::sync::Mutex<mpsc::Receiver<RawJob>>>,
117        queue: Arc<dyn Queue>,
118        dispatch: HashMap<&'static str, HandlerFn>,
119        shutdown: CancellationToken,
120    ) {
121        loop {
122            let job: Option<RawJob> = {
123                let mut rx = rx.lock().await;
124                tokio::select! {
125                    _ = shutdown.cancelled() => break,
126                    job = rx.recv() => job
127                }
128            };
129
130            let job = match job {
131                Some(j) => j,
132                None => break,
133            };
134
135            let ctx = JobContext {
136                queue: queue.clone(),
137                worker_id: worker_id.clone(),
138            };
139
140            match dispatch.get(job.job_type.as_str()) {
141                None => {
142                    error!(job_type = %job.job_type, "no handler registered");
143                    let _ = queue.fail(job.id, "no handler registered").await;
144                }
145                Some(handler) => {
146                    let span = tracing::info_span!(
147                        "execute_job",
148                        job_id = %job.id,
149                        job_type = %job.job_type,
150                        queue = %job.queue,
151                        attempt = job.attempts,
152                    );
153                    let result = handler(&job.payload, ctx.clone()).instrument(span).await;
154                    match result {
155                        Ok(()) => {
156                            info!(job_id = %job.id, "job succeeded");
157                            let _ = queue.ack(job.id).await;
158                        }
159                        Err(e) => {
160                            warn!(job_id = %job.id, error = %e, "job failed");
161                            if job.attempts >= job.max_retries {
162                                let _ = queue.fail(job.id, &e.to_string()).await;
163                            } else {
164                                let retry_at = crate::retry::next_retry_at(job.attempts);
165                                let _ = queue.retry(job.id, retry_at).await;
166                            }
167                        }
168                    }
169                }
170            }
171        }
172    }
173}