1use crate::dispatcher::{Dispatcher, QueueConfig};
2use crate::executor::{BoxedWorker, JobError, JobExecutor, JobResult, Worker};
3use crate::heartbeat::HeartbeatService;
4use crate::maintenance::MaintenanceService;
5use awa_model::{JobArgs, JobRow};
6use serde::de::DeserializeOwned;
7use sqlx::PgPool;
8use std::any::{Any, TypeId};
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::RwLock;
14use tokio_util::sync::CancellationToken;
15use tracing::{info, warn};
16
17#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
19pub enum BuildError {
20 #[error("at least one queue must be configured")]
21 NoQueuesConfigured,
22}
23
24#[derive(Debug, Clone)]
26pub struct HealthCheck {
27 pub healthy: bool,
28 pub postgres_connected: bool,
29 pub poll_loop_alive: bool,
30 pub heartbeat_alive: bool,
31 pub shutting_down: bool,
32 pub leader: bool,
33 pub queues: HashMap<String, QueueHealth>,
34}
35
36#[derive(Debug, Clone)]
38pub struct QueueHealth {
39 pub in_flight: u32,
40 pub max_workers: u32,
41 pub available: u64,
42}
43
44pub struct ClientBuilder {
46 pool: PgPool,
47 queues: Vec<(String, QueueConfig)>,
48 workers: HashMap<String, BoxedWorker>,
49 state: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
50 heartbeat_interval: Duration,
51}
52
53impl ClientBuilder {
54 pub fn new(pool: PgPool) -> Self {
55 Self {
56 pool,
57 queues: Vec::new(),
58 workers: HashMap::new(),
59 state: HashMap::new(),
60 heartbeat_interval: Duration::from_secs(30),
61 }
62 }
63
64 pub fn queue(mut self, name: impl Into<String>, config: QueueConfig) -> Self {
66 self.queues.push((name.into(), config));
67 self
68 }
69
70 pub fn register<T, F, Fut>(mut self, handler: F) -> Self
75 where
76 T: JobArgs + DeserializeOwned + Send + Sync + 'static,
77 F: Fn(T, &crate::context::JobContext) -> Fut + Send + Sync + 'static,
78 Fut: std::future::Future<Output = Result<JobResult, JobError>> + Send + Sync + 'static,
79 {
80 let kind = T::kind().to_string();
81 let worker = TypedWorker {
82 kind: T::kind(),
83 handler: Arc::new(handler),
84 _phantom: std::marker::PhantomData,
85 };
86 self.workers.insert(kind, Box::new(worker));
87 self
88 }
89
90 pub fn register_worker(mut self, worker: impl Worker + 'static) -> Self {
92 let kind = worker.kind().to_string();
93 self.workers.insert(kind, Box::new(worker));
94 self
95 }
96
97 pub fn state<T: Any + Send + Sync + Clone>(mut self, value: T) -> Self {
99 self.state.insert(TypeId::of::<T>(), Box::new(value));
100 self
101 }
102
103 pub fn heartbeat_interval(mut self, interval: Duration) -> Self {
105 self.heartbeat_interval = interval;
106 self
107 }
108
109 pub fn build(self) -> Result<Client, BuildError> {
111 if self.queues.is_empty() {
112 return Err(BuildError::NoQueuesConfigured);
113 }
114
115 let metrics = crate::metrics::AwaMetrics::from_global();
116 let queue_in_flight = Arc::new(
117 self.queues
118 .iter()
119 .map(|(name, _)| (name.clone(), Arc::new(AtomicU32::new(0))))
120 .collect(),
121 );
122 let dispatcher_alive = Arc::new(
123 self.queues
124 .iter()
125 .map(|(name, _)| (name.clone(), Arc::new(AtomicBool::new(false))))
126 .collect(),
127 );
128
129 Ok(Client {
130 pool: self.pool,
131 queues: self.queues,
132 workers: Arc::new(self.workers),
133 state: Arc::new(self.state),
134 heartbeat_interval: self.heartbeat_interval,
135 cancel: CancellationToken::new(),
136 handles: RwLock::new(Vec::new()),
137 in_flight: Arc::new(RwLock::new(HashMap::new())),
138 queue_in_flight,
139 dispatcher_alive,
140 heartbeat_alive: Arc::new(AtomicBool::new(false)),
141 leader: Arc::new(AtomicBool::new(false)),
142 metrics,
143 })
144 }
145}
146
147struct TypedWorker<T, F, Fut>
149where
150 T: JobArgs + DeserializeOwned + Send + Sync + 'static,
151 F: Fn(T, &crate::context::JobContext) -> Fut + Send + Sync + 'static,
152 Fut: std::future::Future<Output = Result<JobResult, JobError>> + Send + Sync + 'static,
153{
154 kind: &'static str,
155 handler: Arc<F>,
156 _phantom: std::marker::PhantomData<fn() -> (T, Fut)>,
157}
158
159#[async_trait::async_trait]
160impl<T, F, Fut> Worker for TypedWorker<T, F, Fut>
161where
162 T: JobArgs + DeserializeOwned + Send + Sync + 'static,
163 F: Fn(T, &crate::context::JobContext) -> Fut + Send + Sync + 'static,
164 Fut: std::future::Future<Output = Result<JobResult, JobError>> + Send + Sync + 'static,
165{
166 fn kind(&self) -> &'static str {
167 self.kind
168 }
169
170 async fn perform(
171 &self,
172 job_row: &JobRow,
173 ctx: &crate::context::JobContext,
174 ) -> Result<JobResult, JobError> {
175 let args: T = serde_json::from_value(job_row.args.clone())
177 .map_err(|err| JobError::Terminal(format!("failed to deserialize args: {}", err)))?;
178
179 (self.handler)(args, ctx).await
180 }
181}
182
183pub struct Client {
185 pool: PgPool,
186 queues: Vec<(String, QueueConfig)>,
187 workers: Arc<HashMap<String, BoxedWorker>>,
188 state: Arc<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
189 heartbeat_interval: Duration,
190 cancel: CancellationToken,
191 handles: RwLock<Vec<tokio::task::JoinHandle<()>>>,
192 in_flight: Arc<RwLock<HashMap<i64, Arc<AtomicBool>>>>,
193 queue_in_flight: Arc<HashMap<String, Arc<AtomicU32>>>,
194 dispatcher_alive: Arc<HashMap<String, Arc<AtomicBool>>>,
195 heartbeat_alive: Arc<AtomicBool>,
196 leader: Arc<AtomicBool>,
197 metrics: crate::metrics::AwaMetrics,
198}
199
200impl Client {
201 pub fn builder(pool: PgPool) -> ClientBuilder {
203 ClientBuilder::new(pool)
204 }
205
206 pub async fn start(&self) -> Result<(), awa_model::AwaError> {
208 info!(
209 queues = self.queues.len(),
210 workers = self.workers.len(),
211 "Starting Awa worker runtime"
212 );
213
214 let executor = Arc::new(JobExecutor::new(
216 self.pool.clone(),
217 self.workers.clone(),
218 self.in_flight.clone(),
219 self.queue_in_flight.clone(),
220 self.state.clone(),
221 self.metrics.clone(),
222 ));
223
224 let mut handles = self.handles.write().await;
225
226 let heartbeat = HeartbeatService::new(
228 self.pool.clone(),
229 self.in_flight.clone(),
230 self.heartbeat_interval,
231 self.heartbeat_alive.clone(),
232 self.cancel.clone(),
233 );
234 handles.push(tokio::spawn(async move {
235 heartbeat.run().await;
236 }));
237
238 let maintenance =
240 MaintenanceService::new(self.pool.clone(), self.leader.clone(), self.cancel.clone());
241 handles.push(tokio::spawn(async move {
242 maintenance.run().await;
243 }));
244
245 for (queue_name, config) in &self.queues {
247 let dispatcher = Dispatcher::new(
248 queue_name.clone(),
249 config.clone(),
250 self.pool.clone(),
251 executor.clone(),
252 self.in_flight.clone(),
253 self.dispatcher_alive
254 .get(queue_name)
255 .cloned()
256 .unwrap_or_else(|| Arc::new(AtomicBool::new(false))),
257 self.cancel.clone(),
258 );
259 handles.push(tokio::spawn(async move {
260 dispatcher.run().await;
261 }));
262 }
263
264 info!("Awa worker runtime started");
265 Ok(())
266 }
267
268 pub async fn shutdown(&self, timeout: Duration) {
270 info!("Initiating graceful shutdown");
271
272 self.cancel.cancel();
274 {
275 let guard = self.in_flight.read().await;
276 for flag in guard.values() {
277 flag.store(true, Ordering::SeqCst);
278 }
279 }
280
281 let handles: Vec<_> = {
283 let mut guard = self.handles.write().await;
284 std::mem::take(&mut *guard)
285 };
286
287 let shutdown_future = async {
288 for handle in handles {
289 let _ = handle.await;
290 }
291 };
292
293 if tokio::time::timeout(timeout, shutdown_future)
294 .await
295 .is_err()
296 {
297 warn!(
298 timeout_secs = timeout.as_secs(),
299 "Shutdown timeout exceeded, some tasks may not have completed"
300 );
301 }
302
303 info!("Awa worker runtime stopped");
304 }
305
306 pub fn pool(&self) -> &PgPool {
308 &self.pool
309 }
310
311 pub async fn health_check(&self) -> HealthCheck {
313 let postgres_connected = sqlx::query("SELECT 1").execute(&self.pool).await.is_ok();
314 let poll_loop_alive = self
315 .dispatcher_alive
316 .values()
317 .all(|alive| alive.load(Ordering::SeqCst));
318 let heartbeat_alive = self.heartbeat_alive.load(Ordering::SeqCst);
319 let shutting_down = self.cancel.is_cancelled();
320 let leader = self.leader.load(Ordering::SeqCst);
321 let available_rows = sqlx::query_as::<_, (String, i64)>(
322 r#"
323 SELECT queue, count(*)::bigint AS available
324 FROM awa.jobs
325 WHERE state = 'available'
326 GROUP BY queue
327 "#,
328 )
329 .fetch_all(&self.pool)
330 .await
331 .unwrap_or_default();
332 let available_by_queue: HashMap<_, _> = available_rows.into_iter().collect();
333 let queues = self
334 .queues
335 .iter()
336 .map(|(queue, config)| {
337 let in_flight = self
338 .queue_in_flight
339 .get(queue)
340 .map(|counter| counter.load(Ordering::SeqCst))
341 .unwrap_or(0);
342 let available = available_by_queue.get(queue).copied().unwrap_or(0).max(0) as u64;
343 (
344 queue.clone(),
345 QueueHealth {
346 in_flight,
347 max_workers: config.max_workers,
348 available,
349 },
350 )
351 })
352 .collect();
353
354 HealthCheck {
355 healthy: postgres_connected && poll_loop_alive && heartbeat_alive && !shutting_down,
356 postgres_connected,
357 poll_loop_alive,
358 heartbeat_alive,
359 shutting_down,
360 leader,
361 queues,
362 }
363 }
364}