Skip to main content

aurora_db/workers/
executor.rs

1use super::job::Job;
2use super::queue::JobQueue;
3use crate::error::Result;
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{RwLock, broadcast};
10use tokio::task::JoinHandle;
11use tokio::time::{interval, timeout};
12
13/// Job handler function type
14pub type JobHandler =
15    Arc<dyn Fn(Job) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync>;
16
17/// Worker configuration
18#[derive(Clone)]
19pub struct WorkerConfig {
20    pub storage_path: String,
21    pub concurrency: usize,
22    pub poll_interval_ms: u64,
23    pub cleanup_interval_seconds: u64,
24}
25
26impl Default for WorkerConfig {
27    fn default() -> Self {
28        Self {
29            storage_path: "./aurora_workers".to_string(),
30            concurrency: 4,
31            poll_interval_ms: 10,           // Faster polling fallback
32            cleanup_interval_seconds: 3600, // 1 hour
33        }
34    }
35}
36
37/// Worker executor that processes jobs
38pub struct WorkerExecutor {
39    queue: Arc<JobQueue>,
40    handlers: Arc<RwLock<HashMap<String, JobHandler>>>,
41    config: WorkerConfig,
42    running: Arc<RwLock<bool>>,
43    worker_handles: Arc<RwLock<Vec<JoinHandle<()>>>>,
44    shutdown_tx: Option<broadcast::Sender<()>>,
45}
46
47impl WorkerExecutor {
48    pub fn new(queue: Arc<JobQueue>, config: WorkerConfig) -> Self {
49        Self {
50            queue,
51            handlers: Arc::new(RwLock::new(HashMap::new())),
52            config,
53            running: Arc::new(RwLock::new(false)),
54            worker_handles: Arc::new(RwLock::new(Vec::new())),
55            shutdown_tx: None,
56        }
57    }
58
59    /// Register a job handler
60    pub async fn register_handler<F, Fut>(&self, job_type: impl Into<String>, handler: F)
61    where
62        F: Fn(Job) -> Fut + Send + Sync + 'static,
63        Fut: Future<Output = Result<()>> + Send + 'static,
64    {
65        let handler = Arc::new(
66            move |job: Job| -> Pin<Box<dyn Future<Output = Result<()>> + Send>> {
67                Box::pin(handler(job))
68            },
69        );
70
71        self.handlers.write().await.insert(job_type.into(), handler);
72    }
73
74    /// Start the worker executor
75    pub async fn start(&mut self) -> Result<()> {
76        let mut running = self.running.write().await;
77        if *running {
78            return Ok(());
79        }
80        *running = true;
81        drop(running);
82
83        let (tx, _) = broadcast::channel(1);
84        self.shutdown_tx = Some(tx.clone());
85
86        // Spawn worker tasks
87        let mut handles = self.worker_handles.write().await;
88        for worker_id in 0..self.config.concurrency {
89            let handle = self.spawn_worker(worker_id, tx.subscribe());
90            handles.push(handle);
91        }
92
93        // Spawn cleanup task
94        let cleanup_handle = self.spawn_cleanup_task(tx.subscribe());
95        handles.push(cleanup_handle);
96
97        // Spawn reaper task for zombie job recovery
98        let reaper_handle = self.spawn_reaper(tx.subscribe());
99        handles.push(reaper_handle);
100
101        Ok(())
102    }
103
104    /// Stop the worker executor
105    pub async fn stop(&mut self) -> Result<()> {
106        let mut running = self.running.write().await;
107        if !*running {
108            return Ok(());
109        }
110        *running = false;
111        drop(running);
112
113        // Send shutdown signal to all tasks
114        if let Some(tx) = self.shutdown_tx.take() {
115            let _ = tx.send(());
116        }
117
118        // Also notify queue to wake up any workers stuck in dequeue
119        self.queue.shutdown().await;
120
121        // Wait for all worker tasks to finish with a timeout
122        let mut handles = self.worker_handles.write().await;
123        for handle in handles.drain(..) {
124            // We give each task a small window to exit gracefully
125            let _ = timeout(Duration::from_millis(500), handle).await;
126        }
127
128        Ok(())
129    }
130
131    /// Spawn a worker task
132    fn spawn_worker(
133        &self,
134        worker_id: usize,
135        mut shutdown_rx: broadcast::Receiver<()>,
136    ) -> JoinHandle<()> {
137        let queue = Arc::clone(&self.queue);
138        let handlers = Arc::clone(&self.handlers);
139        let running = Arc::clone(&self.running);
140
141        tokio::spawn(async move {
142            loop {
143                // Check if we should stop
144                if !*running.read().await {
145                    break;
146                }
147
148                // HIGH-PERFORMANCE DEQUEUE (Channel-based, O(1))
149                // Use select to allow immediate interruption during dequeue
150                let job_opt = tokio::select! {
151                    res = queue.dequeue() => {
152                        match res {
153                            Ok(Some(job)) => Some(job),
154                            Ok(None) => return, // Channel closed
155                            Err(e) => {
156                                eprintln!("[Worker {}] Dequeue Error: {}", worker_id, e);
157                                tokio::time::sleep(Duration::from_millis(100)).await;
158                                None
159                            }
160                        }
161                    }
162                    _ = shutdown_rx.recv() => break,
163                };
164
165                if let Some(mut job) = job_opt {
166                    // Get handler
167                    let handlers_guard = handlers.read().await;
168                    let handler = handlers_guard.get(&job.job_type).cloned();
169                    drop(handlers_guard);
170
171                    if let Some(handler) = handler {
172                        let result = if let Some(timeout_secs) = job.timeout_seconds {
173                            timeout(Duration::from_secs(timeout_secs), handler(job.clone())).await
174                        } else {
175                            Ok(handler(job.clone()).await)
176                        };
177
178                        match result {
179                            Ok(Ok(())) => {
180                                job.mark_completed();
181                            }
182                            Ok(Err(e)) => {
183                                job.mark_failed(e.to_string());
184                            }
185                            Err(_) => {
186                                job.mark_failed("Timeout".to_string());
187                            }
188                        }
189
190                        let job_id = job.id.clone();
191                        let _ = queue.update_job(&job_id, job).await;
192                    } else {
193                        job.mark_failed("No handler registered".to_string());
194                        let job_id = job.id.clone();
195                        let _ = queue.update_job(&job_id, job).await;
196                    }
197                }
198            }
199        })
200    }
201
202    /// Spawn cleanup task
203    fn spawn_cleanup_task(&self, mut shutdown_rx: broadcast::Receiver<()>) -> JoinHandle<()> {
204        let queue = Arc::clone(&self.queue);
205        let cleanup_interval = self.config.cleanup_interval_seconds;
206
207        tokio::spawn(async move {
208            let mut interval = interval(Duration::from_secs(cleanup_interval));
209            loop {
210                tokio::select! {
211                    _ = interval.tick() => {
212                        let _ = queue.cleanup_completed().await;
213                    }
214                    _ = shutdown_rx.recv() => break,
215                }
216            }
217        })
218    }
219
220    /// Spawn the Reaper task for zombie job recovery
221    fn spawn_reaper(&self, mut shutdown_rx: broadcast::Receiver<()>) -> JoinHandle<()> {
222        let queue = Arc::clone(&self.queue);
223
224        tokio::spawn(async move {
225            let mut interval = interval(Duration::from_secs(60));
226            loop {
227                tokio::select! {
228                    _ = interval.tick() => {
229                        let zombies = queue.find_zombie_jobs().await;
230                        for job_id in zombies {
231                            if let Ok(Some(mut job)) = queue.get(&job_id).await {
232                                job.status = super::job::JobStatus::Pending;
233                                job.retry_count += 1;
234                                job.touch();
235                                let _ = queue.update_job(&job_id, job).await;
236                            }
237                        }
238                    }
239                    _ = shutdown_rx.recv() => break,
240                }
241            }
242        })
243    }
244}