1use crate::dispatcher::{ConcurrencyMode, Dispatcher, OverflowPool, QueueConfig};
2use crate::executor::{BoxedWorker, JobError, JobExecutor, JobResult, Worker};
3use crate::heartbeat::HeartbeatService;
4use crate::maintenance::MaintenanceService;
5use awa_model::{JobArgs, JobRow, PeriodicJob};
6use serde::de::DeserializeOwned;
7use sqlx::PgPool;
8use std::any::{Any, TypeId};
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::{Mutex, RwLock};
14use tokio::task::JoinSet;
15use tokio_util::sync::CancellationToken;
16use tracing::{info, warn};
17
18#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
20pub enum BuildError {
21 #[error("at least one queue must be configured")]
22 NoQueuesConfigured,
23 #[error("sum of min_workers ({total_min}) exceeds global_max_workers ({global_max})")]
24 MinWorkersExceedGlobal { total_min: u32, global_max: u32 },
25 #[error("rate_limit max_rate must be > 0.0")]
26 InvalidRateLimit,
27 #[error("queue weight must be > 0")]
28 InvalidWeight,
29}
30
31#[derive(Debug, Clone)]
33pub struct HealthCheck {
34 pub healthy: bool,
35 pub postgres_connected: bool,
36 pub poll_loop_alive: bool,
37 pub heartbeat_alive: bool,
38 pub shutting_down: bool,
39 pub leader: bool,
40 pub queues: HashMap<String, QueueHealth>,
41}
42
43#[derive(Debug, Clone)]
45pub struct QueueHealth {
46 pub in_flight: u32,
47 pub available: u64,
48 pub capacity: QueueCapacity,
50}
51
52#[derive(Debug, Clone)]
54pub enum QueueCapacity {
55 HardReserved { max_workers: u32 },
57 Weighted {
59 min_workers: u32,
60 weight: u32,
61 overflow_held: u32,
62 },
63}
64
65pub struct ClientBuilder {
67 pool: PgPool,
68 queues: Vec<(String, QueueConfig)>,
69 workers: HashMap<String, BoxedWorker>,
70 state: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
71 heartbeat_interval: Duration,
72 periodic_jobs: Vec<PeriodicJob>,
73 global_max_workers: Option<u32>,
74}
75
76impl ClientBuilder {
77 pub fn new(pool: PgPool) -> Self {
78 Self {
79 pool,
80 queues: Vec::new(),
81 workers: HashMap::new(),
82 state: HashMap::new(),
83 heartbeat_interval: Duration::from_secs(30),
84 periodic_jobs: Vec::new(),
85 global_max_workers: None,
86 }
87 }
88
89 pub fn queue(mut self, name: impl Into<String>, config: QueueConfig) -> Self {
91 self.queues.push((name.into(), config));
92 self
93 }
94
95 pub fn register<T, F, Fut>(mut self, handler: F) -> Self
100 where
101 T: JobArgs + DeserializeOwned + Send + Sync + 'static,
102 F: Fn(T, &crate::context::JobContext) -> Fut + Send + Sync + 'static,
103 Fut: std::future::Future<Output = Result<JobResult, JobError>> + Send + Sync + 'static,
104 {
105 let kind = T::kind().to_string();
106 let worker = TypedWorker {
107 kind: T::kind(),
108 handler: Arc::new(handler),
109 _phantom: std::marker::PhantomData,
110 };
111 self.workers.insert(kind, Box::new(worker));
112 self
113 }
114
115 pub fn register_worker(mut self, worker: impl Worker + 'static) -> Self {
117 let kind = worker.kind().to_string();
118 self.workers.insert(kind, Box::new(worker));
119 self
120 }
121
122 pub fn state<T: Any + Send + Sync + Clone>(mut self, value: T) -> Self {
124 self.state.insert(TypeId::of::<T>(), Box::new(value));
125 self
126 }
127
128 pub fn heartbeat_interval(mut self, interval: Duration) -> Self {
130 self.heartbeat_interval = interval;
131 self
132 }
133
134 pub fn global_max_workers(mut self, max: u32) -> Self {
139 self.global_max_workers = Some(max);
140 self
141 }
142
143 pub fn periodic(mut self, job: PeriodicJob) -> Self {
148 self.periodic_jobs.push(job);
149 self
150 }
151
152 pub fn build(self) -> Result<Client, BuildError> {
154 if self.queues.is_empty() {
155 return Err(BuildError::NoQueuesConfigured);
156 }
157
158 for (_, config) in &self.queues {
160 if let Some(rl) = &config.rate_limit {
161 if rl.max_rate <= 0.0 {
162 return Err(BuildError::InvalidRateLimit);
163 }
164 }
165 if config.weight == 0 {
166 return Err(BuildError::InvalidWeight);
167 }
168 }
169
170 let overflow_pool = if let Some(global_max) = self.global_max_workers {
172 let total_min: u32 = self.queues.iter().map(|(_, c)| c.min_workers).sum();
173 if total_min > global_max {
174 return Err(BuildError::MinWorkersExceedGlobal {
175 total_min,
176 global_max,
177 });
178 }
179 let overflow_capacity = global_max - total_min;
180 let weights: HashMap<String, u32> = self
181 .queues
182 .iter()
183 .map(|(name, c)| (name.clone(), c.weight.max(1)))
184 .collect();
185 Some(Arc::new(OverflowPool::new(overflow_capacity, weights)))
186 } else {
187 None
188 };
189
190 let metrics = crate::metrics::AwaMetrics::from_global();
191 let queue_in_flight = Arc::new(
192 self.queues
193 .iter()
194 .map(|(name, _)| (name.clone(), Arc::new(AtomicU32::new(0))))
195 .collect(),
196 );
197 let dispatcher_alive = Arc::new(
198 self.queues
199 .iter()
200 .map(|(name, _)| (name.clone(), Arc::new(AtomicBool::new(false))))
201 .collect(),
202 );
203
204 Ok(Client {
205 pool: self.pool,
206 queues: self.queues,
207 workers: Arc::new(self.workers),
208 state: Arc::new(self.state),
209 heartbeat_interval: self.heartbeat_interval,
210 periodic_jobs: Arc::new(self.periodic_jobs),
211 dispatch_cancel: CancellationToken::new(),
212 service_cancel: CancellationToken::new(),
213 dispatcher_handles: RwLock::new(Vec::new()),
214 service_handles: RwLock::new(Vec::new()),
215 job_set: Arc::new(Mutex::new(JoinSet::new())),
216 in_flight: Arc::new(RwLock::new(HashMap::new())),
217 queue_in_flight,
218 dispatcher_alive,
219 heartbeat_alive: Arc::new(AtomicBool::new(false)),
220 leader: Arc::new(AtomicBool::new(false)),
221 overflow_pool,
222 metrics,
223 })
224 }
225}
226
227struct TypedWorker<T, F, Fut>
229where
230 T: JobArgs + DeserializeOwned + Send + Sync + 'static,
231 F: Fn(T, &crate::context::JobContext) -> Fut + Send + Sync + 'static,
232 Fut: std::future::Future<Output = Result<JobResult, JobError>> + Send + Sync + 'static,
233{
234 kind: &'static str,
235 handler: Arc<F>,
236 _phantom: std::marker::PhantomData<fn() -> (T, Fut)>,
237}
238
239#[async_trait::async_trait]
240impl<T, F, Fut> Worker for TypedWorker<T, F, Fut>
241where
242 T: JobArgs + DeserializeOwned + Send + Sync + 'static,
243 F: Fn(T, &crate::context::JobContext) -> Fut + Send + Sync + 'static,
244 Fut: std::future::Future<Output = Result<JobResult, JobError>> + Send + Sync + 'static,
245{
246 fn kind(&self) -> &'static str {
247 self.kind
248 }
249
250 async fn perform(
251 &self,
252 job_row: &JobRow,
253 ctx: &crate::context::JobContext,
254 ) -> Result<JobResult, JobError> {
255 let args: T = serde_json::from_value(job_row.args.clone())
257 .map_err(|err| JobError::Terminal(format!("failed to deserialize args: {}", err)))?;
258
259 (self.handler)(args, ctx).await
260 }
261}
262
263pub struct Client {
265 pool: PgPool,
266 queues: Vec<(String, QueueConfig)>,
267 workers: Arc<HashMap<String, BoxedWorker>>,
268 state: Arc<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
269 heartbeat_interval: Duration,
270 periodic_jobs: Arc<Vec<PeriodicJob>>,
271 dispatch_cancel: CancellationToken,
273 service_cancel: CancellationToken,
275 dispatcher_handles: RwLock<Vec<tokio::task::JoinHandle<()>>>,
277 service_handles: RwLock<Vec<tokio::task::JoinHandle<()>>>,
279 job_set: Arc<Mutex<JoinSet<()>>>,
281 in_flight: Arc<RwLock<HashMap<i64, Arc<AtomicBool>>>>,
282 queue_in_flight: Arc<HashMap<String, Arc<AtomicU32>>>,
283 dispatcher_alive: Arc<HashMap<String, Arc<AtomicBool>>>,
284 heartbeat_alive: Arc<AtomicBool>,
285 leader: Arc<AtomicBool>,
286 overflow_pool: Option<Arc<OverflowPool>>,
288 metrics: crate::metrics::AwaMetrics,
289}
290
291impl Client {
292 pub fn builder(pool: PgPool) -> ClientBuilder {
294 ClientBuilder::new(pool)
295 }
296
297 pub async fn start(&self) -> Result<(), awa_model::AwaError> {
299 info!(
300 queues = self.queues.len(),
301 workers = self.workers.len(),
302 "Starting Awa worker runtime"
303 );
304
305 let executor = Arc::new(JobExecutor::new(
307 self.pool.clone(),
308 self.workers.clone(),
309 self.in_flight.clone(),
310 self.queue_in_flight.clone(),
311 self.state.clone(),
312 self.metrics.clone(),
313 ));
314
315 let mut service_handles = self.service_handles.write().await;
316
317 let heartbeat = HeartbeatService::new(
319 self.pool.clone(),
320 self.in_flight.clone(),
321 self.heartbeat_interval,
322 self.heartbeat_alive.clone(),
323 self.service_cancel.clone(),
324 );
325 service_handles.push(tokio::spawn(async move {
326 heartbeat.run().await;
327 }));
328
329 let maintenance = MaintenanceService::new(
331 self.pool.clone(),
332 self.leader.clone(),
333 self.service_cancel.clone(),
334 self.periodic_jobs.clone(),
335 self.in_flight.clone(),
336 );
337 service_handles.push(tokio::spawn(async move {
338 maintenance.run().await;
339 }));
340
341 let mut dispatcher_handles = self.dispatcher_handles.write().await;
343 for (queue_name, config) in &self.queues {
344 let alive = self
345 .dispatcher_alive
346 .get(queue_name)
347 .cloned()
348 .unwrap_or_else(|| Arc::new(AtomicBool::new(false)));
349
350 let dispatcher = if let Some(overflow_pool) = &self.overflow_pool {
351 let concurrency = ConcurrencyMode::Weighted {
353 local_semaphore: Arc::new(tokio::sync::Semaphore::new(
354 config.min_workers as usize,
355 )),
356 overflow_pool: overflow_pool.clone(),
357 queue_name: queue_name.clone(),
358 };
359 Dispatcher::with_concurrency(
360 queue_name.clone(),
361 config.clone(),
362 self.pool.clone(),
363 executor.clone(),
364 self.in_flight.clone(),
365 alive,
366 self.dispatch_cancel.clone(),
367 self.job_set.clone(),
368 concurrency,
369 )
370 } else {
371 Dispatcher::new(
373 queue_name.clone(),
374 config.clone(),
375 self.pool.clone(),
376 executor.clone(),
377 self.in_flight.clone(),
378 alive,
379 self.dispatch_cancel.clone(),
380 self.job_set.clone(),
381 )
382 };
383 dispatcher_handles.push(tokio::spawn(async move {
384 dispatcher.run().await;
385 }));
386 }
387
388 info!("Awa worker runtime started");
389 Ok(())
390 }
391
392 pub async fn shutdown(&self, timeout: Duration) {
401 info!("Initiating graceful shutdown");
402
403 self.dispatch_cancel.cancel();
405
406 {
408 let guard = self.in_flight.read().await;
409 for flag in guard.values() {
410 flag.store(true, Ordering::SeqCst);
411 }
412 }
413
414 let dispatcher_handles: Vec<_> = {
416 let mut guard = self.dispatcher_handles.write().await;
417 std::mem::take(&mut *guard)
418 };
419 for handle in dispatcher_handles {
420 let _ = handle.await;
421 }
422
423 let drain = async {
425 let mut set = self.job_set.lock().await;
426 while set.join_next().await.is_some() {}
427 };
428 if tokio::time::timeout(timeout, drain).await.is_err() {
429 warn!(
430 timeout_secs = timeout.as_secs(),
431 "Shutdown drain timeout exceeded, some jobs may not have completed"
432 );
433 }
434
435 self.service_cancel.cancel();
437 let service_handles: Vec<_> = {
438 let mut guard = self.service_handles.write().await;
439 std::mem::take(&mut *guard)
440 };
441 for handle in service_handles {
442 let _ = handle.await;
443 }
444
445 info!("Awa worker runtime stopped");
446 }
447
448 pub fn pool(&self) -> &PgPool {
450 &self.pool
451 }
452
453 pub async fn health_check(&self) -> HealthCheck {
455 let postgres_connected = sqlx::query("SELECT 1").execute(&self.pool).await.is_ok();
456 let poll_loop_alive = self
457 .dispatcher_alive
458 .values()
459 .all(|alive| alive.load(Ordering::SeqCst));
460 let heartbeat_alive = self.heartbeat_alive.load(Ordering::SeqCst);
461 let shutting_down = self.dispatch_cancel.is_cancelled();
462 let leader = self.leader.load(Ordering::SeqCst);
463 let available_rows = sqlx::query_as::<_, (String, i64)>(
464 r#"
465 SELECT queue, count(*)::bigint AS available
466 FROM awa.jobs
467 WHERE state = 'available'
468 GROUP BY queue
469 "#,
470 )
471 .fetch_all(&self.pool)
472 .await
473 .unwrap_or_default();
474 let available_by_queue: HashMap<_, _> = available_rows.into_iter().collect();
475 let queues = self
476 .queues
477 .iter()
478 .map(|(queue, config)| {
479 let in_flight = self
480 .queue_in_flight
481 .get(queue)
482 .map(|counter| counter.load(Ordering::SeqCst))
483 .unwrap_or(0);
484 let available = available_by_queue.get(queue).copied().unwrap_or(0).max(0) as u64;
485 let capacity = if let Some(overflow_pool) = &self.overflow_pool {
486 QueueCapacity::Weighted {
487 min_workers: config.min_workers,
488 weight: config.weight,
489 overflow_held: overflow_pool.held(queue),
490 }
491 } else {
492 QueueCapacity::HardReserved {
493 max_workers: config.max_workers,
494 }
495 };
496 (
497 queue.clone(),
498 QueueHealth {
499 in_flight,
500 available,
501 capacity,
502 },
503 )
504 })
505 .collect();
506
507 HealthCheck {
508 healthy: postgres_connected && poll_loop_alive && heartbeat_alive && !shutting_down,
509 postgres_connected,
510 poll_loop_alive,
511 heartbeat_alive,
512 shutting_down,
513 leader,
514 queues,
515 }
516 }
517}