prefect/
worker.rs

1use ahash::HashMap;
2use std::fmt::Debug;
3use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
4use std::sync::Arc;
5use tokio::sync::{oneshot, Notify};
6use tokio::task::JoinHandle;
7use tokio::time::Instant;
8use tracing::{event, instrument, Level, Span};
9
10use crate::db_writer::ready_jobs::{GetReadyJobsArgs, ReadyJob};
11use crate::db_writer::{DbOperation, DbOperationType};
12use crate::job_registry::{JobRegistry, JobRunner};
13use crate::shared_state::{SharedState, Time};
14use crate::worker_list::ListeningWorker;
15use crate::{Error, Queue, Result, SmartString};
16
17/// The internal ID for a worker.
18pub type WorkerId = u64;
19
20struct CancellableTask {
21    close_tx: oneshot::Sender<()>,
22    join_handle: JoinHandle<()>,
23}
24
25/// A worker that runs jobs from the queue.
26pub struct Worker {
27    /// The worker's internal ID.
28    pub id: WorkerId,
29    counts: Arc<RunningJobs>,
30    worker_list_task: Option<CancellableTask>,
31}
32
33pub struct WorkerCounts {
34    pub started: u64,
35    pub finished: u64,
36}
37
38impl Worker {
39    /// Unregister a worker from the queue. It will still finish any jobs it is currently running,
40    /// but will no longer take new jobs.
41    pub async fn unregister(mut self, timeout: Option<std::time::Duration>) -> Result<()> {
42        if let Some(task) = self.worker_list_task.take() {
43            task.close_tx.send(()).ok();
44            if let Some(timeout) = timeout {
45                tokio::time::timeout(timeout, task.join_handle)
46                    .await
47                    .map_err(|_| Error::Timeout)??;
48            } else {
49                task.join_handle.await?;
50            }
51        }
52        Ok(())
53    }
54
55    /// Create a [WorkerBuilder] to build a new worker.
56    pub fn builder<CONTEXT>(queue: &Queue, context: CONTEXT) -> WorkerBuilder<CONTEXT>
57    where
58        CONTEXT: Send + Sync + Debug + Clone + 'static,
59    {
60        WorkerBuilder::new(queue, context)
61    }
62
63    /// Return some counts about the number of jobs this worker has processed.
64    pub fn counts(&self) -> WorkerCounts {
65        WorkerCounts {
66            started: self.counts.started.load(Ordering::Relaxed),
67            finished: self.counts.finished.load(Ordering::Relaxed),
68        }
69    }
70}
71
72impl Drop for Worker {
73    fn drop(&mut self) {
74        if let Some(task) = self.worker_list_task.take() {
75            task.close_tx.send(()).ok();
76            tokio::spawn(task.join_handle);
77        }
78    }
79}
80
81/// A builder object for a [Worker].
82pub struct WorkerBuilder<'a, CONTEXT>
83where
84    CONTEXT: Send + Sync + Debug + Clone + 'static,
85{
86    /// The job registry from which this worker should take its job functions.
87    registry: Option<&'a JobRegistry<CONTEXT>>,
88    job_defs: Option<Vec<JobRunner<CONTEXT>>>,
89    queue: &'a Queue,
90    /// The context value to send to the worker's jobs.
91    context: CONTEXT,
92    /// Limit the job types this worker will run. Defaults to all job types in the registry.
93    jobs: Vec<SmartString>,
94    /// Fetch new jobs when the number of running jobs drops to this number. Defaults to
95    /// the same as max_concurrency.
96    min_concurrency: Option<u16>,
97    /// The maximum number of jobs that can be run concurrently. Defaults to 1, but you will
98    /// usually want to set this to a higher number.
99    max_concurrency: Option<u16>,
100}
101
102impl<'a, CONTEXT> WorkerBuilder<'a, CONTEXT>
103where
104    CONTEXT: Send + Sync + Debug + Clone + 'static,
105{
106    /// Create a new [WorkerBuilder] for a particular [Queue].
107    pub fn new(queue: &'a Queue, context: CONTEXT) -> Self {
108        Self {
109            registry: None,
110            job_defs: None,
111            queue,
112            context,
113            jobs: Vec::new(),
114            min_concurrency: None,
115            max_concurrency: None,
116        }
117    }
118
119    /// Get the job definitions from this [JobRegistry].
120    pub fn registry(mut self, registry: &'a JobRegistry<CONTEXT>) -> Self {
121        if self.job_defs.is_some() {
122            panic!("Cannot set both registry and job_defs");
123        }
124
125        self.registry = Some(registry);
126        self
127    }
128
129    /// Get the job definitions from this list of [JobRunners](JobRunner).
130    pub fn jobs(mut self, jobs: impl Into<Vec<JobRunner<CONTEXT>>>) -> Self {
131        if self.job_defs.is_some() {
132            panic!("Cannot set both registry and job_defs");
133        }
134
135        self.job_defs = Some(jobs.into());
136        self
137    }
138
139    fn has_job_type(&self, job_type: &str) -> bool {
140        if let Some(job_defs) = self.job_defs.as_ref() {
141            job_defs.iter().any(|job_def| job_def.name == job_type)
142        } else if let Some(registry) = self.registry.as_ref() {
143            registry.jobs.contains_key(job_type)
144        } else {
145            panic!("Must set either registry or job_defs");
146        }
147    }
148
149    /// Limit this worker to only running these job types, even if the registry contains more
150    /// types.
151    pub fn limit_job_types(mut self, job_types: &[impl AsRef<str>]) -> Self {
152        self.jobs = job_types
153            .iter()
154            .map(|s| {
155                assert!(
156                    self.has_job_type(s.as_ref()),
157                    "Job type {} not found in registry",
158                    s.as_ref()
159                );
160
161                SmartString::from(s.as_ref())
162            })
163            .collect();
164        self
165    }
166
167    /// Set the minimum concurrency for this worker. When the number of running jobs falls below
168    /// this number, the worker will try to fetch more jobs, up to `max_concurrency`.
169    /// Defaults to the same as max_concurrency.
170    pub fn min_concurrency(mut self, min_concurrency: u16) -> Self {
171        assert!(min_concurrency > 0);
172        self.min_concurrency = Some(min_concurrency);
173        self
174    }
175
176    /// The maximum number of jobs that the worker will run concurrently. Defaults to 1.
177    pub fn max_concurrency(mut self, max_concurrency: u16) -> Self {
178        assert!(max_concurrency > 0);
179        self.max_concurrency = Some(max_concurrency);
180        self
181    }
182
183    /// Consume this [WorkerBuilder] and create a new [Worker].
184    pub async fn build(self) -> Result<Worker> {
185        let job_defs: HashMap<SmartString, JobRunner<CONTEXT>> =
186            if let Some(job_defs) = self.job_defs {
187                job_defs
188                    .into_iter()
189                    .filter(|job| self.jobs.is_empty() || self.jobs.contains(&job.name))
190                    .map(|job| (job.name.clone(), job))
191                    .collect()
192            } else if let Some(registry) = self.registry {
193                let job_list = if self.jobs.is_empty() {
194                    registry.jobs.keys().cloned().collect()
195                } else {
196                    self.jobs
197                };
198
199                job_list
200                    .iter()
201                    .filter_map(|job| {
202                        registry
203                            .jobs
204                            .get(job)
205                            .map(|job_def| (job.clone(), job_def.clone()))
206                    })
207                    .collect()
208            } else {
209                panic!("Must set either registry or jobs");
210            };
211
212        let max_concurrency = self.max_concurrency.unwrap_or(1).max(1);
213        let min_concurrency = self.min_concurrency.unwrap_or(max_concurrency).max(1);
214
215        let job_list = job_defs.keys().cloned().collect::<Vec<_>>();
216
217        event!(
218            Level::INFO,
219            ?job_list,
220            min_concurrency,
221            max_concurrency,
222            "Starting worker",
223        );
224
225        let (close_tx, close_rx) = oneshot::channel();
226
227        let mut workers = self.queue.state.workers.write().await;
228        let listener = workers.add_worker(&job_list);
229        drop(workers);
230
231        let counts = Arc::new(RunningJobs {
232            started: AtomicU64::new(0),
233            finished: AtomicU64::new(0),
234            current_weighted: AtomicU32::new(0),
235            job_finished: Notify::new(),
236        });
237
238        let worker_id = listener.id;
239        let worker_internal = WorkerInternal {
240            listener,
241            running_jobs: counts.clone(),
242            job_list: job_list.into_iter().map(String::from).collect(),
243            job_defs: Arc::new(job_defs),
244            queue: self.queue.state.clone(),
245            context: self.context,
246            min_concurrency,
247            max_concurrency,
248        };
249
250        let join_handle = tokio::spawn(worker_internal.run(close_rx));
251
252        Ok(Worker {
253            id: worker_id,
254            counts,
255            worker_list_task: Some(CancellableTask {
256                close_tx,
257                join_handle,
258            }),
259        })
260    }
261}
262
263pub(crate) struct RunningJobs {
264    pub started: AtomicU64,
265    pub finished: AtomicU64,
266    pub current_weighted: AtomicU32,
267    pub job_finished: Notify,
268}
269
270struct WorkerInternal<CONTEXT>
271where
272    CONTEXT: Send + Sync + Debug + Clone + 'static,
273{
274    listener: Arc<ListeningWorker>,
275    queue: SharedState,
276    job_list: Vec<String>,
277    job_defs: Arc<HashMap<SmartString, JobRunner<CONTEXT>>>,
278    running_jobs: Arc<RunningJobs>,
279    context: CONTEXT,
280    min_concurrency: u16,
281    max_concurrency: u16,
282}
283
284pub(crate) fn log_error<T, E>(result: Result<T, E>)
285where
286    E: std::error::Error,
287{
288    if let Err(e) = result {
289        event!(Level::ERROR, ?e);
290    }
291}
292
293impl<CONTEXT> WorkerInternal<CONTEXT>
294where
295    CONTEXT: Send + Sync + Debug + Clone + 'static,
296{
297    #[instrument(parent = None, name="worker_loop", skip_all, fields(worker_id = %self.listener.id))]
298    async fn run(self, mut close_rx: oneshot::Receiver<()>) {
299        let mut global_close_rx = self.queue.close.clone();
300        loop {
301            let mut running_jobs = self.running_jobs.current_weighted.load(Ordering::Relaxed);
302            let min_concurrency = self.min_concurrency as u32;
303            if running_jobs < min_concurrency {
304                log_error(self.run_ready_jobs().await);
305                running_jobs = self.running_jobs.current_weighted.load(Ordering::Relaxed);
306            }
307
308            let grab_new_jobs = running_jobs < min_concurrency;
309
310            tokio::select! {
311                biased;
312                _ = &mut close_rx => {
313                    log_error(self.shutdown().await);
314                    break;
315                }
316                _ = global_close_rx.changed() => {
317                    log_error(self.shutdown().await);
318                    break;
319                }
320                _ = self.listener.notify_task_ready.notified(), if grab_new_jobs  => {
321                    event!(Level::TRACE, "New task ready");
322                }
323                _ = self.running_jobs.job_finished.notified() => {
324                    event!(Level::TRACE, "Job finished");
325                }
326            }
327        }
328    }
329
330    async fn shutdown(&self) -> Result<()> {
331        let mut running_jobs = self.running_jobs.current_weighted.load(Ordering::Relaxed);
332        while running_jobs > 0 {
333            self.running_jobs.job_finished.notified().await;
334            running_jobs = self.running_jobs.current_weighted.load(Ordering::Relaxed);
335        }
336
337        let mut workers = self.queue.workers.write().await;
338        workers.remove_worker(self.listener.id)
339    }
340
341    async fn run_ready_jobs(&self) -> Result<()> {
342        let running_count = self.running_jobs.current_weighted.load(Ordering::Relaxed);
343        let max_concurrency = self.max_concurrency as u32;
344        let max_jobs = max_concurrency - running_count;
345        let job_types = self
346            .job_list
347            .iter()
348            .map(|s| rusqlite::types::Value::from(s.clone()))
349            .collect::<Vec<_>>();
350
351        let running_jobs = self.running_jobs.clone();
352        let worker_id = self.listener.id;
353        let now = self.queue.time.now();
354        event!(Level::TRACE, %now, current_running = %running_count, %max_concurrency, "Checking ready jobs");
355
356        let (result_tx, result_rx) = oneshot::channel();
357        self.queue
358            .db_write_tx
359            .send(DbOperation {
360                worker_id,
361                span: Span::current(),
362                operation: DbOperationType::GetReadyJobs(GetReadyJobsArgs {
363                    job_types,
364                    max_jobs,
365                    max_concurrency,
366                    running_jobs,
367                    now,
368                    result_tx,
369                }),
370            })
371            .await
372            .map_err(|_| Error::QueueClosed)?;
373
374        let ready_jobs = result_rx.await.map_err(|_| Error::QueueClosed)??;
375
376        for job in ready_jobs {
377            self.run_job(job).await?;
378        }
379
380        Ok(())
381    }
382
383    #[instrument(level="debug", skip(self, done), fields(worker_id = %self.listener.id))]
384    async fn run_job(
385        &self,
386        ReadyJob {
387            job,
388            done_rx: mut done,
389        }: ReadyJob,
390    ) -> Result<()> {
391        let job_def = self
392            .job_defs
393            .get(job.job_type.as_str())
394            .expect("Got job for unsupported type");
395
396        let worker_id = self.listener.id;
397        let running = self.running_jobs.clone();
398        let autoheartbeat = job_def.autoheartbeat;
399        let time = job.queue.time.clone();
400
401        (job_def.runner)(job.clone(), self.context.clone());
402
403        tokio::spawn(async move {
404            let use_autohearbeat = autoheartbeat && job.heartbeat_increment > 0;
405            event!(Level::DEBUG, ?job, "Starting job monitor task");
406            loop {
407                let expires = job.expires.load(Ordering::Relaxed);
408                let expires_instant = time.instant_for_timestamp(expires);
409
410                tokio::select! {
411                    _ = wait_for_next_autoheartbeat(&time, expires, job.heartbeat_increment), if use_autohearbeat => {
412                        event!(Level::DEBUG, %job, "Sending autoheartbeat");
413                        let new_time =
414                            crate::job::send_heartbeat(job.job_id, worker_id, job.heartbeat_increment, &job.queue).await;
415
416                        match new_time {
417                            Ok(new_time) => job.expires.store(new_time.unix_timestamp(), Ordering::Relaxed),
418                            Err(e) => event!(Level::ERROR, ?e),
419                        }
420                    }
421                    _ = tokio::time::sleep_until(expires_instant) => {
422                        event!(Level::DEBUG, %job, "Job expired");
423                        let now_expires = job.expires.load(Ordering::Relaxed);
424                        if now_expires == expires {
425                            if !job.is_done().await {
426                                log_error(job.fail("Job expired").await);
427                            }
428                            break;
429                        }
430                    }
431                    _ = done.changed() => {
432                        break;
433                    }
434                }
435            }
436
437            // Do this in a separate task from the job runner so that even if something goes horribly wrong
438            // we'll still be able to update the internal counts.
439            running
440                .current_weighted
441                .fetch_sub(job.weight as u32, Ordering::Relaxed);
442            running.finished.fetch_add(1, Ordering::Relaxed);
443            running.job_finished.notify_one();
444        });
445
446        Ok(())
447    }
448}
449
450async fn wait_for_next_autoheartbeat(time: &Time, expires: i64, heartbeat_increment: i32) {
451    let now = time.now();
452    let before = (heartbeat_increment.min(30) / 2) as i64;
453    let next_heartbeat_time = expires - before;
454
455    let time_from_now = next_heartbeat_time - now.unix_timestamp();
456    let instant = Instant::now() + std::time::Duration::from_secs(time_from_now.max(0) as u64);
457
458    tokio::time::sleep_until(instant).await
459}
460
461#[cfg(test)]
462mod tests {
463    use crate::test_util::TestEnvironment;
464
465    use super::*;
466
467    #[tokio::test]
468    #[should_panic]
469    async fn worker_without_jobs_should_panic() {
470        let test = TestEnvironment::new().await;
471        Worker::builder(&test.queue, test.context.clone())
472            .build()
473            .await
474            .unwrap();
475    }
476}