1use ahash::HashMap;
2use std::fmt::Debug;
3use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
4use std::sync::Arc;
5use tokio::sync::{oneshot, Notify};
6use tokio::task::JoinHandle;
7use tokio::time::Instant;
8use tracing::{event, instrument, Level, Span};
9
10use crate::db_writer::ready_jobs::{GetReadyJobsArgs, ReadyJob};
11use crate::db_writer::{DbOperation, DbOperationType};
12use crate::job_registry::{JobRegistry, JobRunner};
13use crate::shared_state::{SharedState, Time};
14use crate::worker_list::ListeningWorker;
15use crate::{Error, Queue, Result, SmartString};
16
17pub type WorkerId = u64;
19
20struct CancellableTask {
21 close_tx: oneshot::Sender<()>,
22 join_handle: JoinHandle<()>,
23}
24
25pub struct Worker {
27 pub id: WorkerId,
29 counts: Arc<RunningJobs>,
30 worker_list_task: Option<CancellableTask>,
31}
32
33pub struct WorkerCounts {
34 pub started: u64,
35 pub finished: u64,
36}
37
38impl Worker {
39 pub async fn unregister(mut self, timeout: Option<std::time::Duration>) -> Result<()> {
42 if let Some(task) = self.worker_list_task.take() {
43 task.close_tx.send(()).ok();
44 if let Some(timeout) = timeout {
45 tokio::time::timeout(timeout, task.join_handle)
46 .await
47 .map_err(|_| Error::Timeout)??;
48 } else {
49 task.join_handle.await?;
50 }
51 }
52 Ok(())
53 }
54
55 pub fn builder<CONTEXT>(queue: &Queue, context: CONTEXT) -> WorkerBuilder<CONTEXT>
57 where
58 CONTEXT: Send + Sync + Debug + Clone + 'static,
59 {
60 WorkerBuilder::new(queue, context)
61 }
62
63 pub fn counts(&self) -> WorkerCounts {
65 WorkerCounts {
66 started: self.counts.started.load(Ordering::Relaxed),
67 finished: self.counts.finished.load(Ordering::Relaxed),
68 }
69 }
70}
71
72impl Drop for Worker {
73 fn drop(&mut self) {
74 if let Some(task) = self.worker_list_task.take() {
75 task.close_tx.send(()).ok();
76 tokio::spawn(task.join_handle);
77 }
78 }
79}
80
81pub struct WorkerBuilder<'a, CONTEXT>
83where
84 CONTEXT: Send + Sync + Debug + Clone + 'static,
85{
86 registry: Option<&'a JobRegistry<CONTEXT>>,
88 job_defs: Option<Vec<JobRunner<CONTEXT>>>,
89 queue: &'a Queue,
90 context: CONTEXT,
92 jobs: Vec<SmartString>,
94 min_concurrency: Option<u16>,
97 max_concurrency: Option<u16>,
100}
101
102impl<'a, CONTEXT> WorkerBuilder<'a, CONTEXT>
103where
104 CONTEXT: Send + Sync + Debug + Clone + 'static,
105{
106 pub fn new(queue: &'a Queue, context: CONTEXT) -> Self {
108 Self {
109 registry: None,
110 job_defs: None,
111 queue,
112 context,
113 jobs: Vec::new(),
114 min_concurrency: None,
115 max_concurrency: None,
116 }
117 }
118
119 pub fn registry(mut self, registry: &'a JobRegistry<CONTEXT>) -> Self {
121 if self.job_defs.is_some() {
122 panic!("Cannot set both registry and job_defs");
123 }
124
125 self.registry = Some(registry);
126 self
127 }
128
129 pub fn jobs(mut self, jobs: impl Into<Vec<JobRunner<CONTEXT>>>) -> Self {
131 if self.job_defs.is_some() {
132 panic!("Cannot set both registry and job_defs");
133 }
134
135 self.job_defs = Some(jobs.into());
136 self
137 }
138
139 fn has_job_type(&self, job_type: &str) -> bool {
140 if let Some(job_defs) = self.job_defs.as_ref() {
141 job_defs.iter().any(|job_def| job_def.name == job_type)
142 } else if let Some(registry) = self.registry.as_ref() {
143 registry.jobs.contains_key(job_type)
144 } else {
145 panic!("Must set either registry or job_defs");
146 }
147 }
148
149 pub fn limit_job_types(mut self, job_types: &[impl AsRef<str>]) -> Self {
152 self.jobs = job_types
153 .iter()
154 .map(|s| {
155 assert!(
156 self.has_job_type(s.as_ref()),
157 "Job type {} not found in registry",
158 s.as_ref()
159 );
160
161 SmartString::from(s.as_ref())
162 })
163 .collect();
164 self
165 }
166
167 pub fn min_concurrency(mut self, min_concurrency: u16) -> Self {
171 assert!(min_concurrency > 0);
172 self.min_concurrency = Some(min_concurrency);
173 self
174 }
175
176 pub fn max_concurrency(mut self, max_concurrency: u16) -> Self {
178 assert!(max_concurrency > 0);
179 self.max_concurrency = Some(max_concurrency);
180 self
181 }
182
183 pub async fn build(self) -> Result<Worker> {
185 let job_defs: HashMap<SmartString, JobRunner<CONTEXT>> =
186 if let Some(job_defs) = self.job_defs {
187 job_defs
188 .into_iter()
189 .filter(|job| self.jobs.is_empty() || self.jobs.contains(&job.name))
190 .map(|job| (job.name.clone(), job))
191 .collect()
192 } else if let Some(registry) = self.registry {
193 let job_list = if self.jobs.is_empty() {
194 registry.jobs.keys().cloned().collect()
195 } else {
196 self.jobs
197 };
198
199 job_list
200 .iter()
201 .filter_map(|job| {
202 registry
203 .jobs
204 .get(job)
205 .map(|job_def| (job.clone(), job_def.clone()))
206 })
207 .collect()
208 } else {
209 panic!("Must set either registry or jobs");
210 };
211
212 let max_concurrency = self.max_concurrency.unwrap_or(1).max(1);
213 let min_concurrency = self.min_concurrency.unwrap_or(max_concurrency).max(1);
214
215 let job_list = job_defs.keys().cloned().collect::<Vec<_>>();
216
217 event!(
218 Level::INFO,
219 ?job_list,
220 min_concurrency,
221 max_concurrency,
222 "Starting worker",
223 );
224
225 let (close_tx, close_rx) = oneshot::channel();
226
227 let mut workers = self.queue.state.workers.write().await;
228 let listener = workers.add_worker(&job_list);
229 drop(workers);
230
231 let counts = Arc::new(RunningJobs {
232 started: AtomicU64::new(0),
233 finished: AtomicU64::new(0),
234 current_weighted: AtomicU32::new(0),
235 job_finished: Notify::new(),
236 });
237
238 let worker_id = listener.id;
239 let worker_internal = WorkerInternal {
240 listener,
241 running_jobs: counts.clone(),
242 job_list: job_list.into_iter().map(String::from).collect(),
243 job_defs: Arc::new(job_defs),
244 queue: self.queue.state.clone(),
245 context: self.context,
246 min_concurrency,
247 max_concurrency,
248 };
249
250 let join_handle = tokio::spawn(worker_internal.run(close_rx));
251
252 Ok(Worker {
253 id: worker_id,
254 counts,
255 worker_list_task: Some(CancellableTask {
256 close_tx,
257 join_handle,
258 }),
259 })
260 }
261}
262
263pub(crate) struct RunningJobs {
264 pub started: AtomicU64,
265 pub finished: AtomicU64,
266 pub current_weighted: AtomicU32,
267 pub job_finished: Notify,
268}
269
270struct WorkerInternal<CONTEXT>
271where
272 CONTEXT: Send + Sync + Debug + Clone + 'static,
273{
274 listener: Arc<ListeningWorker>,
275 queue: SharedState,
276 job_list: Vec<String>,
277 job_defs: Arc<HashMap<SmartString, JobRunner<CONTEXT>>>,
278 running_jobs: Arc<RunningJobs>,
279 context: CONTEXT,
280 min_concurrency: u16,
281 max_concurrency: u16,
282}
283
284pub(crate) fn log_error<T, E>(result: Result<T, E>)
285where
286 E: std::error::Error,
287{
288 if let Err(e) = result {
289 event!(Level::ERROR, ?e);
290 }
291}
292
293impl<CONTEXT> WorkerInternal<CONTEXT>
294where
295 CONTEXT: Send + Sync + Debug + Clone + 'static,
296{
297 #[instrument(parent = None, name="worker_loop", skip_all, fields(worker_id = %self.listener.id))]
298 async fn run(self, mut close_rx: oneshot::Receiver<()>) {
299 let mut global_close_rx = self.queue.close.clone();
300 loop {
301 let mut running_jobs = self.running_jobs.current_weighted.load(Ordering::Relaxed);
302 let min_concurrency = self.min_concurrency as u32;
303 if running_jobs < min_concurrency {
304 log_error(self.run_ready_jobs().await);
305 running_jobs = self.running_jobs.current_weighted.load(Ordering::Relaxed);
306 }
307
308 let grab_new_jobs = running_jobs < min_concurrency;
309
310 tokio::select! {
311 biased;
312 _ = &mut close_rx => {
313 log_error(self.shutdown().await);
314 break;
315 }
316 _ = global_close_rx.changed() => {
317 log_error(self.shutdown().await);
318 break;
319 }
320 _ = self.listener.notify_task_ready.notified(), if grab_new_jobs => {
321 event!(Level::TRACE, "New task ready");
322 }
323 _ = self.running_jobs.job_finished.notified() => {
324 event!(Level::TRACE, "Job finished");
325 }
326 }
327 }
328 }
329
330 async fn shutdown(&self) -> Result<()> {
331 let mut running_jobs = self.running_jobs.current_weighted.load(Ordering::Relaxed);
332 while running_jobs > 0 {
333 self.running_jobs.job_finished.notified().await;
334 running_jobs = self.running_jobs.current_weighted.load(Ordering::Relaxed);
335 }
336
337 let mut workers = self.queue.workers.write().await;
338 workers.remove_worker(self.listener.id)
339 }
340
341 async fn run_ready_jobs(&self) -> Result<()> {
342 let running_count = self.running_jobs.current_weighted.load(Ordering::Relaxed);
343 let max_concurrency = self.max_concurrency as u32;
344 let max_jobs = max_concurrency - running_count;
345 let job_types = self
346 .job_list
347 .iter()
348 .map(|s| rusqlite::types::Value::from(s.clone()))
349 .collect::<Vec<_>>();
350
351 let running_jobs = self.running_jobs.clone();
352 let worker_id = self.listener.id;
353 let now = self.queue.time.now();
354 event!(Level::TRACE, %now, current_running = %running_count, %max_concurrency, "Checking ready jobs");
355
356 let (result_tx, result_rx) = oneshot::channel();
357 self.queue
358 .db_write_tx
359 .send(DbOperation {
360 worker_id,
361 span: Span::current(),
362 operation: DbOperationType::GetReadyJobs(GetReadyJobsArgs {
363 job_types,
364 max_jobs,
365 max_concurrency,
366 running_jobs,
367 now,
368 result_tx,
369 }),
370 })
371 .await
372 .map_err(|_| Error::QueueClosed)?;
373
374 let ready_jobs = result_rx.await.map_err(|_| Error::QueueClosed)??;
375
376 for job in ready_jobs {
377 self.run_job(job).await?;
378 }
379
380 Ok(())
381 }
382
383 #[instrument(level="debug", skip(self, done), fields(worker_id = %self.listener.id))]
384 async fn run_job(
385 &self,
386 ReadyJob {
387 job,
388 done_rx: mut done,
389 }: ReadyJob,
390 ) -> Result<()> {
391 let job_def = self
392 .job_defs
393 .get(job.job_type.as_str())
394 .expect("Got job for unsupported type");
395
396 let worker_id = self.listener.id;
397 let running = self.running_jobs.clone();
398 let autoheartbeat = job_def.autoheartbeat;
399 let time = job.queue.time.clone();
400
401 (job_def.runner)(job.clone(), self.context.clone());
402
403 tokio::spawn(async move {
404 let use_autohearbeat = autoheartbeat && job.heartbeat_increment > 0;
405 event!(Level::DEBUG, ?job, "Starting job monitor task");
406 loop {
407 let expires = job.expires.load(Ordering::Relaxed);
408 let expires_instant = time.instant_for_timestamp(expires);
409
410 tokio::select! {
411 _ = wait_for_next_autoheartbeat(&time, expires, job.heartbeat_increment), if use_autohearbeat => {
412 event!(Level::DEBUG, %job, "Sending autoheartbeat");
413 let new_time =
414 crate::job::send_heartbeat(job.job_id, worker_id, job.heartbeat_increment, &job.queue).await;
415
416 match new_time {
417 Ok(new_time) => job.expires.store(new_time.unix_timestamp(), Ordering::Relaxed),
418 Err(e) => event!(Level::ERROR, ?e),
419 }
420 }
421 _ = tokio::time::sleep_until(expires_instant) => {
422 event!(Level::DEBUG, %job, "Job expired");
423 let now_expires = job.expires.load(Ordering::Relaxed);
424 if now_expires == expires {
425 if !job.is_done().await {
426 log_error(job.fail("Job expired").await);
427 }
428 break;
429 }
430 }
431 _ = done.changed() => {
432 break;
433 }
434 }
435 }
436
437 running
440 .current_weighted
441 .fetch_sub(job.weight as u32, Ordering::Relaxed);
442 running.finished.fetch_add(1, Ordering::Relaxed);
443 running.job_finished.notify_one();
444 });
445
446 Ok(())
447 }
448}
449
450async fn wait_for_next_autoheartbeat(time: &Time, expires: i64, heartbeat_increment: i32) {
451 let now = time.now();
452 let before = (heartbeat_increment.min(30) / 2) as i64;
453 let next_heartbeat_time = expires - before;
454
455 let time_from_now = next_heartbeat_time - now.unix_timestamp();
456 let instant = Instant::now() + std::time::Duration::from_secs(time_from_now.max(0) as u64);
457
458 tokio::time::sleep_until(instant).await
459}
460
461#[cfg(test)]
462mod tests {
463 use crate::test_util::TestEnvironment;
464
465 use super::*;
466
467 #[tokio::test]
468 #[should_panic]
469 async fn worker_without_jobs_should_panic() {
470 let test = TestEnvironment::new().await;
471 Worker::builder(&test.queue, test.context.clone())
472 .build()
473 .await
474 .unwrap();
475 }
476}