Skip to main content

awa_worker/
client.rs

1use crate::dispatcher::{Dispatcher, QueueConfig};
2use crate::executor::{BoxedWorker, JobError, JobExecutor, JobResult, Worker};
3use crate::heartbeat::HeartbeatService;
4use crate::maintenance::MaintenanceService;
5use awa_model::{JobArgs, JobRow};
6use serde::de::DeserializeOwned;
7use sqlx::PgPool;
8use std::any::{Any, TypeId};
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::RwLock;
14use tokio_util::sync::CancellationToken;
15use tracing::{info, warn};
16
17/// Errors returned when building a worker client.
18#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
19pub enum BuildError {
20    #[error("at least one queue must be configured")]
21    NoQueuesConfigured,
22}
23
24/// Health check result.
25#[derive(Debug, Clone)]
26pub struct HealthCheck {
27    pub healthy: bool,
28    pub postgres_connected: bool,
29    pub poll_loop_alive: bool,
30    pub heartbeat_alive: bool,
31    pub shutting_down: bool,
32    pub leader: bool,
33    pub queues: HashMap<String, QueueHealth>,
34}
35
36/// Per-queue health.
37#[derive(Debug, Clone)]
38pub struct QueueHealth {
39    pub in_flight: u32,
40    pub max_workers: u32,
41    pub available: u64,
42}
43
44/// Builder for the Awa worker client.
45pub struct ClientBuilder {
46    pool: PgPool,
47    queues: Vec<(String, QueueConfig)>,
48    workers: HashMap<String, BoxedWorker>,
49    state: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
50    heartbeat_interval: Duration,
51}
52
53impl ClientBuilder {
54    pub fn new(pool: PgPool) -> Self {
55        Self {
56            pool,
57            queues: Vec::new(),
58            workers: HashMap::new(),
59            state: HashMap::new(),
60            heartbeat_interval: Duration::from_secs(30),
61        }
62    }
63
64    /// Add a queue with its configuration.
65    pub fn queue(mut self, name: impl Into<String>, config: QueueConfig) -> Self {
66        self.queues.push((name.into(), config));
67        self
68    }
69
70    /// Register a typed worker.
71    ///
72    /// The worker handles jobs of type `T` where `T: JobArgs + DeserializeOwned`.
73    /// The handler function receives the deserialized args and job context.
74    pub fn register<T, F, Fut>(mut self, handler: F) -> Self
75    where
76        T: JobArgs + DeserializeOwned + Send + Sync + 'static,
77        F: Fn(T, &crate::context::JobContext) -> Fut + Send + Sync + 'static,
78        Fut: std::future::Future<Output = Result<JobResult, JobError>> + Send + Sync + 'static,
79    {
80        let kind = T::kind().to_string();
81        let worker = TypedWorker {
82            kind: T::kind(),
83            handler: Arc::new(handler),
84            _phantom: std::marker::PhantomData,
85        };
86        self.workers.insert(kind, Box::new(worker));
87        self
88    }
89
90    /// Register a raw worker implementation.
91    pub fn register_worker(mut self, worker: impl Worker + 'static) -> Self {
92        let kind = worker.kind().to_string();
93        self.workers.insert(kind, Box::new(worker));
94        self
95    }
96
97    /// Add shared state accessible via `ctx.extract::<T>()`.
98    pub fn state<T: Any + Send + Sync + Clone>(mut self, value: T) -> Self {
99        self.state.insert(TypeId::of::<T>(), Box::new(value));
100        self
101    }
102
103    /// Set the heartbeat interval (default: 30s).
104    pub fn heartbeat_interval(mut self, interval: Duration) -> Self {
105        self.heartbeat_interval = interval;
106        self
107    }
108
109    /// Build the client.
110    pub fn build(self) -> Result<Client, BuildError> {
111        if self.queues.is_empty() {
112            return Err(BuildError::NoQueuesConfigured);
113        }
114
115        let metrics = crate::metrics::AwaMetrics::from_global();
116        let queue_in_flight = Arc::new(
117            self.queues
118                .iter()
119                .map(|(name, _)| (name.clone(), Arc::new(AtomicU32::new(0))))
120                .collect(),
121        );
122        let dispatcher_alive = Arc::new(
123            self.queues
124                .iter()
125                .map(|(name, _)| (name.clone(), Arc::new(AtomicBool::new(false))))
126                .collect(),
127        );
128
129        Ok(Client {
130            pool: self.pool,
131            queues: self.queues,
132            workers: Arc::new(self.workers),
133            state: Arc::new(self.state),
134            heartbeat_interval: self.heartbeat_interval,
135            cancel: CancellationToken::new(),
136            handles: RwLock::new(Vec::new()),
137            in_flight: Arc::new(RwLock::new(HashMap::new())),
138            queue_in_flight,
139            dispatcher_alive,
140            heartbeat_alive: Arc::new(AtomicBool::new(false)),
141            leader: Arc::new(AtomicBool::new(false)),
142            metrics,
143        })
144    }
145}
146
147/// A typed worker that deserializes args and calls a handler function.
148struct TypedWorker<T, F, Fut>
149where
150    T: JobArgs + DeserializeOwned + Send + Sync + 'static,
151    F: Fn(T, &crate::context::JobContext) -> Fut + Send + Sync + 'static,
152    Fut: std::future::Future<Output = Result<JobResult, JobError>> + Send + Sync + 'static,
153{
154    kind: &'static str,
155    handler: Arc<F>,
156    _phantom: std::marker::PhantomData<fn() -> (T, Fut)>,
157}
158
159#[async_trait::async_trait]
160impl<T, F, Fut> Worker for TypedWorker<T, F, Fut>
161where
162    T: JobArgs + DeserializeOwned + Send + Sync + 'static,
163    F: Fn(T, &crate::context::JobContext) -> Fut + Send + Sync + 'static,
164    Fut: std::future::Future<Output = Result<JobResult, JobError>> + Send + Sync + 'static,
165{
166    fn kind(&self) -> &'static str {
167        self.kind
168    }
169
170    async fn perform(
171        &self,
172        job_row: &JobRow,
173        ctx: &crate::context::JobContext,
174    ) -> Result<JobResult, JobError> {
175        // Deserialize args
176        let args: T = serde_json::from_value(job_row.args.clone())
177            .map_err(|err| JobError::Terminal(format!("failed to deserialize args: {}", err)))?;
178
179        (self.handler)(args, ctx).await
180    }
181}
182
183/// The Awa worker client — manages dispatchers, heartbeat, and maintenance.
184pub struct Client {
185    pool: PgPool,
186    queues: Vec<(String, QueueConfig)>,
187    workers: Arc<HashMap<String, BoxedWorker>>,
188    state: Arc<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
189    heartbeat_interval: Duration,
190    cancel: CancellationToken,
191    handles: RwLock<Vec<tokio::task::JoinHandle<()>>>,
192    in_flight: Arc<RwLock<HashMap<i64, Arc<AtomicBool>>>>,
193    queue_in_flight: Arc<HashMap<String, Arc<AtomicU32>>>,
194    dispatcher_alive: Arc<HashMap<String, Arc<AtomicBool>>>,
195    heartbeat_alive: Arc<AtomicBool>,
196    leader: Arc<AtomicBool>,
197    metrics: crate::metrics::AwaMetrics,
198}
199
200impl Client {
201    /// Create a new builder.
202    pub fn builder(pool: PgPool) -> ClientBuilder {
203        ClientBuilder::new(pool)
204    }
205
206    /// Start the worker runtime. Spawns dispatchers, heartbeat, and maintenance.
207    pub async fn start(&self) -> Result<(), awa_model::AwaError> {
208        info!(
209            queues = self.queues.len(),
210            workers = self.workers.len(),
211            "Starting Awa worker runtime"
212        );
213
214        // Create executor with metrics
215        let executor = Arc::new(JobExecutor::new(
216            self.pool.clone(),
217            self.workers.clone(),
218            self.in_flight.clone(),
219            self.queue_in_flight.clone(),
220            self.state.clone(),
221            self.metrics.clone(),
222        ));
223
224        let mut handles = self.handles.write().await;
225
226        // Start heartbeat service
227        let heartbeat = HeartbeatService::new(
228            self.pool.clone(),
229            self.in_flight.clone(),
230            self.heartbeat_interval,
231            self.heartbeat_alive.clone(),
232            self.cancel.clone(),
233        );
234        handles.push(tokio::spawn(async move {
235            heartbeat.run().await;
236        }));
237
238        // Start maintenance service
239        let maintenance =
240            MaintenanceService::new(self.pool.clone(), self.leader.clone(), self.cancel.clone());
241        handles.push(tokio::spawn(async move {
242            maintenance.run().await;
243        }));
244
245        // Start a dispatcher per queue
246        for (queue_name, config) in &self.queues {
247            let dispatcher = Dispatcher::new(
248                queue_name.clone(),
249                config.clone(),
250                self.pool.clone(),
251                executor.clone(),
252                self.in_flight.clone(),
253                self.dispatcher_alive
254                    .get(queue_name)
255                    .cloned()
256                    .unwrap_or_else(|| Arc::new(AtomicBool::new(false))),
257                self.cancel.clone(),
258            );
259            handles.push(tokio::spawn(async move {
260                dispatcher.run().await;
261            }));
262        }
263
264        info!("Awa worker runtime started");
265        Ok(())
266    }
267
268    /// Graceful shutdown with drain timeout.
269    pub async fn shutdown(&self, timeout: Duration) {
270        info!("Initiating graceful shutdown");
271
272        // Signal all tasks to stop
273        self.cancel.cancel();
274        {
275            let guard = self.in_flight.read().await;
276            for flag in guard.values() {
277                flag.store(true, Ordering::SeqCst);
278            }
279        }
280
281        // Wait for handles with timeout
282        let handles: Vec<_> = {
283            let mut guard = self.handles.write().await;
284            std::mem::take(&mut *guard)
285        };
286
287        let shutdown_future = async {
288            for handle in handles {
289                let _ = handle.await;
290            }
291        };
292
293        if tokio::time::timeout(timeout, shutdown_future)
294            .await
295            .is_err()
296        {
297            warn!(
298                timeout_secs = timeout.as_secs(),
299                "Shutdown timeout exceeded, some tasks may not have completed"
300            );
301        }
302
303        info!("Awa worker runtime stopped");
304    }
305
306    /// Get the pool reference.
307    pub fn pool(&self) -> &PgPool {
308        &self.pool
309    }
310
311    /// Health check.
312    pub async fn health_check(&self) -> HealthCheck {
313        let postgres_connected = sqlx::query("SELECT 1").execute(&self.pool).await.is_ok();
314        let poll_loop_alive = self
315            .dispatcher_alive
316            .values()
317            .all(|alive| alive.load(Ordering::SeqCst));
318        let heartbeat_alive = self.heartbeat_alive.load(Ordering::SeqCst);
319        let shutting_down = self.cancel.is_cancelled();
320        let leader = self.leader.load(Ordering::SeqCst);
321        let available_rows = sqlx::query_as::<_, (String, i64)>(
322            r#"
323            SELECT queue, count(*)::bigint AS available
324            FROM awa.jobs
325            WHERE state = 'available'
326            GROUP BY queue
327            "#,
328        )
329        .fetch_all(&self.pool)
330        .await
331        .unwrap_or_default();
332        let available_by_queue: HashMap<_, _> = available_rows.into_iter().collect();
333        let queues = self
334            .queues
335            .iter()
336            .map(|(queue, config)| {
337                let in_flight = self
338                    .queue_in_flight
339                    .get(queue)
340                    .map(|counter| counter.load(Ordering::SeqCst))
341                    .unwrap_or(0);
342                let available = available_by_queue.get(queue).copied().unwrap_or(0).max(0) as u64;
343                (
344                    queue.clone(),
345                    QueueHealth {
346                        in_flight,
347                        max_workers: config.max_workers,
348                        available,
349                    },
350                )
351            })
352            .collect();
353
354        HealthCheck {
355            healthy: postgres_connected && poll_loop_alive && heartbeat_alive && !shutting_down,
356            postgres_connected,
357            poll_loop_alive,
358            heartbeat_alive,
359            shutting_down,
360            leader,
361            queues,
362        }
363    }
364}