Skip to main content

aurora_db/workers/
executor.rs

1use super::job::Job;
2use super::queue::JobQueue;
3use crate::error::Result;
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{broadcast, RwLock};
10use tokio::task::JoinHandle;
11use tokio::time::{interval, timeout};
12
13/// Job handler function type
14pub type JobHandler =
15    Arc<dyn Fn(Job) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync>;
16
17/// Worker configuration
18#[derive(Clone)]
19pub struct WorkerConfig {
20    pub storage_path: String,
21    pub concurrency: usize,
22    pub poll_interval_ms: u64,
23    pub cleanup_interval_seconds: u64,
24}
25
26impl Default for WorkerConfig {
27    fn default() -> Self {
28        Self {
29            storage_path: "./aurora_workers".to_string(),
30            concurrency: 4,
31            poll_interval_ms: 10, // Faster polling fallback
32            cleanup_interval_seconds: 3600, // 1 hour
33        }
34    }
35}
36
37/// Worker executor that processes jobs
38pub struct WorkerExecutor {
39    queue: Arc<JobQueue>,
40    handlers: Arc<RwLock<HashMap<String, JobHandler>>>,
41    config: WorkerConfig,
42    running: Arc<RwLock<bool>>,
43    worker_handles: Arc<RwLock<Vec<JoinHandle<()>>>>,
44    shutdown_tx: Option<broadcast::Sender<()>>,
45}
46
47impl WorkerExecutor {
48    pub fn new(queue: Arc<JobQueue>, config: WorkerConfig) -> Self {
49        Self {
50            queue,
51            handlers: Arc::new(RwLock::new(HashMap::new())),
52            config,
53            running: Arc::new(RwLock::new(false)),
54            worker_handles: Arc::new(RwLock::new(Vec::new())),
55            shutdown_tx: None,
56        }
57    }
58
59    /// Register a job handler
60    pub async fn register_handler<F, Fut>(&self, job_type: impl Into<String>, handler: F)
61    where
62        F: Fn(Job) -> Fut + Send + Sync + 'static,
63        Fut: Future<Output = Result<()>> + Send + 'static,
64    {
65        let handler = Arc::new(
66            move |job: Job| -> Pin<Box<dyn Future<Output = Result<()>> + Send>> {
67                Box::pin(handler(job))
68            },
69        );
70
71        self.handlers.write().await.insert(job_type.into(), handler);
72    }
73
74    /// Start the worker executor
75    pub async fn start(&mut self) -> Result<()> {
76        let mut running = self.running.write().await;
77        if *running {
78            return Ok(());
79        }
80        *running = true;
81        drop(running);
82
83        let (tx, _) = broadcast::channel(1);
84        self.shutdown_tx = Some(tx.clone());
85
86        // Spawn worker tasks
87        let mut handles = self.worker_handles.write().await;
88        for worker_id in 0..self.config.concurrency {
89            let handle = self.spawn_worker(worker_id, tx.subscribe());
90            handles.push(handle);
91        }
92
93        // Spawn cleanup task
94        let cleanup_handle = self.spawn_cleanup_task(tx.subscribe());
95        handles.push(cleanup_handle);
96
97        // Spawn reaper task for zombie job recovery
98        let reaper_handle = self.spawn_reaper(tx.subscribe());
99        handles.push(reaper_handle);
100
101        Ok(())
102    }
103
104    /// Stop the worker executor
105    pub async fn stop(&mut self) -> Result<()> {
106        let mut running = self.running.write().await;
107        if !*running {
108            return Ok(());
109        }
110        *running = false;
111        drop(running);
112
113        // Send shutdown signal to all tasks
114        if let Some(tx) = self.shutdown_tx.take() {
115            let _ = tx.send(());
116        }
117
118        // Also notify queue to wake up any workers stuck in dequeue
119        self.queue.shutdown().await;
120
121        // Wait for all worker tasks to finish with a timeout
122        let mut handles = self.worker_handles.write().await;
123        for handle in handles.drain(..) {
124            // We give each task a small window to exit gracefully
125            let _ = timeout(Duration::from_millis(500), handle).await;
126        }
127
128        Ok(())
129    }
130
131    /// Spawn a worker task
132    fn spawn_worker(&self, worker_id: usize, mut shutdown_rx: broadcast::Receiver<()>) -> JoinHandle<()> {
133        let queue = Arc::clone(&self.queue);
134        let handlers = Arc::clone(&self.handlers);
135        let running = Arc::clone(&self.running);
136
137        tokio::spawn(async move {
138            loop {
139                // Check if we should stop
140                if !*running.read().await {
141                    break;
142                }
143
144                // HIGH-PERFORMANCE DEQUEUE (Channel-based, O(1))
145                // Use select to allow immediate interruption during dequeue
146                let job_opt = tokio::select! {
147                    res = queue.dequeue() => {
148                        match res {
149                            Ok(Some(job)) => Some(job),
150                            Ok(None) => return, // Channel closed
151                            Err(e) => {
152                                eprintln!("[Worker {}] Dequeue Error: {}", worker_id, e);
153                                tokio::time::sleep(Duration::from_millis(100)).await;
154                                None
155                            }
156                        }
157                    }
158                    _ = shutdown_rx.recv() => break,
159                };
160
161                if let Some(mut job) = job_opt {
162                    // Get handler
163                    let handlers_guard = handlers.read().await;
164                    let handler = handlers_guard.get(&job.job_type).cloned();
165                    drop(handlers_guard);
166
167                    if let Some(handler) = handler {
168                        let result = if let Some(timeout_secs) = job.timeout_seconds {
169                            timeout(Duration::from_secs(timeout_secs), handler(job.clone())).await
170                        } else {
171                            Ok(handler(job.clone()).await)
172                        };
173
174                        match result {
175                            Ok(Ok(())) => { job.mark_completed(); }
176                            Ok(Err(e)) => { job.mark_failed(e.to_string()); }
177                            Err(_) => { job.mark_failed("Timeout".to_string()); }
178                        }
179
180                        let job_id = job.id.clone();
181                        let _ = queue.update_job(&job_id, job).await;
182                    } else {
183                        job.mark_failed("No handler registered".to_string());
184                        let job_id = job.id.clone();
185                        let _ = queue.update_job(&job_id, job).await;
186                    }
187                }
188            }
189        })
190    }
191
192    /// Spawn cleanup task
193    fn spawn_cleanup_task(&self, mut shutdown_rx: broadcast::Receiver<()>) -> JoinHandle<()> {
194        let queue = Arc::clone(&self.queue);
195        let cleanup_interval = self.config.cleanup_interval_seconds;
196
197        tokio::spawn(async move {
198            let mut interval = interval(Duration::from_secs(cleanup_interval));
199            loop {
200                tokio::select! {
201                    _ = interval.tick() => {
202                        let _ = queue.cleanup_completed().await;
203                    }
204                    _ = shutdown_rx.recv() => break,
205                }
206            }
207        })
208    }
209
210    /// Spawn the Reaper task for zombie job recovery
211    fn spawn_reaper(&self, mut shutdown_rx: broadcast::Receiver<()>) -> JoinHandle<()> {
212        let queue = Arc::clone(&self.queue);
213
214        tokio::spawn(async move {
215            let mut interval = interval(Duration::from_secs(60));
216            loop {
217                tokio::select! {
218                    _ = interval.tick() => {
219                        let zombies = queue.find_zombie_jobs().await;
220                        for job_id in zombies {
221                            if let Ok(Some(mut job)) = queue.get(&job_id).await {
222                                job.status = super::job::JobStatus::Pending;
223                                job.retry_count += 1;
224                                job.touch();
225                                let _ = queue.update_job(&job_id, job).await;
226                            }
227                        }
228                    }
229                    _ = shutdown_rx.recv() => break,
230                }
231            }
232        })
233    }
234}