Skip to main content

awa_worker/
dispatcher.rs

1use crate::executor::JobExecutor;
2use awa_model::JobRow;
3use sqlx::PgPool;
4use std::collections::HashMap;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::{Mutex, RwLock, Semaphore};
9use tokio::task::JoinSet;
10use tokio_util::sync::CancellationToken;
11use tracing::{debug, error, info, warn};
12
13/// Rate limit configuration for a queue.
14#[derive(Debug, Clone)]
15pub struct RateLimit {
16    /// Maximum sustained dispatch rate (jobs per second).
17    pub max_rate: f64,
18    /// Maximum burst size. Defaults to ceil(max_rate) if 0.
19    pub burst: u32,
20}
21
22/// Internal token bucket state for rate limiting.
23struct TokenBucket {
24    tokens: f64,
25    max_tokens: f64,
26    refill_rate: f64,
27    last_refill: Instant,
28}
29
30impl TokenBucket {
31    fn new(rate_limit: &RateLimit) -> Self {
32        let burst = if rate_limit.burst == 0 {
33            (rate_limit.max_rate.ceil() as u32).max(1)
34        } else {
35            rate_limit.burst
36        };
37        Self {
38            tokens: burst as f64,
39            max_tokens: burst as f64,
40            refill_rate: rate_limit.max_rate,
41            last_refill: Instant::now(),
42        }
43    }
44
45    /// Return how many whole tokens are available after refilling.
46    fn available(&mut self) -> u32 {
47        let now = Instant::now();
48        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
49        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
50        self.last_refill = now;
51        self.tokens.floor() as u32
52    }
53
54    /// Consume `n` tokens (caller must ensure n <= available()).
55    fn consume(&mut self, n: u32) {
56        self.tokens -= n as f64;
57    }
58}
59
60/// Configuration for a single queue.
61#[derive(Debug, Clone)]
62pub struct QueueConfig {
63    pub max_workers: u32,
64    pub poll_interval: Duration,
65    pub deadline_duration: Duration,
66    pub priority_aging_interval: Duration,
67    /// Optional rate limit for this queue. None means unlimited.
68    pub rate_limit: Option<RateLimit>,
69    /// Minimum guaranteed workers in weighted mode (default: 0).
70    pub min_workers: u32,
71    /// Weight for overflow allocation in weighted mode (default: 1).
72    pub weight: u32,
73}
74
75impl Default for QueueConfig {
76    fn default() -> Self {
77        Self {
78            max_workers: 50,
79            poll_interval: Duration::from_millis(200),
80            deadline_duration: Duration::from_secs(300), // 5 minutes
81            priority_aging_interval: Duration::from_secs(60),
82            rate_limit: None,
83            min_workers: 0,
84            weight: 1,
85        }
86    }
87}
88
89/// Wraps permits so the correct resource is released on drop.
90/// The OwnedSemaphorePermit fields are held purely for their Drop behavior.
91#[allow(dead_code)]
92pub(crate) enum DispatchPermit {
93    /// Hard-reserved semaphore permit (current default behavior).
94    Hard(tokio::sync::OwnedSemaphorePermit),
95    /// Local (guaranteed minimum) semaphore permit in weighted mode.
96    Local(tokio::sync::OwnedSemaphorePermit),
97    /// Overflow permit from the shared OverflowPool.
98    Overflow {
99        pool: Arc<OverflowPool>,
100        queue: String,
101    },
102}
103
104impl Drop for DispatchPermit {
105    fn drop(&mut self) {
106        if let DispatchPermit::Overflow { pool, queue } = self {
107            pool.release(queue, 1);
108        }
109        // OwnedSemaphorePermit auto-releases on drop for Hard/Local
110    }
111}
112
113/// Concurrency mode for a dispatcher.
114pub(crate) enum ConcurrencyMode {
115    /// Each queue has its own semaphore. No sharing. Default behavior.
116    HardReserved { semaphore: Arc<Semaphore> },
117    /// Queues share a global overflow pool with per-queue minimum guarantees.
118    Weighted {
119        local_semaphore: Arc<Semaphore>,
120        overflow_pool: Arc<OverflowPool>,
121        queue_name: String,
122    },
123}
124
125/// Centralized overflow capacity allocator for weighted mode.
126/// Thread-safe: called from multiple dispatcher poll loops via Mutex.
127pub(crate) struct OverflowPool {
128    total: u32,
129    state: std::sync::Mutex<OverflowState>,
130}
131
132struct OverflowState {
133    /// Per-queue: currently held overflow permits (decremented on release).
134    held: HashMap<String, u32>,
135    /// Per-queue: last-declared demand (updated every try_acquire call).
136    demand: HashMap<String, u32>,
137    /// Per-queue: configured weight (immutable after construction).
138    weights: HashMap<String, u32>,
139}
140
141impl OverflowPool {
142    pub fn new(total: u32, weights: HashMap<String, u32>) -> Self {
143        Self {
144            total,
145            state: std::sync::Mutex::new(OverflowState {
146                held: HashMap::new(),
147                demand: HashMap::new(),
148                weights,
149            }),
150        }
151    }
152
153    /// Try to acquire up to `wanted` overflow permits for `queue`.
154    /// Returns the number actually granted (0..=wanted).
155    ///
156    /// Calling with wanted=0 is valid — it clears this queue's demand signal.
157    pub fn try_acquire(&self, queue: &str, wanted: u32) -> u32 {
158        let mut state = self.state.lock().unwrap();
159
160        // Always update demand — this is the key signal for fairness
161        state.demand.insert(queue.to_string(), wanted);
162
163        if wanted == 0 {
164            return 0;
165        }
166
167        let currently_used: u32 = state.held.values().sum();
168        let available = self.total.saturating_sub(currently_used);
169        if available == 0 {
170            return 0;
171        }
172
173        let my_weight = state.weights.get(queue).copied().unwrap_or(1);
174
175        // Contending = queues with demand > 0 OR held > 0
176        let contending_weight: u32 = state
177            .weights
178            .iter()
179            .filter(|(q, _)| {
180                state.demand.get(q.as_str()).copied().unwrap_or(0) > 0
181                    || state.held.get(q.as_str()).copied().unwrap_or(0) > 0
182            })
183            .map(|(_, w)| *w)
184            .sum();
185
186        if contending_weight == 0 {
187            return 0;
188        }
189
190        // My fair share of the TOTAL pool (not just available)
191        let my_fair_share =
192            ((self.total as f64) * (my_weight as f64 / contending_weight as f64)).ceil() as u32;
193        let my_held = state.held.get(queue).copied().unwrap_or(0);
194        let room = my_fair_share.saturating_sub(my_held);
195
196        let granted = wanted.min(available).min(room);
197        if granted > 0 {
198            *state.held.entry(queue.to_string()).or_insert(0) += granted;
199        }
200        granted
201    }
202
203    /// Release `n` overflow permits back to the pool.
204    pub fn release(&self, queue: &str, n: u32) {
205        let mut state = self.state.lock().unwrap();
206        if let Some(held) = state.held.get_mut(queue) {
207            *held = held.saturating_sub(n);
208        }
209    }
210
211    /// Get the number of overflow permits currently held by a queue.
212    pub fn held(&self, queue: &str) -> u32 {
213        let state = self.state.lock().unwrap();
214        state.held.get(queue).copied().unwrap_or(0)
215    }
216}
217
218/// Dispatcher polls a single queue for available jobs and dispatches them.
219pub struct Dispatcher {
220    queue: String,
221    config: QueueConfig,
222    pool: PgPool,
223    executor: Arc<JobExecutor>,
224    _in_flight: Arc<RwLock<HashMap<i64, Arc<AtomicBool>>>>,
225    concurrency: ConcurrencyMode,
226    alive: Arc<AtomicBool>,
227    cancel: CancellationToken,
228    job_set: Arc<Mutex<JoinSet<()>>>,
229    rate_limiter: Option<TokenBucket>,
230}
231
232impl Dispatcher {
233    #[allow(clippy::too_many_arguments)]
234    pub fn new(
235        queue: String,
236        config: QueueConfig,
237        pool: PgPool,
238        executor: Arc<JobExecutor>,
239        in_flight: Arc<RwLock<HashMap<i64, Arc<AtomicBool>>>>,
240        alive: Arc<AtomicBool>,
241        cancel: CancellationToken,
242        job_set: Arc<Mutex<JoinSet<()>>>,
243    ) -> Self {
244        let concurrency = ConcurrencyMode::HardReserved {
245            semaphore: Arc::new(Semaphore::new(config.max_workers as usize)),
246        };
247        let rate_limiter = config.rate_limit.as_ref().map(TokenBucket::new);
248        Self {
249            queue,
250            config,
251            pool,
252            executor,
253            _in_flight: in_flight,
254            concurrency,
255            alive,
256            cancel,
257            job_set,
258            rate_limiter,
259        }
260    }
261
262    /// Create a dispatcher with a specific concurrency mode (used for weighted mode).
263    #[allow(clippy::too_many_arguments)]
264    pub(crate) fn with_concurrency(
265        queue: String,
266        config: QueueConfig,
267        pool: PgPool,
268        executor: Arc<JobExecutor>,
269        in_flight: Arc<RwLock<HashMap<i64, Arc<AtomicBool>>>>,
270        alive: Arc<AtomicBool>,
271        cancel: CancellationToken,
272        job_set: Arc<Mutex<JoinSet<()>>>,
273        concurrency: ConcurrencyMode,
274    ) -> Self {
275        let rate_limiter = config.rate_limit.as_ref().map(TokenBucket::new);
276        Self {
277            queue,
278            config,
279            pool,
280            executor,
281            _in_flight: in_flight,
282            concurrency,
283            alive,
284            cancel,
285            job_set,
286            rate_limiter,
287        }
288    }
289
290    /// Run the poll loop. Returns when cancelled.
291    #[tracing::instrument(skip(self), fields(queue = %self.queue))]
292    pub async fn run(mut self) {
293        self.alive.store(true, Ordering::SeqCst);
294        info!(
295            queue = %self.queue,
296            poll_interval_ms = self.config.poll_interval.as_millis(),
297            "Dispatcher started"
298        );
299
300        // Set up LISTEN/NOTIFY for this queue
301        let notify_channel = format!("awa:{}", self.queue);
302        let mut listener = match sqlx::postgres::PgListener::connect_with(&self.pool).await {
303            Ok(listener) => listener,
304            Err(err) => {
305                error!(error = %err, "Failed to create PG listener, falling back to polling only");
306                // Fall back to poll-only mode
307                self.poll_loop_only().await;
308                self.alive.store(false, Ordering::SeqCst);
309                return;
310            }
311        };
312
313        if let Err(err) = listener.listen(&notify_channel).await {
314            warn!(error = %err, channel = %notify_channel, "Failed to LISTEN, falling back to polling");
315            self.poll_loop_only().await;
316            self.alive.store(false, Ordering::SeqCst);
317            return;
318        }
319
320        debug!(channel = %notify_channel, "Listening for job notifications");
321
322        loop {
323            tokio::select! {
324                _ = self.cancel.cancelled() => {
325                    debug!(queue = %self.queue, "Dispatcher shutting down");
326                    break;
327                }
328                // Wait for either a notification or the poll interval
329                notification = listener.recv() => {
330                    match notification {
331                        Ok(_) => {
332                            debug!(queue = %self.queue, "Woken by NOTIFY");
333                            self.poll_once().await;
334                        }
335                        Err(err) => {
336                            warn!(error = %err, "PG listener error, will retry");
337                            tokio::time::sleep(Duration::from_secs(1)).await;
338                        }
339                    }
340                }
341                _ = tokio::time::sleep(self.config.poll_interval) => {
342                    self.poll_once().await;
343                }
344            }
345        }
346
347        self.alive.store(false, Ordering::SeqCst);
348    }
349
350    /// Poll-only fallback (no LISTEN/NOTIFY).
351    async fn poll_loop_only(&mut self) {
352        loop {
353            tokio::select! {
354                _ = self.cancel.cancelled() => {
355                    debug!(queue = %self.queue, "Dispatcher (poll-only) shutting down");
356                    break;
357                }
358                _ = tokio::time::sleep(self.config.poll_interval) => {
359                    self.poll_once().await;
360                }
361            }
362        }
363    }
364
365    /// Pre-acquire permits (non-blocking). Returns a vec of permits.
366    fn acquire_permits(&mut self) -> Vec<DispatchPermit> {
367        let mut permits = Vec::new();
368        match &self.concurrency {
369            ConcurrencyMode::HardReserved { semaphore } => {
370                for _ in 0..10 {
371                    match semaphore.clone().try_acquire_owned() {
372                        Ok(p) => permits.push(DispatchPermit::Hard(p)),
373                        Err(_) => break,
374                    }
375                }
376            }
377            ConcurrencyMode::Weighted {
378                local_semaphore,
379                overflow_pool,
380                queue_name,
381            } => {
382                // First: local (guaranteed) permits
383                for _ in 0..10 {
384                    match local_semaphore.clone().try_acquire_owned() {
385                        Ok(p) => permits.push(DispatchPermit::Local(p)),
386                        Err(_) => break,
387                    }
388                }
389                // Then: overflow permits (up to 10 total)
390                let overflow_wanted = (10usize.saturating_sub(permits.len())) as u32;
391                let granted = overflow_pool.try_acquire(queue_name, overflow_wanted);
392                for _ in 0..granted {
393                    permits.push(DispatchPermit::Overflow {
394                        pool: overflow_pool.clone(),
395                        queue: queue_name.clone(),
396                    });
397                }
398            }
399        }
400        permits
401    }
402
403    /// Single poll iteration: pre-acquire permits, claim jobs, dispatch.
404    #[tracing::instrument(skip(self), fields(queue = %self.queue))]
405    async fn poll_once(&mut self) {
406        // Phase 1: Pre-acquire permits (non-blocking)
407        let mut permits = self.acquire_permits();
408        if permits.is_empty() {
409            return;
410        }
411
412        // Phase 2: Apply rate limit
413        let rate_available = self
414            .rate_limiter
415            .as_mut()
416            .map(|rl| rl.available() as usize)
417            .unwrap_or(usize::MAX);
418        let batch_size = permits.len().min(rate_available).min(10);
419        if batch_size == 0 {
420            // Drop all permits — rate limited
421            return;
422        }
423        // Release excess permits beyond what rate limit allows
424        while permits.len() > batch_size {
425            permits.pop(); // Drop releases the permit
426        }
427
428        // Phase 3: Claim from DB
429        let deadline_secs = self.config.deadline_duration.as_secs_f64();
430        let aging_secs = self.config.priority_aging_interval.as_secs_f64();
431
432        let jobs: Vec<JobRow> = match sqlx::query_as::<_, JobRow>(
433            r#"
434            WITH claimed AS (
435                SELECT id
436                FROM awa.jobs
437                WHERE state = 'available'
438                  AND queue = $1
439                  AND run_at <= now()
440                  AND NOT EXISTS (
441                      SELECT 1 FROM awa.queue_meta
442                      WHERE queue = $1 AND paused = TRUE
443                  )
444                ORDER BY
445                  GREATEST(1, priority - FLOOR(EXTRACT(EPOCH FROM (now() - run_at)) / $4)::int) ASC,
446                  run_at ASC,
447                  id ASC
448                LIMIT $2
449                FOR UPDATE SKIP LOCKED
450            )
451            UPDATE awa.jobs
452            SET state = 'running',
453                attempt = attempt + 1,
454                attempted_at = now(),
455                heartbeat_at = now(),
456                deadline_at = now() + make_interval(secs => $3)
457            FROM claimed
458            WHERE awa.jobs.id = claimed.id
459            RETURNING awa.jobs.*
460            "#,
461        )
462        .bind(&self.queue)
463        .bind(batch_size as i32)
464        .bind(deadline_secs)
465        .bind(aging_secs)
466        .fetch_all(&self.pool)
467        .await
468        {
469            Ok(jobs) => jobs,
470            Err(err) => {
471                warn!(queue = %self.queue, error = %err, "Failed to claim jobs");
472                return;
473            }
474        };
475
476        // Phase 4: Release excess permits if DB had fewer jobs
477        while permits.len() > jobs.len() {
478            permits.pop();
479        }
480
481        // Phase 5: Clear overflow demand if no jobs found
482        if jobs.is_empty() {
483            if let ConcurrencyMode::Weighted {
484                overflow_pool,
485                queue_name,
486                ..
487            } = &self.concurrency
488            {
489                overflow_pool.try_acquire(queue_name, 0);
490            }
491            return;
492        }
493
494        debug!(queue = %self.queue, count = jobs.len(), "Claimed jobs");
495
496        // Phase 6: Consume rate limit tokens
497        if let Some(rl) = &mut self.rate_limiter {
498            rl.consume(jobs.len() as u32);
499        }
500
501        // Phase 7: Dispatch (each job takes one pre-acquired permit)
502        for (job, permit) in jobs.into_iter().zip(permits) {
503            let cancel_flag = Arc::new(AtomicBool::new(false));
504            let handle = self.executor.execute(job, cancel_flag);
505
506            // Spawn into JoinSet so shutdown can drain in-flight jobs
507            let job_set = self.job_set.clone();
508            let mut set = job_set.lock().await;
509            set.spawn(async move {
510                let _ = handle.await;
511                drop(permit);
512            });
513        }
514    }
515}