Skip to main content

awa_worker/
client.rs

1use crate::dispatcher::{ConcurrencyMode, Dispatcher, OverflowPool, QueueConfig};
2use crate::executor::{BoxedWorker, JobError, JobExecutor, JobResult, Worker};
3use crate::heartbeat::HeartbeatService;
4use crate::maintenance::MaintenanceService;
5use awa_model::{JobArgs, JobRow, PeriodicJob};
6use serde::de::DeserializeOwned;
7use sqlx::PgPool;
8use std::any::{Any, TypeId};
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::{Mutex, RwLock};
14use tokio::task::JoinSet;
15use tokio_util::sync::CancellationToken;
16use tracing::{info, warn};
17
18/// Errors returned when building a worker client.
19#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
20pub enum BuildError {
21    #[error("at least one queue must be configured")]
22    NoQueuesConfigured,
23    #[error("sum of min_workers ({total_min}) exceeds global_max_workers ({global_max})")]
24    MinWorkersExceedGlobal { total_min: u32, global_max: u32 },
25    #[error("rate_limit max_rate must be > 0.0")]
26    InvalidRateLimit,
27    #[error("queue weight must be > 0")]
28    InvalidWeight,
29}
30
31/// Health check result.
32#[derive(Debug, Clone)]
33pub struct HealthCheck {
34    pub healthy: bool,
35    pub postgres_connected: bool,
36    pub poll_loop_alive: bool,
37    pub heartbeat_alive: bool,
38    pub shutting_down: bool,
39    pub leader: bool,
40    pub queues: HashMap<String, QueueHealth>,
41}
42
43/// Per-queue health.
44#[derive(Debug, Clone)]
45pub struct QueueHealth {
46    pub in_flight: u32,
47    pub available: u64,
48    /// Capacity interpretation depends on mode.
49    pub capacity: QueueCapacity,
50}
51
52/// Capacity information for a queue, mode-dependent.
53#[derive(Debug, Clone)]
54pub enum QueueCapacity {
55    /// Hard-reserved: fixed max.
56    HardReserved { max_workers: u32 },
57    /// Weighted: min guaranteed + current overflow.
58    Weighted {
59        min_workers: u32,
60        weight: u32,
61        overflow_held: u32,
62    },
63}
64
65/// Builder for the Awa worker client.
66pub struct ClientBuilder {
67    pool: PgPool,
68    queues: Vec<(String, QueueConfig)>,
69    workers: HashMap<String, BoxedWorker>,
70    state: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
71    heartbeat_interval: Duration,
72    periodic_jobs: Vec<PeriodicJob>,
73    global_max_workers: Option<u32>,
74}
75
76impl ClientBuilder {
77    pub fn new(pool: PgPool) -> Self {
78        Self {
79            pool,
80            queues: Vec::new(),
81            workers: HashMap::new(),
82            state: HashMap::new(),
83            heartbeat_interval: Duration::from_secs(30),
84            periodic_jobs: Vec::new(),
85            global_max_workers: None,
86        }
87    }
88
89    /// Add a queue with its configuration.
90    pub fn queue(mut self, name: impl Into<String>, config: QueueConfig) -> Self {
91        self.queues.push((name.into(), config));
92        self
93    }
94
95    /// Register a typed worker.
96    ///
97    /// The worker handles jobs of type `T` where `T: JobArgs + DeserializeOwned`.
98    /// The handler function receives the deserialized args and job context.
99    pub fn register<T, F, Fut>(mut self, handler: F) -> Self
100    where
101        T: JobArgs + DeserializeOwned + Send + Sync + 'static,
102        F: Fn(T, &crate::context::JobContext) -> Fut + Send + Sync + 'static,
103        Fut: std::future::Future<Output = Result<JobResult, JobError>> + Send + Sync + 'static,
104    {
105        let kind = T::kind().to_string();
106        let worker = TypedWorker {
107            kind: T::kind(),
108            handler: Arc::new(handler),
109            _phantom: std::marker::PhantomData,
110        };
111        self.workers.insert(kind, Box::new(worker));
112        self
113    }
114
115    /// Register a raw worker implementation.
116    pub fn register_worker(mut self, worker: impl Worker + 'static) -> Self {
117        let kind = worker.kind().to_string();
118        self.workers.insert(kind, Box::new(worker));
119        self
120    }
121
122    /// Add shared state accessible via `ctx.extract::<T>()`.
123    pub fn state<T: Any + Send + Sync + Clone>(mut self, value: T) -> Self {
124        self.state.insert(TypeId::of::<T>(), Box::new(value));
125        self
126    }
127
128    /// Set the heartbeat interval (default: 30s).
129    pub fn heartbeat_interval(mut self, interval: Duration) -> Self {
130        self.heartbeat_interval = interval;
131        self
132    }
133
134    /// Set a global maximum worker count across all queues (enables weighted mode).
135    ///
136    /// When set, each queue gets `min_workers` guaranteed permits plus a share
137    /// of the remaining overflow capacity based on `weight`.
138    pub fn global_max_workers(mut self, max: u32) -> Self {
139        self.global_max_workers = Some(max);
140        self
141    }
142
143    /// Register a periodic (cron) job schedule.
144    ///
145    /// The schedule is synced to the database by the leader and evaluated
146    /// every second. When a fire is due, a job is atomically enqueued.
147    pub fn periodic(mut self, job: PeriodicJob) -> Self {
148        self.periodic_jobs.push(job);
149        self
150    }
151
152    /// Build the client.
153    pub fn build(self) -> Result<Client, BuildError> {
154        if self.queues.is_empty() {
155            return Err(BuildError::NoQueuesConfigured);
156        }
157
158        // Validate rate limits and weights
159        for (_, config) in &self.queues {
160            if let Some(rl) = &config.rate_limit {
161                if rl.max_rate <= 0.0 {
162                    return Err(BuildError::InvalidRateLimit);
163                }
164            }
165            if config.weight == 0 {
166                return Err(BuildError::InvalidWeight);
167            }
168        }
169
170        // Validate weighted mode constraints
171        let overflow_pool = if let Some(global_max) = self.global_max_workers {
172            let total_min: u32 = self.queues.iter().map(|(_, c)| c.min_workers).sum();
173            if total_min > global_max {
174                return Err(BuildError::MinWorkersExceedGlobal {
175                    total_min,
176                    global_max,
177                });
178            }
179            let overflow_capacity = global_max - total_min;
180            let weights: HashMap<String, u32> = self
181                .queues
182                .iter()
183                .map(|(name, c)| (name.clone(), c.weight.max(1)))
184                .collect();
185            Some(Arc::new(OverflowPool::new(overflow_capacity, weights)))
186        } else {
187            None
188        };
189
190        let metrics = crate::metrics::AwaMetrics::from_global();
191        let queue_in_flight = Arc::new(
192            self.queues
193                .iter()
194                .map(|(name, _)| (name.clone(), Arc::new(AtomicU32::new(0))))
195                .collect(),
196        );
197        let dispatcher_alive = Arc::new(
198            self.queues
199                .iter()
200                .map(|(name, _)| (name.clone(), Arc::new(AtomicBool::new(false))))
201                .collect(),
202        );
203
204        Ok(Client {
205            pool: self.pool,
206            queues: self.queues,
207            workers: Arc::new(self.workers),
208            state: Arc::new(self.state),
209            heartbeat_interval: self.heartbeat_interval,
210            periodic_jobs: Arc::new(self.periodic_jobs),
211            dispatch_cancel: CancellationToken::new(),
212            service_cancel: CancellationToken::new(),
213            dispatcher_handles: RwLock::new(Vec::new()),
214            service_handles: RwLock::new(Vec::new()),
215            job_set: Arc::new(Mutex::new(JoinSet::new())),
216            in_flight: Arc::new(RwLock::new(HashMap::new())),
217            queue_in_flight,
218            dispatcher_alive,
219            heartbeat_alive: Arc::new(AtomicBool::new(false)),
220            leader: Arc::new(AtomicBool::new(false)),
221            overflow_pool,
222            metrics,
223        })
224    }
225}
226
227/// A typed worker that deserializes args and calls a handler function.
228struct TypedWorker<T, F, Fut>
229where
230    T: JobArgs + DeserializeOwned + Send + Sync + 'static,
231    F: Fn(T, &crate::context::JobContext) -> Fut + Send + Sync + 'static,
232    Fut: std::future::Future<Output = Result<JobResult, JobError>> + Send + Sync + 'static,
233{
234    kind: &'static str,
235    handler: Arc<F>,
236    _phantom: std::marker::PhantomData<fn() -> (T, Fut)>,
237}
238
239#[async_trait::async_trait]
240impl<T, F, Fut> Worker for TypedWorker<T, F, Fut>
241where
242    T: JobArgs + DeserializeOwned + Send + Sync + 'static,
243    F: Fn(T, &crate::context::JobContext) -> Fut + Send + Sync + 'static,
244    Fut: std::future::Future<Output = Result<JobResult, JobError>> + Send + Sync + 'static,
245{
246    fn kind(&self) -> &'static str {
247        self.kind
248    }
249
250    async fn perform(
251        &self,
252        job_row: &JobRow,
253        ctx: &crate::context::JobContext,
254    ) -> Result<JobResult, JobError> {
255        // Deserialize args
256        let args: T = serde_json::from_value(job_row.args.clone())
257            .map_err(|err| JobError::Terminal(format!("failed to deserialize args: {}", err)))?;
258
259        (self.handler)(args, ctx).await
260    }
261}
262
263/// The Awa worker client — manages dispatchers, heartbeat, and maintenance.
264pub struct Client {
265    pool: PgPool,
266    queues: Vec<(String, QueueConfig)>,
267    workers: Arc<HashMap<String, BoxedWorker>>,
268    state: Arc<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
269    heartbeat_interval: Duration,
270    periodic_jobs: Arc<Vec<PeriodicJob>>,
271    /// Cancellation token for dispatchers only — stops claiming new jobs.
272    dispatch_cancel: CancellationToken,
273    /// Cancellation token for heartbeat + maintenance — kept alive during drain.
274    service_cancel: CancellationToken,
275    /// Handles for dispatcher tasks.
276    dispatcher_handles: RwLock<Vec<tokio::task::JoinHandle<()>>>,
277    /// Handles for service tasks (heartbeat + maintenance).
278    service_handles: RwLock<Vec<tokio::task::JoinHandle<()>>>,
279    /// JoinSet tracking in-flight job tasks for graceful drain.
280    job_set: Arc<Mutex<JoinSet<()>>>,
281    in_flight: Arc<RwLock<HashMap<i64, Arc<AtomicBool>>>>,
282    queue_in_flight: Arc<HashMap<String, Arc<AtomicU32>>>,
283    dispatcher_alive: Arc<HashMap<String, Arc<AtomicBool>>>,
284    heartbeat_alive: Arc<AtomicBool>,
285    leader: Arc<AtomicBool>,
286    /// Shared overflow pool for weighted mode (None in hard-reserved mode).
287    overflow_pool: Option<Arc<OverflowPool>>,
288    metrics: crate::metrics::AwaMetrics,
289}
290
291impl Client {
292    /// Create a new builder.
293    pub fn builder(pool: PgPool) -> ClientBuilder {
294        ClientBuilder::new(pool)
295    }
296
297    /// Start the worker runtime. Spawns dispatchers, heartbeat, and maintenance.
298    pub async fn start(&self) -> Result<(), awa_model::AwaError> {
299        info!(
300            queues = self.queues.len(),
301            workers = self.workers.len(),
302            "Starting Awa worker runtime"
303        );
304
305        // Create executor with metrics
306        let executor = Arc::new(JobExecutor::new(
307            self.pool.clone(),
308            self.workers.clone(),
309            self.in_flight.clone(),
310            self.queue_in_flight.clone(),
311            self.state.clone(),
312            self.metrics.clone(),
313        ));
314
315        let mut service_handles = self.service_handles.write().await;
316
317        // Start heartbeat service (uses service_cancel — stays alive during drain)
318        let heartbeat = HeartbeatService::new(
319            self.pool.clone(),
320            self.in_flight.clone(),
321            self.heartbeat_interval,
322            self.heartbeat_alive.clone(),
323            self.service_cancel.clone(),
324        );
325        service_handles.push(tokio::spawn(async move {
326            heartbeat.run().await;
327        }));
328
329        // Start maintenance service (uses service_cancel — stays alive during drain)
330        let maintenance = MaintenanceService::new(
331            self.pool.clone(),
332            self.leader.clone(),
333            self.service_cancel.clone(),
334            self.periodic_jobs.clone(),
335            self.in_flight.clone(),
336        );
337        service_handles.push(tokio::spawn(async move {
338            maintenance.run().await;
339        }));
340
341        // Start a dispatcher per queue (uses dispatch_cancel — stops claiming first)
342        let mut dispatcher_handles = self.dispatcher_handles.write().await;
343        for (queue_name, config) in &self.queues {
344            let alive = self
345                .dispatcher_alive
346                .get(queue_name)
347                .cloned()
348                .unwrap_or_else(|| Arc::new(AtomicBool::new(false)));
349
350            let dispatcher = if let Some(overflow_pool) = &self.overflow_pool {
351                // Weighted mode
352                let concurrency = ConcurrencyMode::Weighted {
353                    local_semaphore: Arc::new(tokio::sync::Semaphore::new(
354                        config.min_workers as usize,
355                    )),
356                    overflow_pool: overflow_pool.clone(),
357                    queue_name: queue_name.clone(),
358                };
359                Dispatcher::with_concurrency(
360                    queue_name.clone(),
361                    config.clone(),
362                    self.pool.clone(),
363                    executor.clone(),
364                    self.in_flight.clone(),
365                    alive,
366                    self.dispatch_cancel.clone(),
367                    self.job_set.clone(),
368                    concurrency,
369                )
370            } else {
371                // Hard-reserved mode (default)
372                Dispatcher::new(
373                    queue_name.clone(),
374                    config.clone(),
375                    self.pool.clone(),
376                    executor.clone(),
377                    self.in_flight.clone(),
378                    alive,
379                    self.dispatch_cancel.clone(),
380                    self.job_set.clone(),
381                )
382            };
383            dispatcher_handles.push(tokio::spawn(async move {
384                dispatcher.run().await;
385            }));
386        }
387
388        info!("Awa worker runtime started");
389        Ok(())
390    }
391
392    /// Graceful shutdown with drain timeout.
393    ///
394    /// Phased lifecycle:
395    /// 1. Stop dispatchers (no new jobs claimed)
396    /// 2. Signal in-flight jobs to cancel
397    /// 3. Wait for dispatchers to exit
398    /// 4. Drain in-flight jobs (heartbeat + maintenance still alive!)
399    /// 5. Stop heartbeat + maintenance
400    pub async fn shutdown(&self, timeout: Duration) {
401        info!("Initiating graceful shutdown");
402
403        // Phase 1: Stop claiming new jobs
404        self.dispatch_cancel.cancel();
405
406        // Phase 2: Signal in-flight cancellation flags
407        {
408            let guard = self.in_flight.read().await;
409            for flag in guard.values() {
410                flag.store(true, Ordering::SeqCst);
411            }
412        }
413
414        // Phase 3: Wait for dispatchers to exit their poll loops
415        let dispatcher_handles: Vec<_> = {
416            let mut guard = self.dispatcher_handles.write().await;
417            std::mem::take(&mut *guard)
418        };
419        for handle in dispatcher_handles {
420            let _ = handle.await;
421        }
422
423        // Phase 4: Drain in-flight jobs (heartbeat + maintenance still alive)
424        let drain = async {
425            let mut set = self.job_set.lock().await;
426            while set.join_next().await.is_some() {}
427        };
428        if tokio::time::timeout(timeout, drain).await.is_err() {
429            warn!(
430                timeout_secs = timeout.as_secs(),
431                "Shutdown drain timeout exceeded, some jobs may not have completed"
432            );
433        }
434
435        // Phase 5: Stop background services (heartbeat + maintenance)
436        self.service_cancel.cancel();
437        let service_handles: Vec<_> = {
438            let mut guard = self.service_handles.write().await;
439            std::mem::take(&mut *guard)
440        };
441        for handle in service_handles {
442            let _ = handle.await;
443        }
444
445        info!("Awa worker runtime stopped");
446    }
447
448    /// Get the pool reference.
449    pub fn pool(&self) -> &PgPool {
450        &self.pool
451    }
452
453    /// Health check.
454    pub async fn health_check(&self) -> HealthCheck {
455        let postgres_connected = sqlx::query("SELECT 1").execute(&self.pool).await.is_ok();
456        let poll_loop_alive = self
457            .dispatcher_alive
458            .values()
459            .all(|alive| alive.load(Ordering::SeqCst));
460        let heartbeat_alive = self.heartbeat_alive.load(Ordering::SeqCst);
461        let shutting_down = self.dispatch_cancel.is_cancelled();
462        let leader = self.leader.load(Ordering::SeqCst);
463        let available_rows = sqlx::query_as::<_, (String, i64)>(
464            r#"
465            SELECT queue, count(*)::bigint AS available
466            FROM awa.jobs
467            WHERE state = 'available'
468            GROUP BY queue
469            "#,
470        )
471        .fetch_all(&self.pool)
472        .await
473        .unwrap_or_default();
474        let available_by_queue: HashMap<_, _> = available_rows.into_iter().collect();
475        let queues = self
476            .queues
477            .iter()
478            .map(|(queue, config)| {
479                let in_flight = self
480                    .queue_in_flight
481                    .get(queue)
482                    .map(|counter| counter.load(Ordering::SeqCst))
483                    .unwrap_or(0);
484                let available = available_by_queue.get(queue).copied().unwrap_or(0).max(0) as u64;
485                let capacity = if let Some(overflow_pool) = &self.overflow_pool {
486                    QueueCapacity::Weighted {
487                        min_workers: config.min_workers,
488                        weight: config.weight,
489                        overflow_held: overflow_pool.held(queue),
490                    }
491                } else {
492                    QueueCapacity::HardReserved {
493                        max_workers: config.max_workers,
494                    }
495                };
496                (
497                    queue.clone(),
498                    QueueHealth {
499                        in_flight,
500                        available,
501                        capacity,
502                    },
503                )
504            })
505            .collect();
506
507        HealthCheck {
508            healthy: postgres_connected && poll_loop_alive && heartbeat_alive && !shutting_down,
509            postgres_connected,
510            poll_loop_alive,
511            heartbeat_alive,
512            shutting_down,
513            leader,
514            queues,
515        }
516    }
517}