aurora_db/workers/
executor.rs1use super::job::Job;
2use super::queue::JobQueue;
3use crate::error::Result;
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::RwLock;
10use tokio::task::JoinHandle;
11use tokio::time::{interval, timeout};
12
13pub type JobHandler =
15 Arc<dyn Fn(Job) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync>;
16
17#[derive(Clone)]
19pub struct WorkerConfig {
20 pub storage_path: String,
21 pub concurrency: usize,
22 pub poll_interval_ms: u64,
23 pub cleanup_interval_seconds: u64,
24}
25
26impl Default for WorkerConfig {
27 fn default() -> Self {
28 Self {
29 storage_path: "./aurora_workers".to_string(),
30 concurrency: 4,
31 poll_interval_ms: 100,
32 cleanup_interval_seconds: 3600, }
34 }
35}
36
37pub struct WorkerExecutor {
39 queue: Arc<JobQueue>,
40 handlers: Arc<RwLock<HashMap<String, JobHandler>>>,
41 config: WorkerConfig,
42 running: Arc<RwLock<bool>>,
43 worker_handles: Arc<RwLock<Vec<JoinHandle<()>>>>,
44}
45
46impl WorkerExecutor {
47 pub fn new(queue: Arc<JobQueue>, config: WorkerConfig) -> Self {
48 Self {
49 queue,
50 handlers: Arc::new(RwLock::new(HashMap::new())),
51 config,
52 running: Arc::new(RwLock::new(false)),
53 worker_handles: Arc::new(RwLock::new(Vec::new())),
54 }
55 }
56
57 pub async fn register_handler<F, Fut>(&self, job_type: impl Into<String>, handler: F)
59 where
60 F: Fn(Job) -> Fut + Send + Sync + 'static,
61 Fut: Future<Output = Result<()>> + Send + 'static,
62 {
63 let handler = Arc::new(
64 move |job: Job| -> Pin<Box<dyn Future<Output = Result<()>> + Send>> {
65 Box::pin(handler(job))
66 },
67 );
68
69 self.handlers.write().await.insert(job_type.into(), handler);
70 }
71
72 pub async fn start(&self) -> Result<()> {
74 let mut running = self.running.write().await;
75 if *running {
76 return Ok(());
77 }
78 *running = true;
79 drop(running);
80
81 let mut handles = self.worker_handles.write().await;
83 for worker_id in 0..self.config.concurrency {
84 let handle = self.spawn_worker(worker_id);
85 handles.push(handle);
86 }
87
88 let cleanup_handle = self.spawn_cleanup_task();
90 handles.push(cleanup_handle);
91
92 let reaper_handle = self.spawn_reaper();
94 handles.push(reaper_handle);
95
96 Ok(())
97 }
98
99 pub async fn stop(&self) -> Result<()> {
101 let mut running = self.running.write().await;
102 *running = false;
103 drop(running);
104
105 self.queue.notify_all();
107
108 let mut handles = self.worker_handles.write().await;
110 for handle in handles.drain(..) {
111 if let Err(e) = handle.await {
112 eprintln!("Worker panic during shutdown: {:?}", e);
113 }
114 }
115
116 Ok(())
117 }
118
119 fn spawn_worker(&self, worker_id: usize) -> JoinHandle<()> {
121 let queue = Arc::clone(&self.queue);
122 let handlers = Arc::clone(&self.handlers);
123 let running = Arc::clone(&self.running);
124
125 tokio::spawn(async move {
126 loop {
127 if !*running.read().await {
129 break;
130 }
131
132 match queue.dequeue().await {
134 Ok(Some(mut job)) => {
135 println!(
136 "[Worker {}] Processing job: {} ({})",
137 worker_id, job.id, job.job_type
138 );
139
140 let handlers = handlers.read().await;
142 let handler = handlers.get(&job.job_type);
143
144 if let Some(handler) = handler {
145 let handler = Arc::clone(handler);
146 drop(handlers);
147
148 let job_id_for_heartbeat = job.id.clone();
150 let queue_for_heartbeat = Arc::clone(&queue);
151 let mut heartbeat_job = job.clone();
152
153 let heartbeat_interval = Duration::from_secs(15);
155 let mut heartbeat_tick = interval(heartbeat_interval);
156 heartbeat_tick
157 .set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
158
159 let job_future = async {
161 if let Some(timeout_secs) = job.timeout_seconds {
162 timeout(Duration::from_secs(timeout_secs), handler(job.clone()))
163 .await
164 } else {
165 Ok(handler(job.clone()).await)
166 }
167 };
168
169 let result = {
171 tokio::pin!(job_future);
172
173 loop {
174 tokio::select! {
175 biased;
176
177 result = &mut job_future => {
179 break result;
180 }
181
182 _ = heartbeat_tick.tick() => {
184 heartbeat_job.touch();
185 let _ = queue_for_heartbeat
186 .update_job(&job_id_for_heartbeat, heartbeat_job.clone())
187 .await;
188 }
189 }
190 }
191 };
192
193 match result {
194 Ok(Ok(())) => {
195 job.mark_completed();
196 }
197 Ok(Err(e)) => {
198 job.mark_failed(e.to_string());
199 }
200 Err(_) => {
201 job.mark_failed("Timeout".to_string());
202 }
203 }
204
205 let job_id = job.id.clone();
207 let _ = queue.update_job(&job_id, job).await;
208 } else {
209 let job_type = job.job_type.clone();
210 job.mark_failed("No handler registered".to_string());
211 let job_id = job.id.clone();
212 let _ = queue.update_job(&job_id, job).await;
213 println!(
214 "[Worker {}] No handler for job type: {}",
215 worker_id, job_type
216 );
217 }
218 }
219 Ok(None) => {
220 queue.notified().await;
222 }
223 Err(e) => {
224 eprintln!("[Worker {}] Error dequeuing job: {}", worker_id, e);
225 }
226 }
227 }
228
229 println!("[Worker {}] Stopped", worker_id);
230 })
231 }
232
233 fn spawn_cleanup_task(&self) -> JoinHandle<()> {
235 let queue = Arc::clone(&self.queue);
236 let running = Arc::clone(&self.running);
237 let cleanup_interval = self.config.cleanup_interval_seconds;
238
239 tokio::spawn(async move {
240 let mut interval = interval(Duration::from_secs(cleanup_interval));
241
242 loop {
243 interval.tick().await;
244
245 if !*running.read().await {
246 break;
247 }
248
249 match queue.cleanup_completed().await {
250 Ok(count) => {
251 if count > 0 {
252 println!("[Cleanup] Removed {} completed jobs", count);
253 }
254 }
255 Err(e) => {
256 eprintln!("[Cleanup] Error: {}", e);
257 }
258 }
259 }
260
261 println!("[Cleanup] Stopped");
262 })
263 }
264
265 fn spawn_reaper(&self) -> JoinHandle<()> {
272 let queue = Arc::clone(&self.queue);
273 let running = Arc::clone(&self.running);
274
275 tokio::spawn(async move {
276 let mut interval = interval(Duration::from_secs(60));
278
279 loop {
280 interval.tick().await;
281
282 if !*running.read().await {
283 break;
284 }
285
286 let zombies = queue.find_zombie_jobs().await;
288
289 for job_id in zombies {
291 if let Ok(Some(mut job)) = queue.get(&job_id).await {
292 job.status = super::job::JobStatus::Pending;
294 job.retry_count += 1;
295 job.touch(); let _ = queue.update_job(&job_id, job).await;
299
300 queue.notify_all();
302 }
303 }
304 }
305
306 println!("[Reaper] Stopped");
307 })
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use crate::workers::job::{Job, JobStatus};
315 use tempfile::TempDir;
316 use tokio::time::sleep;
317
318 #[tokio::test]
319 async fn test_worker_execution() {
320 let temp_dir = TempDir::new().unwrap();
321 let config = WorkerConfig {
322 storage_path: temp_dir.path().to_str().unwrap().to_string(),
323 concurrency: 2,
324 poll_interval_ms: 50,
325 cleanup_interval_seconds: 10, };
327
328 let queue = Arc::new(JobQueue::new(config.storage_path.clone()).unwrap());
329 let executor = WorkerExecutor::new(Arc::clone(&queue), config);
330
331 executor
333 .register_handler("test", |_job| async { Ok(()) })
334 .await;
335
336 executor.start().await.unwrap();
338
339 let job = Job::new("test");
341 let job_id = queue.enqueue(job).await.unwrap();
342
343 sleep(Duration::from_millis(300)).await;
345
346 let status = queue.get_status(&job_id).await.unwrap();
348 assert!(matches!(status, Some(JobStatus::Completed) | None));
350
351 executor.stop().await.unwrap();
352 }
353
354 #[tokio::test]
355 async fn test_graceful_shutdown() {
356 let temp_dir = TempDir::new().unwrap();
357 let config = WorkerConfig {
358 storage_path: temp_dir.path().to_str().unwrap().to_string(),
359 concurrency: 1,
360 poll_interval_ms: 100,
361 cleanup_interval_seconds: 10,
362 };
363
364 let queue = Arc::new(JobQueue::new(config.storage_path.clone()).unwrap());
365 let executor = WorkerExecutor::new(Arc::clone(&queue), config);
366
367 executor
369 .register_handler("long_task", |_job| async {
370 tokio::time::sleep(Duration::from_secs(2)).await;
371 Ok(())
372 })
373 .await;
374
375 executor.start().await.unwrap();
376
377 let job = Job::new("long_task");
379 let job_id = queue.enqueue(job).await.unwrap();
380
381 tokio::time::sleep(Duration::from_millis(100)).await;
383
384 let status = queue.get_status(&job_id).await.unwrap();
386 assert_eq!(status, Some(JobStatus::Running));
387
388 let start = std::time::Instant::now();
390 executor.stop().await.unwrap();
391 let duration = start.elapsed();
392
393 assert!(
395 duration.as_millis() >= 1500,
396 "Shutdown was too fast ({:?}), didn't wait for job",
397 duration
398 );
399
400 let status = queue.get_status(&job_id).await.unwrap();
402 assert_eq!(status, Some(JobStatus::Completed));
403 }
404
405 #[tokio::test]
409 async fn test_is_heartbeat_expired() {
410 let mut job = Job::new("test");
412 job.lease_duration = 1; job.touch();
416 assert!(!job.is_heartbeat_expired());
417
418 tokio::time::sleep(Duration::from_secs(2)).await;
420 job.status = JobStatus::Running;
421 assert!(job.is_heartbeat_expired());
422
423 job.touch();
425 assert!(!job.is_heartbeat_expired());
426 }
427}