Skip to main content

awa_worker/
dispatcher.rs

1use crate::executor::JobExecutor;
2use awa_model::JobRow;
3use sqlx::PgPool;
4use std::collections::HashMap;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::{RwLock, Semaphore};
9use tokio_util::sync::CancellationToken;
10use tracing::{debug, error, info, warn};
11
12/// Configuration for a single queue.
13#[derive(Debug, Clone)]
14pub struct QueueConfig {
15    pub max_workers: u32,
16    pub poll_interval: Duration,
17    pub deadline_duration: Duration,
18    pub priority_aging_interval: Duration,
19}
20
21impl Default for QueueConfig {
22    fn default() -> Self {
23        Self {
24            max_workers: 50,
25            poll_interval: Duration::from_millis(200),
26            deadline_duration: Duration::from_secs(300), // 5 minutes
27            priority_aging_interval: Duration::from_secs(60),
28        }
29    }
30}
31
32/// Dispatcher polls a single queue for available jobs and dispatches them.
33pub struct Dispatcher {
34    queue: String,
35    config: QueueConfig,
36    pool: PgPool,
37    executor: Arc<JobExecutor>,
38    _in_flight: Arc<RwLock<HashMap<i64, Arc<AtomicBool>>>>,
39    semaphore: Arc<Semaphore>,
40    alive: Arc<AtomicBool>,
41    cancel: CancellationToken,
42}
43
44impl Dispatcher {
45    pub fn new(
46        queue: String,
47        config: QueueConfig,
48        pool: PgPool,
49        executor: Arc<JobExecutor>,
50        in_flight: Arc<RwLock<HashMap<i64, Arc<AtomicBool>>>>,
51        alive: Arc<AtomicBool>,
52        cancel: CancellationToken,
53    ) -> Self {
54        let semaphore = Arc::new(Semaphore::new(config.max_workers as usize));
55        Self {
56            queue,
57            config,
58            pool,
59            executor,
60            _in_flight: in_flight,
61            semaphore,
62            alive,
63            cancel,
64        }
65    }
66
67    /// Run the poll loop. Returns when cancelled.
68    #[tracing::instrument(skip(self), fields(queue = %self.queue, max_workers = self.config.max_workers))]
69    pub async fn run(&self) {
70        self.alive.store(true, Ordering::SeqCst);
71        info!(
72            queue = %self.queue,
73            max_workers = self.config.max_workers,
74            poll_interval_ms = self.config.poll_interval.as_millis(),
75            "Dispatcher started"
76        );
77
78        // Set up LISTEN/NOTIFY for this queue
79        let notify_channel = format!("awa:{}", self.queue);
80        let mut listener = match sqlx::postgres::PgListener::connect_with(&self.pool).await {
81            Ok(listener) => listener,
82            Err(err) => {
83                error!(error = %err, "Failed to create PG listener, falling back to polling only");
84                // Fall back to poll-only mode
85                self.poll_loop_only().await;
86                self.alive.store(false, Ordering::SeqCst);
87                return;
88            }
89        };
90
91        if let Err(err) = listener.listen(&notify_channel).await {
92            warn!(error = %err, channel = %notify_channel, "Failed to LISTEN, falling back to polling");
93            self.poll_loop_only().await;
94            self.alive.store(false, Ordering::SeqCst);
95            return;
96        }
97
98        debug!(channel = %notify_channel, "Listening for job notifications");
99
100        loop {
101            tokio::select! {
102                _ = self.cancel.cancelled() => {
103                    debug!(queue = %self.queue, "Dispatcher shutting down");
104                    break;
105                }
106                // Wait for either a notification or the poll interval
107                notification = listener.recv() => {
108                    match notification {
109                        Ok(_) => {
110                            debug!(queue = %self.queue, "Woken by NOTIFY");
111                            self.poll_once().await;
112                        }
113                        Err(err) => {
114                            warn!(error = %err, "PG listener error, will retry");
115                            tokio::time::sleep(Duration::from_secs(1)).await;
116                        }
117                    }
118                }
119                _ = tokio::time::sleep(self.config.poll_interval) => {
120                    self.poll_once().await;
121                }
122            }
123        }
124
125        self.alive.store(false, Ordering::SeqCst);
126    }
127
128    /// Poll-only fallback (no LISTEN/NOTIFY).
129    async fn poll_loop_only(&self) {
130        loop {
131            tokio::select! {
132                _ = self.cancel.cancelled() => {
133                    debug!(queue = %self.queue, "Dispatcher (poll-only) shutting down");
134                    break;
135                }
136                _ = tokio::time::sleep(self.config.poll_interval) => {
137                    self.poll_once().await;
138                }
139            }
140        }
141    }
142
143    /// Single poll iteration: claim available jobs up to the semaphore limit.
144    #[tracing::instrument(skip(self), fields(queue = %self.queue))]
145    async fn poll_once(&self) {
146        // How many workers are available?
147        let available = self.semaphore.available_permits();
148        if available == 0 {
149            return;
150        }
151
152        let batch_size = available.min(10) as i32; // Claim in small batches
153        let deadline_secs = self.config.deadline_duration.as_secs_f64();
154        let aging_secs = self.config.priority_aging_interval.as_secs_f64();
155
156        let jobs: Vec<JobRow> = match sqlx::query_as::<_, JobRow>(
157            r#"
158            WITH claimed AS (
159                SELECT id
160                FROM awa.jobs
161                WHERE state = 'available'
162                  AND queue = $1
163                  AND run_at <= now()
164                  AND NOT EXISTS (
165                      SELECT 1 FROM awa.queue_meta
166                      WHERE queue = $1 AND paused = TRUE
167                  )
168                ORDER BY
169                  GREATEST(1, priority - FLOOR(EXTRACT(EPOCH FROM (now() - run_at)) / $4)::int) ASC,
170                  run_at ASC,
171                  id ASC
172                LIMIT $2
173                FOR UPDATE SKIP LOCKED
174            )
175            UPDATE awa.jobs
176            SET state = 'running',
177                attempt = attempt + 1,
178                attempted_at = now(),
179                heartbeat_at = now(),
180                deadline_at = now() + make_interval(secs => $3)
181            FROM claimed
182            WHERE awa.jobs.id = claimed.id
183            RETURNING awa.jobs.*
184            "#,
185        )
186        .bind(&self.queue)
187        .bind(batch_size)
188        .bind(deadline_secs)
189        .bind(aging_secs)
190        .fetch_all(&self.pool)
191        .await
192        {
193            Ok(jobs) => jobs,
194            Err(err) => {
195                warn!(queue = %self.queue, error = %err, "Failed to claim jobs");
196                return;
197            }
198        };
199
200        if !jobs.is_empty() {
201            debug!(queue = %self.queue, count = jobs.len(), "Claimed jobs");
202        }
203
204        for job in jobs {
205            // Acquire a semaphore permit before dispatching
206            let permit = match self.semaphore.clone().acquire_owned().await {
207                Ok(permit) => permit,
208                Err(_) => {
209                    warn!("Semaphore closed");
210                    break;
211                }
212            };
213
214            let cancel_flag = Arc::new(AtomicBool::new(false));
215            let handle = self.executor.execute(job, cancel_flag);
216
217            // Spawn a task to release the permit when the job completes
218            tokio::spawn(async move {
219                let _ = handle.await;
220                drop(permit);
221            });
222        }
223    }
224
225    /// Get the number of in-flight jobs for this queue.
226    pub async fn in_flight_count(&self) -> usize {
227        self.config.max_workers as usize - self.semaphore.available_permits()
228    }
229}