1use std::{
2 fmt::Display,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6use moduforge_state::debug;
7use tokio::sync::{mpsc, oneshot};
8use async_trait::async_trait;
9use tokio::select;
10
11#[derive(Debug, Clone, PartialEq)]
19pub enum TaskStatus {
20 Pending,
21 Processing,
22 Completed,
23 Failed(String),
24 Timeout,
25 Cancelled,
26}
27
28#[derive(Debug)]
36pub enum ProcessorError {
37 QueueFull,
38 TaskFailed(String),
39 InternalError(String),
40 TaskTimeout,
41 TaskCancelled,
42 RetryExhausted(String),
43}
44
45impl Display for ProcessorError {
46 fn fmt(
47 &self,
48 f: &mut std::fmt::Formatter<'_>,
49 ) -> std::fmt::Result {
50 match self {
51 ProcessorError::QueueFull => write!(f, "Task queue is full"),
52 ProcessorError::TaskFailed(msg) => {
53 write!(f, "Task failed: {}", msg)
54 },
55 ProcessorError::InternalError(msg) => {
56 write!(f, "Internal error: {}", msg)
57 },
58 ProcessorError::TaskTimeout => {
59 write!(f, "Task execution timed out")
60 },
61 ProcessorError::TaskCancelled => write!(f, "Task was cancelled"),
62 ProcessorError::RetryExhausted(msg) => {
63 write!(f, "Retry attempts exhausted: {}", msg)
64 },
65 }
66 }
67}
68
69impl std::error::Error for ProcessorError {}
70
71#[derive(Clone, Debug)]
78pub struct ProcessorConfig {
79 pub max_queue_size: usize,
80 pub max_concurrent_tasks: usize,
81 pub task_timeout: Duration,
82 pub max_retries: u32,
83 pub retry_delay: Duration,
84}
85
86impl Default for ProcessorConfig {
87 fn default() -> Self {
88 Self {
89 max_queue_size: 1000,
90 max_concurrent_tasks: 10,
91 task_timeout: Duration::from_secs(30),
92 max_retries: 3,
93 retry_delay: Duration::from_secs(1),
94 }
95 }
96}
97
98#[derive(Debug, Default, Clone)]
108pub struct ProcessorStats {
109 pub total_tasks: u64,
110 pub completed_tasks: u64,
111 pub failed_tasks: u64,
112 pub timeout_tasks: u64,
113 pub cancelled_tasks: u64,
114 pub average_processing_time: Duration,
115 pub current_queue_size: usize,
116 pub current_processing_tasks: usize,
117}
118
119#[derive(Debug)]
127pub struct TaskResult<T, O>
128where
129 T: Send + Sync,
130 O: Send + Sync,
131{
132 pub task_id: u64,
133 pub status: TaskStatus,
134 pub task: Option<T>,
135 pub output: Option<O>,
136 pub error: Option<String>,
137 pub processing_time: Option<Duration>,
138}
139
140struct QueuedTask<T, O>
147where
148 T: Send + Sync,
149 O: Send + Sync,
150{
151 task: T,
152 task_id: u64,
153 result_tx: mpsc::Sender<TaskResult<T, O>>,
154 priority: u32,
155 retry_count: u32,
156}
157
158pub struct TaskQueue<T, O>
164where
165 T: Send + Sync,
166 O: Send + Sync,
167{
168 queue: mpsc::Sender<QueuedTask<T, O>>,
169 queue_rx: Arc<tokio::sync::Mutex<Option<mpsc::Receiver<QueuedTask<T, O>>>>>,
170 next_task_id: Arc<tokio::sync::Mutex<u64>>,
171 stats: Arc<tokio::sync::Mutex<ProcessorStats>>,
172}
173
174impl<T: Clone + Send + Sync + 'static, O: Clone + Send + Sync + 'static>
175 TaskQueue<T, O>
176{
177 pub fn new(config: &ProcessorConfig) -> Self {
178 let (tx, rx) = mpsc::channel(config.max_queue_size);
179 Self {
180 queue: tx,
181 queue_rx: Arc::new(tokio::sync::Mutex::new(Some(rx))),
182 next_task_id: Arc::new(tokio::sync::Mutex::new(0)),
183 stats: Arc::new(tokio::sync::Mutex::new(ProcessorStats::default())),
184 }
185 }
186
187 pub async fn enqueue_task(
188 &self,
189 task: T,
190 priority: u32,
191 ) -> Result<(u64, mpsc::Receiver<TaskResult<T, O>>), ProcessorError> {
192 let mut task_id = self.next_task_id.lock().await;
193 *task_id += 1;
194 let current_id = *task_id;
195
196 let (result_tx, result_rx) = mpsc::channel(1);
197 let queued_task = QueuedTask {
198 task,
199 task_id: current_id,
200 result_tx,
201 priority,
202 retry_count: 0,
203 };
204
205 self.queue
206 .send(queued_task)
207 .await
208 .map_err(|_| ProcessorError::QueueFull)?;
209
210 let mut stats = self.stats.lock().await;
211 stats.total_tasks += 1;
212 stats.current_queue_size += 1;
213
214 Ok((current_id, result_rx))
215 }
216
217 pub async fn get_next_ready(
218 &self
219 ) -> Option<(T, u64, mpsc::Sender<TaskResult<T, O>>, u32, u32)> {
220 let mut rx_guard = self.queue_rx.lock().await;
221 if let Some(rx) = rx_guard.as_mut() {
222 if let Some(queued) = rx.recv().await {
223 let mut stats = self.stats.lock().await;
224 stats.current_queue_size -= 1;
225 stats.current_processing_tasks += 1;
226 return Some((
227 queued.task,
228 queued.task_id,
229 queued.result_tx,
230 queued.priority,
231 queued.retry_count,
232 ));
233 }
234 }
235 None
236 }
237
238 pub async fn get_stats(&self) -> ProcessorStats {
239 self.stats.lock().await.clone()
240 }
241
242 pub async fn update_stats(
243 &self,
244 result: &TaskResult<T, O>,
245 ) {
246 let mut stats = self.stats.lock().await;
247 match result.status {
248 TaskStatus::Completed => {
249 stats.completed_tasks += 1;
250 if let Some(processing_time) = result.processing_time {
251 stats.average_processing_time =
252 (stats.average_processing_time + processing_time) / 2;
253 }
254 },
255 TaskStatus::Failed(_) => stats.failed_tasks += 1,
256 TaskStatus::Timeout => stats.timeout_tasks += 1,
257 TaskStatus::Cancelled => stats.cancelled_tasks += 1,
258 _ => {},
259 }
260 stats.current_processing_tasks -= 1;
261 }
262}
263
264#[async_trait]
267pub trait TaskProcessor<T, O>: Send + Sync + 'static
268where
269 T: Clone + Send + Sync + 'static,
270 O: Clone + Send + Sync + 'static,
271{
272 async fn process(
273 &self,
274 task: T,
275 ) -> Result<O, ProcessorError>;
276}
277
278pub struct AsyncProcessor<T, O, P>
284where
285 T: Clone + Send + Sync + 'static,
286 O: Clone + Send + Sync + 'static,
287 P: TaskProcessor<T, O>,
288{
289 task_queue: Arc<TaskQueue<T, O>>,
290 config: ProcessorConfig,
291 processor: Arc<P>,
292 shutdown_tx: Option<oneshot::Sender<()>>,
293 handle: Option<tokio::task::JoinHandle<()>>,
294}
295
296impl<T, O, P> AsyncProcessor<T, O, P>
297where
298 T: Clone + Send + Sync + 'static,
299 O: Clone + Send + Sync + 'static,
300 P: TaskProcessor<T, O>,
301{
302 pub fn new(
304 config: ProcessorConfig,
305 processor: P,
306 ) -> Self {
307 let task_queue = Arc::new(TaskQueue::new(&config));
308 Self {
309 task_queue,
310 config,
311 processor: Arc::new(processor),
312 shutdown_tx: None,
313 handle: None,
314 }
315 }
316
317 pub async fn submit_task(
320 &self,
321 task: T,
322 priority: u32,
323 ) -> Result<(u64, mpsc::Receiver<TaskResult<T, O>>), ProcessorError> {
324 self.task_queue.enqueue_task(task, priority).await
325 }
326
327 pub fn start(&mut self) {
330 let queue = self.task_queue.clone();
331 let processor = self.processor.clone();
332 let config = self.config.clone();
333 let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
334
335 self.shutdown_tx = Some(shutdown_tx);
336
337 let handle = tokio::spawn(async move {
338 let mut join_set = tokio::task::JoinSet::new();
339
340 loop {
341 select! {
342 _ = &mut shutdown_rx => {
344 break;
345 }
346
347 Some(result) = join_set.join_next() => {
349 if let Err(e) = result {
350 debug!("Task failed: {}", e);
351 }
352 }
353
354 Some((task, task_id, result_tx, _priority, retry_count)) = queue.get_next_ready() => {
356 if join_set.len() < config.max_concurrent_tasks {
357 let processor = processor.clone();
358 let config = config.clone();
359 let queue = queue.clone();
360
361 join_set.spawn(async move {
362 let start_time = Instant::now();
363 let mut current_retry = retry_count;
364
365 loop {
366 let result = tokio::time::timeout(
367 config.task_timeout,
368 processor.process(task.clone())
369 ).await;
370
371 match result {
372 Ok(Ok(output)) => {
373 let processing_time = start_time.elapsed();
374 let task_result = TaskResult {
375 task_id,
376 status: TaskStatus::Completed,
377 task: Some(task),
378 output: Some(output),
379 error: None,
380 processing_time: Some(processing_time),
381 };
382 queue.update_stats(&task_result).await;
383 let _ = result_tx.send(task_result).await;
384 break;
385 }
386 Ok(Err(e)) => {
387 if current_retry < config.max_retries {
388 current_retry += 1;
389 tokio::time::sleep(config.retry_delay).await;
390 continue;
391 }
392 let task_result = TaskResult {
393 task_id,
394 status: TaskStatus::Failed(e.to_string()),
395 task: Some(task),
396 output: None,
397 error: Some(e.to_string()),
398 processing_time: Some(start_time.elapsed()),
399 };
400 queue.update_stats(&task_result).await;
401 let _ = result_tx.send(task_result).await;
402 break;
403 }
404 Err(_) => {
405 let task_result = TaskResult {
406 task_id,
407 status: TaskStatus::Timeout,
408 task: Some(task),
409 output: None,
410 error: Some("Task execution timed out".to_string()),
411 processing_time: Some(start_time.elapsed()),
412 };
413 queue.update_stats(&task_result).await;
414 let _ = result_tx.send(task_result).await;
415 break;
416 }
417 }
418 }
419 });
420 }
421 }
422 }
423 }
424 });
425
426 self.handle = Some(handle);
427 }
428
429 pub async fn shutdown(&mut self) -> Result<(), ProcessorError> {
432 if let Some(shutdown_tx) = self.shutdown_tx.take() {
433 shutdown_tx.send(()).map_err(|_| {
434 ProcessorError::InternalError(
435 "Failed to send shutdown signal".to_string(),
436 )
437 })?;
438
439 if let Some(handle) = self.handle.take() {
440 handle.await.map_err(|e| {
441 ProcessorError::InternalError(format!(
442 "Failed to join processor task: {}",
443 e
444 ))
445 })?;
446 }
447 }
448 Ok(())
449 }
450
451 pub async fn get_stats(&self) -> ProcessorStats {
452 self.task_queue.get_stats().await
453 }
454}
455
456impl<T, O, P> Drop for AsyncProcessor<T, O, P>
458where
459 T: Clone + Send + Sync + 'static,
460 O: Clone + Send + Sync + 'static,
461 P: TaskProcessor<T, O>,
462{
463 fn drop(&mut self) {
464 if self.shutdown_tx.is_some() {
465 let rt = tokio::runtime::Runtime::new().unwrap();
467 rt.block_on(self.shutdown()).unwrap();
468 }
469 }
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475
476 struct TestProcessor;
477
478 #[async_trait::async_trait]
479 impl TaskProcessor<i32, String> for TestProcessor {
480 async fn process(
481 &self,
482 task: i32,
483 ) -> Result<String, ProcessorError> {
484 tokio::time::sleep(Duration::from_millis(100)).await;
485 Ok(format!("Processed: {}", task))
486 }
487 }
488
489 #[tokio::test]
490 async fn test_async_processor() {
491 let config = ProcessorConfig {
492 max_queue_size: 100,
493 max_concurrent_tasks: 5,
494 task_timeout: Duration::from_secs(1),
495 max_retries: 3,
496 retry_delay: Duration::from_secs(1),
497 };
498 let mut processor = AsyncProcessor::new(config, TestProcessor);
499 processor.start();
500
501 let mut receivers = Vec::new();
502 for i in 0..10 {
503 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
504 receivers.push(rx);
505 }
506
507 for mut rx in receivers {
508 let result = rx.recv().await.unwrap();
509 assert_eq!(result.status, TaskStatus::Completed);
510 assert!(result.error.is_none());
511 assert!(result.output.is_some());
512 }
513 }
514
515 #[tokio::test]
516 async fn test_processor_shutdown() {
517 let config = ProcessorConfig {
518 max_queue_size: 100,
519 max_concurrent_tasks: 5,
520 task_timeout: Duration::from_secs(1),
521 max_retries: 3,
522 retry_delay: Duration::from_secs(1),
523 };
524 let mut processor = AsyncProcessor::new(config, TestProcessor);
525 processor.start();
526
527 let mut receivers = Vec::new();
529 for i in 0..5 {
530 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
531 receivers.push(rx);
532 }
533
534 processor.shutdown().await.unwrap();
536
537 for mut rx in receivers {
539 let result = rx.recv().await.unwrap();
540 assert_eq!(result.status, TaskStatus::Completed);
541 }
542 }
543
544 #[tokio::test]
545 async fn test_processor_auto_shutdown() {
546 let config = ProcessorConfig {
547 max_queue_size: 100,
548 max_concurrent_tasks: 5,
549 task_timeout: Duration::from_secs(1),
550 max_retries: 3,
551 retry_delay: Duration::from_secs(1),
552 };
553 let mut processor = AsyncProcessor::new(config, TestProcessor);
554 processor.start();
555
556 let mut receivers = Vec::new();
558 for i in 0..5 {
559 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
560 receivers.push(rx);
561 }
562
563 drop(processor);
565
566 for mut rx in receivers {
568 let result = rx.recv().await.unwrap();
569 assert_eq!(result.status, TaskStatus::Completed);
570 }
571 }
572}