Skip to main content

aurora_db/workers/
executor.rs

1use super::job::Job;
2use super::queue::JobQueue;
3use crate::error::Result;
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::RwLock;
10use tokio::task::JoinHandle;
11use tokio::time::{interval, timeout};
12
13/// Job handler function type
14pub type JobHandler =
15    Arc<dyn Fn(Job) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync>;
16
17/// Worker configuration
18#[derive(Clone)]
19pub struct WorkerConfig {
20    pub storage_path: String,
21    pub concurrency: usize,
22    pub poll_interval_ms: u64,
23    pub cleanup_interval_seconds: u64,
24}
25
26impl Default for WorkerConfig {
27    fn default() -> Self {
28        Self {
29            storage_path: "./aurora_workers".to_string(),
30            concurrency: 4,
31            poll_interval_ms: 100,
32            cleanup_interval_seconds: 3600, // 1 hour
33        }
34    }
35}
36
37/// Worker executor that processes jobs
38pub struct WorkerExecutor {
39    queue: Arc<JobQueue>,
40    handlers: Arc<RwLock<HashMap<String, JobHandler>>>,
41    config: WorkerConfig,
42    running: Arc<RwLock<bool>>,
43    worker_handles: Arc<RwLock<Vec<JoinHandle<()>>>>,
44}
45
46impl WorkerExecutor {
47    pub fn new(queue: Arc<JobQueue>, config: WorkerConfig) -> Self {
48        Self {
49            queue,
50            handlers: Arc::new(RwLock::new(HashMap::new())),
51            config,
52            running: Arc::new(RwLock::new(false)),
53            worker_handles: Arc::new(RwLock::new(Vec::new())),
54        }
55    }
56
57    /// Register a job handler
58    pub async fn register_handler<F, Fut>(&self, job_type: impl Into<String>, handler: F)
59    where
60        F: Fn(Job) -> Fut + Send + Sync + 'static,
61        Fut: Future<Output = Result<()>> + Send + 'static,
62    {
63        let handler = Arc::new(
64            move |job: Job| -> Pin<Box<dyn Future<Output = Result<()>> + Send>> {
65                Box::pin(handler(job))
66            },
67        );
68
69        self.handlers.write().await.insert(job_type.into(), handler);
70    }
71
72    /// Start the worker executor
73    pub async fn start(&self) -> Result<()> {
74        let mut running = self.running.write().await;
75        if *running {
76            return Ok(());
77        }
78        *running = true;
79        drop(running);
80
81        // Spawn worker tasks
82        let mut handles = self.worker_handles.write().await;
83        for worker_id in 0..self.config.concurrency {
84            let handle = self.spawn_worker(worker_id);
85            handles.push(handle);
86        }
87
88        // Spawn cleanup task
89        let cleanup_handle = self.spawn_cleanup_task();
90        handles.push(cleanup_handle);
91
92        // Spawn reaper task for zombie job recovery
93        let reaper_handle = self.spawn_reaper();
94        handles.push(reaper_handle);
95
96        Ok(())
97    }
98
99    /// Stop the worker executor
100    pub async fn stop(&self) -> Result<()> {
101        let mut running = self.running.write().await;
102        *running = false;
103        drop(running);
104
105        // Notify all workers to wake up and check the `running` flag
106        self.queue.notify_all();
107
108        // Wait for all worker tasks to finish
109        let mut handles = self.worker_handles.write().await;
110        for handle in handles.drain(..) {
111            if let Err(e) = handle.await {
112                eprintln!("Worker panic during shutdown: {:?}", e);
113            }
114        }
115
116        Ok(())
117    }
118
119    /// Spawn a worker task
120    fn spawn_worker(&self, worker_id: usize) -> JoinHandle<()> {
121        let queue = Arc::clone(&self.queue);
122        let handlers = Arc::clone(&self.handlers);
123        let running = Arc::clone(&self.running);
124
125        tokio::spawn(async move {
126            loop {
127                // Check if we should stop
128                if !*running.read().await {
129                    break;
130                }
131
132                // Try to dequeue a job
133                match queue.dequeue().await {
134                    Ok(Some(mut job)) => {
135                        println!(
136                            "[Worker {}] Processing job: {} ({})",
137                            worker_id, job.id, job.job_type
138                        );
139
140                        // Get handler
141                        let handlers = handlers.read().await;
142                        let handler = handlers.get(&job.job_type);
143
144                        if let Some(handler) = handler {
145                            let handler = Arc::clone(handler);
146                            drop(handlers);
147
148                            // Clone job for heartbeat updates
149                            let job_id_for_heartbeat = job.id.clone();
150                            let queue_for_heartbeat = Arc::clone(&queue);
151                            let mut heartbeat_job = job.clone();
152
153                            // Heartbeat interval: pulse every 15 seconds
154                            let heartbeat_interval = Duration::from_secs(15);
155                            let mut heartbeat_tick = interval(heartbeat_interval);
156                            heartbeat_tick
157                                .set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
158
159                            // Execute job with timeout AND periodic heartbeat
160                            let job_future = async {
161                                if let Some(timeout_secs) = job.timeout_seconds {
162                                    timeout(Duration::from_secs(timeout_secs), handler(job.clone()))
163                                        .await
164                                } else {
165                                    Ok(handler(job.clone()).await)
166                                }
167                            };
168
169                            // Run job and heartbeat concurrently
170                            let result = {
171                                tokio::pin!(job_future);
172
173                                loop {
174                                    tokio::select! {
175                                        biased;
176
177                                        // Job completion takes priority
178                                        result = &mut job_future => {
179                                            break result;
180                                        }
181
182                                        // Heartbeat pulse every 15 seconds
183                                        _ = heartbeat_tick.tick() => {
184                                            heartbeat_job.touch();
185                                            let _ = queue_for_heartbeat
186                                                .update_job(&job_id_for_heartbeat, heartbeat_job.clone())
187                                                .await;
188                                        }
189                                    }
190                                }
191                            };
192
193                            match result {
194                                Ok(Ok(())) => {
195                                    job.mark_completed();
196                                }
197                                Ok(Err(e)) => {
198                                    job.mark_failed(e.to_string());
199                                }
200                                Err(_) => {
201                                    job.mark_failed("Timeout".to_string());
202                                }
203                            }
204
205                            // Update job status
206                            let job_id = job.id.clone();
207                            let _ = queue.update_job(&job_id, job).await;
208                        } else {
209                            let job_type = job.job_type.clone();
210                            job.mark_failed("No handler registered".to_string());
211                            let job_id = job.id.clone();
212                            let _ = queue.update_job(&job_id, job).await;
213                            println!(
214                                "[Worker {}] No handler for job type: {}",
215                                worker_id, job_type
216                            );
217                        }
218                    }
219                    Ok(None) => {
220                        // Wait for notification
221                        queue.notified().await;
222                    }
223                    Err(e) => {
224                        eprintln!("[Worker {}] Error dequeuing job: {}", worker_id, e);
225                    }
226                }
227            }
228
229            println!("[Worker {}] Stopped", worker_id);
230        })
231    }
232
233    /// Spawn cleanup task
234    fn spawn_cleanup_task(&self) -> JoinHandle<()> {
235        let queue = Arc::clone(&self.queue);
236        let running = Arc::clone(&self.running);
237        let cleanup_interval = self.config.cleanup_interval_seconds;
238
239        tokio::spawn(async move {
240            let mut interval = interval(Duration::from_secs(cleanup_interval));
241
242            loop {
243                interval.tick().await;
244
245                if !*running.read().await {
246                    break;
247                }
248
249                match queue.cleanup_completed().await {
250                    Ok(count) => {
251                        if count > 0 {
252                            println!("[Cleanup] Removed {} completed jobs", count);
253                        }
254                    }
255                    Err(e) => {
256                        eprintln!("[Cleanup] Error: {}", e);
257                    }
258                }
259            }
260
261            println!("[Cleanup] Stopped");
262        })
263    }
264
265    /// Spawn the Reaper task for zombie job recovery
266    ///
267    /// The Reaper runs every 60 seconds and:
268    /// 1. Scans for jobs with status=Running where heartbeat has expired
269    /// 2. Resets those "zombie" jobs to Pending so they can be re-processed
270    /// 3. Increments retry_count and notifies workers
271    fn spawn_reaper(&self) -> JoinHandle<()> {
272        let queue = Arc::clone(&self.queue);
273        let running = Arc::clone(&self.running);
274
275        tokio::spawn(async move {
276            // Run the Reaper every 60 seconds
277            let mut interval = interval(Duration::from_secs(60));
278
279            loop {
280                interval.tick().await;
281
282                if !*running.read().await {
283                    break;
284                }
285
286                // Get running jobs and check heartbeat (efficient via queue method)
287                let zombies = queue.find_zombie_jobs().await;
288
289                // Revive zombies
290                for job_id in zombies {
291                    if let Ok(Some(mut job)) = queue.get(&job_id).await {
292                        // Reset status so a new worker can pick it up
293                        job.status = super::job::JobStatus::Pending;
294                        job.retry_count += 1;
295                        job.touch(); // Update time so we don't reap it again immediately
296
297                        // Save to DB
298                        let _ = queue.update_job(&job_id, job).await;
299
300                        // Wake up workers
301                        queue.notify_all();
302                    }
303                }
304            }
305
306            println!("[Reaper] Stopped");
307        })
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use crate::workers::job::{Job, JobStatus};
315    use tempfile::TempDir;
316    use tokio::time::sleep;
317
318    #[tokio::test]
319    async fn test_worker_execution() {
320        let temp_dir = TempDir::new().unwrap();
321        let config = WorkerConfig {
322            storage_path: temp_dir.path().to_str().unwrap().to_string(),
323            concurrency: 2,
324            poll_interval_ms: 50,
325            cleanup_interval_seconds: 10, // Short interval for testing
326        };
327
328        let queue = Arc::new(JobQueue::new(config.storage_path.clone()).unwrap());
329        let executor = WorkerExecutor::new(Arc::clone(&queue), config);
330
331        // Register a test handler
332        executor
333            .register_handler("test", |_job| async { Ok(()) })
334            .await;
335
336        // Start executor
337        executor.start().await.unwrap();
338
339        // Enqueue a job
340        let job = Job::new("test");
341        let job_id = queue.enqueue(job).await.unwrap();
342
343        // Wait for job to complete
344        sleep(Duration::from_millis(300)).await;
345
346        // Check job status - it might be completed or already cleaned up
347        let status = queue.get_status(&job_id).await.unwrap();
348        // Either completed or None (cleaned up) is ok
349        assert!(matches!(status, Some(JobStatus::Completed) | None));
350
351        executor.stop().await.unwrap();
352    }
353
354    #[tokio::test]
355    async fn test_graceful_shutdown() {
356        let temp_dir = TempDir::new().unwrap();
357        let config = WorkerConfig {
358            storage_path: temp_dir.path().to_str().unwrap().to_string(),
359            concurrency: 1,
360            poll_interval_ms: 100,
361            cleanup_interval_seconds: 10,
362        };
363
364        let queue = Arc::new(JobQueue::new(config.storage_path.clone()).unwrap());
365        let executor = WorkerExecutor::new(Arc::clone(&queue), config);
366
367        // Register a handler that sleeps for 2 seconds
368        executor
369            .register_handler("long_task", |_job| async {
370                tokio::time::sleep(Duration::from_secs(2)).await;
371                Ok(())
372            })
373            .await;
374
375        executor.start().await.unwrap();
376
377        // Enqueue job
378        let job = Job::new("long_task");
379        let job_id = queue.enqueue(job).await.unwrap();
380
381        // Wait a bit to ensure worker picked it up
382        tokio::time::sleep(Duration::from_millis(100)).await;
383
384        // Verify it is running
385        let status = queue.get_status(&job_id).await.unwrap();
386        assert_eq!(status, Some(JobStatus::Running));
387
388        // Measure shutdown time
389        let start = std::time::Instant::now();
390        executor.stop().await.unwrap();
391        let duration = start.elapsed();
392
393        // Shutdown should wait for running job to complete (proves graceful shutdown)
394        assert!(
395            duration.as_millis() >= 1500,
396            "Shutdown was too fast ({:?}), didn't wait for job",
397            duration
398        );
399
400        // Verify job completed successfully
401        let status = queue.get_status(&job_id).await.unwrap();
402        assert_eq!(status, Some(JobStatus::Completed));
403    }
404
405    // NOTE: test_job_heartbeat_updates removed as it was too slow (starts full executor with reaper).
406    // The heartbeat logic is tested by test_is_heartbeat_expired below.
407
408    #[tokio::test]
409    async fn test_is_heartbeat_expired() {
410        // Test the heartbeat expiration logic
411        let mut job = Job::new("test");
412        job.lease_duration = 1; // 1 second lease
413
414        // Fresh job - not expired
415        job.touch();
416        assert!(!job.is_heartbeat_expired());
417
418        // Wait for lease to expire
419        tokio::time::sleep(Duration::from_secs(2)).await;
420        job.status = JobStatus::Running;
421        assert!(job.is_heartbeat_expired());
422
423        // Touch again - no longer expired
424        job.touch();
425        assert!(!job.is_heartbeat_expired());
426    }
427}