1use crate::executor::JobExecutor;
2use awa_model::JobRow;
3use sqlx::PgPool;
4use std::collections::HashMap;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::{Mutex, RwLock, Semaphore};
9use tokio::task::JoinSet;
10use tokio_util::sync::CancellationToken;
11use tracing::{debug, error, info, warn};
12
13#[derive(Debug, Clone)]
15pub struct RateLimit {
16 pub max_rate: f64,
18 pub burst: u32,
20}
21
22struct TokenBucket {
24 tokens: f64,
25 max_tokens: f64,
26 refill_rate: f64,
27 last_refill: Instant,
28}
29
30impl TokenBucket {
31 fn new(rate_limit: &RateLimit) -> Self {
32 let burst = if rate_limit.burst == 0 {
33 (rate_limit.max_rate.ceil() as u32).max(1)
34 } else {
35 rate_limit.burst
36 };
37 Self {
38 tokens: burst as f64,
39 max_tokens: burst as f64,
40 refill_rate: rate_limit.max_rate,
41 last_refill: Instant::now(),
42 }
43 }
44
45 fn available(&mut self) -> u32 {
47 let now = Instant::now();
48 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
49 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
50 self.last_refill = now;
51 self.tokens.floor() as u32
52 }
53
54 fn consume(&mut self, n: u32) {
56 self.tokens -= n as f64;
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct QueueConfig {
63 pub max_workers: u32,
64 pub poll_interval: Duration,
65 pub deadline_duration: Duration,
66 pub priority_aging_interval: Duration,
67 pub rate_limit: Option<RateLimit>,
69 pub min_workers: u32,
71 pub weight: u32,
73}
74
75impl Default for QueueConfig {
76 fn default() -> Self {
77 Self {
78 max_workers: 50,
79 poll_interval: Duration::from_millis(200),
80 deadline_duration: Duration::from_secs(300), priority_aging_interval: Duration::from_secs(60),
82 rate_limit: None,
83 min_workers: 0,
84 weight: 1,
85 }
86 }
87}
88
89#[allow(dead_code)]
92pub(crate) enum DispatchPermit {
93 Hard(tokio::sync::OwnedSemaphorePermit),
95 Local(tokio::sync::OwnedSemaphorePermit),
97 Overflow {
99 pool: Arc<OverflowPool>,
100 queue: String,
101 },
102}
103
104impl Drop for DispatchPermit {
105 fn drop(&mut self) {
106 if let DispatchPermit::Overflow { pool, queue } = self {
107 pool.release(queue, 1);
108 }
109 }
111}
112
113pub(crate) enum ConcurrencyMode {
115 HardReserved { semaphore: Arc<Semaphore> },
117 Weighted {
119 local_semaphore: Arc<Semaphore>,
120 overflow_pool: Arc<OverflowPool>,
121 queue_name: String,
122 },
123}
124
125pub(crate) struct OverflowPool {
128 total: u32,
129 state: std::sync::Mutex<OverflowState>,
130}
131
132struct OverflowState {
133 held: HashMap<String, u32>,
135 demand: HashMap<String, u32>,
137 weights: HashMap<String, u32>,
139}
140
141impl OverflowPool {
142 pub fn new(total: u32, weights: HashMap<String, u32>) -> Self {
143 Self {
144 total,
145 state: std::sync::Mutex::new(OverflowState {
146 held: HashMap::new(),
147 demand: HashMap::new(),
148 weights,
149 }),
150 }
151 }
152
153 pub fn try_acquire(&self, queue: &str, wanted: u32) -> u32 {
158 let mut state = self.state.lock().unwrap();
159
160 state.demand.insert(queue.to_string(), wanted);
162
163 if wanted == 0 {
164 return 0;
165 }
166
167 let currently_used: u32 = state.held.values().sum();
168 let available = self.total.saturating_sub(currently_used);
169 if available == 0 {
170 return 0;
171 }
172
173 let my_weight = state.weights.get(queue).copied().unwrap_or(1);
174
175 let contending_weight: u32 = state
177 .weights
178 .iter()
179 .filter(|(q, _)| {
180 state.demand.get(q.as_str()).copied().unwrap_or(0) > 0
181 || state.held.get(q.as_str()).copied().unwrap_or(0) > 0
182 })
183 .map(|(_, w)| *w)
184 .sum();
185
186 if contending_weight == 0 {
187 return 0;
188 }
189
190 let my_fair_share =
192 ((self.total as f64) * (my_weight as f64 / contending_weight as f64)).ceil() as u32;
193 let my_held = state.held.get(queue).copied().unwrap_or(0);
194 let room = my_fair_share.saturating_sub(my_held);
195
196 let granted = wanted.min(available).min(room);
197 if granted > 0 {
198 *state.held.entry(queue.to_string()).or_insert(0) += granted;
199 }
200 granted
201 }
202
203 pub fn release(&self, queue: &str, n: u32) {
205 let mut state = self.state.lock().unwrap();
206 if let Some(held) = state.held.get_mut(queue) {
207 *held = held.saturating_sub(n);
208 }
209 }
210
211 pub fn held(&self, queue: &str) -> u32 {
213 let state = self.state.lock().unwrap();
214 state.held.get(queue).copied().unwrap_or(0)
215 }
216}
217
218pub struct Dispatcher {
220 queue: String,
221 config: QueueConfig,
222 pool: PgPool,
223 executor: Arc<JobExecutor>,
224 _in_flight: Arc<RwLock<HashMap<i64, Arc<AtomicBool>>>>,
225 concurrency: ConcurrencyMode,
226 alive: Arc<AtomicBool>,
227 cancel: CancellationToken,
228 job_set: Arc<Mutex<JoinSet<()>>>,
229 rate_limiter: Option<TokenBucket>,
230}
231
232impl Dispatcher {
233 #[allow(clippy::too_many_arguments)]
234 pub fn new(
235 queue: String,
236 config: QueueConfig,
237 pool: PgPool,
238 executor: Arc<JobExecutor>,
239 in_flight: Arc<RwLock<HashMap<i64, Arc<AtomicBool>>>>,
240 alive: Arc<AtomicBool>,
241 cancel: CancellationToken,
242 job_set: Arc<Mutex<JoinSet<()>>>,
243 ) -> Self {
244 let concurrency = ConcurrencyMode::HardReserved {
245 semaphore: Arc::new(Semaphore::new(config.max_workers as usize)),
246 };
247 let rate_limiter = config.rate_limit.as_ref().map(TokenBucket::new);
248 Self {
249 queue,
250 config,
251 pool,
252 executor,
253 _in_flight: in_flight,
254 concurrency,
255 alive,
256 cancel,
257 job_set,
258 rate_limiter,
259 }
260 }
261
262 #[allow(clippy::too_many_arguments)]
264 pub(crate) fn with_concurrency(
265 queue: String,
266 config: QueueConfig,
267 pool: PgPool,
268 executor: Arc<JobExecutor>,
269 in_flight: Arc<RwLock<HashMap<i64, Arc<AtomicBool>>>>,
270 alive: Arc<AtomicBool>,
271 cancel: CancellationToken,
272 job_set: Arc<Mutex<JoinSet<()>>>,
273 concurrency: ConcurrencyMode,
274 ) -> Self {
275 let rate_limiter = config.rate_limit.as_ref().map(TokenBucket::new);
276 Self {
277 queue,
278 config,
279 pool,
280 executor,
281 _in_flight: in_flight,
282 concurrency,
283 alive,
284 cancel,
285 job_set,
286 rate_limiter,
287 }
288 }
289
290 #[tracing::instrument(skip(self), fields(queue = %self.queue))]
292 pub async fn run(mut self) {
293 self.alive.store(true, Ordering::SeqCst);
294 info!(
295 queue = %self.queue,
296 poll_interval_ms = self.config.poll_interval.as_millis(),
297 "Dispatcher started"
298 );
299
300 let notify_channel = format!("awa:{}", self.queue);
302 let mut listener = match sqlx::postgres::PgListener::connect_with(&self.pool).await {
303 Ok(listener) => listener,
304 Err(err) => {
305 error!(error = %err, "Failed to create PG listener, falling back to polling only");
306 self.poll_loop_only().await;
308 self.alive.store(false, Ordering::SeqCst);
309 return;
310 }
311 };
312
313 if let Err(err) = listener.listen(¬ify_channel).await {
314 warn!(error = %err, channel = %notify_channel, "Failed to LISTEN, falling back to polling");
315 self.poll_loop_only().await;
316 self.alive.store(false, Ordering::SeqCst);
317 return;
318 }
319
320 debug!(channel = %notify_channel, "Listening for job notifications");
321
322 loop {
323 tokio::select! {
324 _ = self.cancel.cancelled() => {
325 debug!(queue = %self.queue, "Dispatcher shutting down");
326 break;
327 }
328 notification = listener.recv() => {
330 match notification {
331 Ok(_) => {
332 debug!(queue = %self.queue, "Woken by NOTIFY");
333 self.poll_once().await;
334 }
335 Err(err) => {
336 warn!(error = %err, "PG listener error, will retry");
337 tokio::time::sleep(Duration::from_secs(1)).await;
338 }
339 }
340 }
341 _ = tokio::time::sleep(self.config.poll_interval) => {
342 self.poll_once().await;
343 }
344 }
345 }
346
347 self.alive.store(false, Ordering::SeqCst);
348 }
349
350 async fn poll_loop_only(&mut self) {
352 loop {
353 tokio::select! {
354 _ = self.cancel.cancelled() => {
355 debug!(queue = %self.queue, "Dispatcher (poll-only) shutting down");
356 break;
357 }
358 _ = tokio::time::sleep(self.config.poll_interval) => {
359 self.poll_once().await;
360 }
361 }
362 }
363 }
364
365 fn acquire_permits(&mut self) -> Vec<DispatchPermit> {
367 let mut permits = Vec::new();
368 match &self.concurrency {
369 ConcurrencyMode::HardReserved { semaphore } => {
370 for _ in 0..10 {
371 match semaphore.clone().try_acquire_owned() {
372 Ok(p) => permits.push(DispatchPermit::Hard(p)),
373 Err(_) => break,
374 }
375 }
376 }
377 ConcurrencyMode::Weighted {
378 local_semaphore,
379 overflow_pool,
380 queue_name,
381 } => {
382 for _ in 0..10 {
384 match local_semaphore.clone().try_acquire_owned() {
385 Ok(p) => permits.push(DispatchPermit::Local(p)),
386 Err(_) => break,
387 }
388 }
389 let overflow_wanted = (10usize.saturating_sub(permits.len())) as u32;
391 let granted = overflow_pool.try_acquire(queue_name, overflow_wanted);
392 for _ in 0..granted {
393 permits.push(DispatchPermit::Overflow {
394 pool: overflow_pool.clone(),
395 queue: queue_name.clone(),
396 });
397 }
398 }
399 }
400 permits
401 }
402
403 #[tracing::instrument(skip(self), fields(queue = %self.queue))]
405 async fn poll_once(&mut self) {
406 let mut permits = self.acquire_permits();
408 if permits.is_empty() {
409 return;
410 }
411
412 let rate_available = self
414 .rate_limiter
415 .as_mut()
416 .map(|rl| rl.available() as usize)
417 .unwrap_or(usize::MAX);
418 let batch_size = permits.len().min(rate_available).min(10);
419 if batch_size == 0 {
420 return;
422 }
423 while permits.len() > batch_size {
425 permits.pop(); }
427
428 let deadline_secs = self.config.deadline_duration.as_secs_f64();
430 let aging_secs = self.config.priority_aging_interval.as_secs_f64();
431
432 let jobs: Vec<JobRow> = match sqlx::query_as::<_, JobRow>(
433 r#"
434 WITH claimed AS (
435 SELECT id
436 FROM awa.jobs
437 WHERE state = 'available'
438 AND queue = $1
439 AND run_at <= now()
440 AND NOT EXISTS (
441 SELECT 1 FROM awa.queue_meta
442 WHERE queue = $1 AND paused = TRUE
443 )
444 ORDER BY
445 GREATEST(1, priority - FLOOR(EXTRACT(EPOCH FROM (now() - run_at)) / $4)::int) ASC,
446 run_at ASC,
447 id ASC
448 LIMIT $2
449 FOR UPDATE SKIP LOCKED
450 )
451 UPDATE awa.jobs
452 SET state = 'running',
453 attempt = attempt + 1,
454 attempted_at = now(),
455 heartbeat_at = now(),
456 deadline_at = now() + make_interval(secs => $3)
457 FROM claimed
458 WHERE awa.jobs.id = claimed.id
459 RETURNING awa.jobs.*
460 "#,
461 )
462 .bind(&self.queue)
463 .bind(batch_size as i32)
464 .bind(deadline_secs)
465 .bind(aging_secs)
466 .fetch_all(&self.pool)
467 .await
468 {
469 Ok(jobs) => jobs,
470 Err(err) => {
471 warn!(queue = %self.queue, error = %err, "Failed to claim jobs");
472 return;
473 }
474 };
475
476 while permits.len() > jobs.len() {
478 permits.pop();
479 }
480
481 if jobs.is_empty() {
483 if let ConcurrencyMode::Weighted {
484 overflow_pool,
485 queue_name,
486 ..
487 } = &self.concurrency
488 {
489 overflow_pool.try_acquire(queue_name, 0);
490 }
491 return;
492 }
493
494 debug!(queue = %self.queue, count = jobs.len(), "Claimed jobs");
495
496 if let Some(rl) = &mut self.rate_limiter {
498 rl.consume(jobs.len() as u32);
499 }
500
501 for (job, permit) in jobs.into_iter().zip(permits) {
503 let cancel_flag = Arc::new(AtomicBool::new(false));
504 let handle = self.executor.execute(job, cancel_flag);
505
506 let job_set = self.job_set.clone();
508 let mut set = job_set.lock().await;
509 set.spawn(async move {
510 let _ = handle.await;
511 drop(permit);
512 });
513 }
514 }
515}