moduforge_runtime/
async_processor.rs

1use std::{
2    fmt::Display,
3    sync::Arc,
4    time::{Duration, Instant},
5};
6use moduforge_state::debug;
7use tokio::sync::{mpsc, oneshot};
8use async_trait::async_trait;
9use tokio::select;
10
11/// 任务处理的结果状态
12/// - Pending: 任务等待处理
13/// - Processing: 任务正在处理中
14/// - Completed: 任务已完成
15/// - Failed: 任务处理失败,包含错误信息
16/// - Timeout: 任务执行超时
17/// - Cancelled: 任务被取消
18#[derive(Debug, Clone, PartialEq)]
19pub enum TaskStatus {
20    Pending,
21    Processing,
22    Completed,
23    Failed(String),
24    Timeout,
25    Cancelled,
26}
27
28/// 任务处理器的错误类型
29/// - QueueFull: 任务队列已满
30/// - TaskFailed: 任务执行失败
31/// - InternalError: 内部错误
32/// - TaskTimeout: 任务执行超时
33/// - TaskCancelled: 任务被取消
34/// - RetryExhausted: 重试次数耗尽
35#[derive(Debug)]
36pub enum ProcessorError {
37    QueueFull,
38    TaskFailed(String),
39    InternalError(String),
40    TaskTimeout,
41    TaskCancelled,
42    RetryExhausted(String),
43}
44
45impl Display for ProcessorError {
46    fn fmt(
47        &self,
48        f: &mut std::fmt::Formatter<'_>,
49    ) -> std::fmt::Result {
50        match self {
51            ProcessorError::QueueFull => write!(f, "Task queue is full"),
52            ProcessorError::TaskFailed(msg) => {
53                write!(f, "Task failed: {}", msg)
54            },
55            ProcessorError::InternalError(msg) => {
56                write!(f, "Internal error: {}", msg)
57            },
58            ProcessorError::TaskTimeout => {
59                write!(f, "Task execution timed out")
60            },
61            ProcessorError::TaskCancelled => write!(f, "Task was cancelled"),
62            ProcessorError::RetryExhausted(msg) => {
63                write!(f, "Retry attempts exhausted: {}", msg)
64            },
65        }
66    }
67}
68
69impl std::error::Error for ProcessorError {}
70
71/// 任务处理器的配置参数
72/// - max_queue_size: 任务队列的最大容量
73/// - max_concurrent_tasks: 最大并发任务数
74/// - task_timeout: 单个任务的最大执行时间
75/// - max_retries: 最大重试次数
76/// - retry_delay: 重试延迟时间
77#[derive(Clone, Debug)]
78pub struct ProcessorConfig {
79    pub max_queue_size: usize,
80    pub max_concurrent_tasks: usize,
81    pub task_timeout: Duration,
82    pub max_retries: u32,
83    pub retry_delay: Duration,
84}
85
86impl Default for ProcessorConfig {
87    fn default() -> Self {
88        Self {
89            max_queue_size: 1000,
90            max_concurrent_tasks: 10,
91            task_timeout: Duration::from_secs(30),
92            max_retries: 3,
93            retry_delay: Duration::from_secs(1),
94        }
95    }
96}
97
98/// 任务处理器的统计信息
99/// - total_tasks: 总任务数
100/// - completed_tasks: 已完成任务数
101/// - failed_tasks: 失败任务数
102/// - timeout_tasks: 超时任务数
103/// - cancelled_tasks: 取消任务数
104/// - average_processing_time: 平均处理时间
105/// - current_queue_size: 当前队列大小
106/// - current_processing_tasks: 当前处理任务数
107#[derive(Debug, Default, Clone)]
108pub struct ProcessorStats {
109    pub total_tasks: u64,
110    pub completed_tasks: u64,
111    pub failed_tasks: u64,
112    pub timeout_tasks: u64,
113    pub cancelled_tasks: u64,
114    pub average_processing_time: Duration,
115    pub current_queue_size: usize,
116    pub current_processing_tasks: usize,
117}
118
119/// 任务处理的结果结构
120/// - task_id: 任务唯一标识符
121/// - status: 任务状态
122/// - task: 原始任务数据
123/// - output: 任务处理输出
124/// - error: 错误信息(如果有)
125/// - processing_time: 任务处理时间
126#[derive(Debug)]
127pub struct TaskResult<T, O>
128where
129    T: Send + Sync,
130    O: Send + Sync,
131{
132    pub task_id: u64,
133    pub status: TaskStatus,
134    pub task: Option<T>,
135    pub output: Option<O>,
136    pub error: Option<String>,
137    pub processing_time: Option<Duration>,
138}
139
140/// 队列中的任务结构
141/// - task: 实际任务数据
142/// - task_id: 任务唯一标识符
143/// - result_tx: 用于发送处理结果的通道发送端
144/// - priority: 任务优先级
145/// - retry_count: 重试次数
146struct QueuedTask<T, O>
147where
148    T: Send + Sync,
149    O: Send + Sync,
150{
151    task: T,
152    task_id: u64,
153    result_tx: mpsc::Sender<TaskResult<T, O>>,
154    priority: u32,
155    retry_count: u32,
156}
157
158/// 任务队列结构
159/// - queue: 任务发送通道
160/// - queue_rx: 任务接收通道(包装在Arc<Mutex>中以支持共享访问)
161/// - next_task_id: 下一个任务的ID(原子递增)
162/// - stats: 任务处理器统计信息
163pub struct TaskQueue<T, O>
164where
165    T: Send + Sync,
166    O: Send + Sync,
167{
168    queue: mpsc::Sender<QueuedTask<T, O>>,
169    queue_rx: Arc<tokio::sync::Mutex<Option<mpsc::Receiver<QueuedTask<T, O>>>>>,
170    next_task_id: Arc<tokio::sync::Mutex<u64>>,
171    stats: Arc<tokio::sync::Mutex<ProcessorStats>>,
172}
173
174impl<T: Clone + Send + Sync + 'static, O: Clone + Send + Sync + 'static>
175    TaskQueue<T, O>
176{
177    pub fn new(config: &ProcessorConfig) -> Self {
178        let (tx, rx) = mpsc::channel(config.max_queue_size);
179        Self {
180            queue: tx,
181            queue_rx: Arc::new(tokio::sync::Mutex::new(Some(rx))),
182            next_task_id: Arc::new(tokio::sync::Mutex::new(0)),
183            stats: Arc::new(tokio::sync::Mutex::new(ProcessorStats::default())),
184        }
185    }
186
187    pub async fn enqueue_task(
188        &self,
189        task: T,
190        priority: u32,
191    ) -> Result<(u64, mpsc::Receiver<TaskResult<T, O>>), ProcessorError> {
192        let mut task_id = self.next_task_id.lock().await;
193        *task_id += 1;
194        let current_id = *task_id;
195
196        let (result_tx, result_rx) = mpsc::channel(1);
197        let queued_task = QueuedTask {
198            task,
199            task_id: current_id,
200            result_tx,
201            priority,
202            retry_count: 0,
203        };
204
205        self.queue
206            .send(queued_task)
207            .await
208            .map_err(|_| ProcessorError::QueueFull)?;
209
210        let mut stats = self.stats.lock().await;
211        stats.total_tasks += 1;
212        stats.current_queue_size += 1;
213
214        Ok((current_id, result_rx))
215    }
216
217    pub async fn get_next_ready(
218        &self
219    ) -> Option<(T, u64, mpsc::Sender<TaskResult<T, O>>, u32, u32)> {
220        let mut rx_guard = self.queue_rx.lock().await;
221        if let Some(rx) = rx_guard.as_mut() {
222            if let Some(queued) = rx.recv().await {
223                let mut stats = self.stats.lock().await;
224                stats.current_queue_size -= 1;
225                stats.current_processing_tasks += 1;
226                return Some((
227                    queued.task,
228                    queued.task_id,
229                    queued.result_tx,
230                    queued.priority,
231                    queued.retry_count,
232                ));
233            }
234        }
235        None
236    }
237
238    pub async fn get_stats(&self) -> ProcessorStats {
239        self.stats.lock().await.clone()
240    }
241
242    pub async fn update_stats(
243        &self,
244        result: &TaskResult<T, O>,
245    ) {
246        let mut stats = self.stats.lock().await;
247        match result.status {
248            TaskStatus::Completed => {
249                stats.completed_tasks += 1;
250                if let Some(processing_time) = result.processing_time {
251                    stats.average_processing_time =
252                        (stats.average_processing_time + processing_time) / 2;
253                }
254            },
255            TaskStatus::Failed(_) => stats.failed_tasks += 1,
256            TaskStatus::Timeout => stats.timeout_tasks += 1,
257            TaskStatus::Cancelled => stats.cancelled_tasks += 1,
258            _ => {},
259        }
260        stats.current_processing_tasks -= 1;
261    }
262}
263
264/// 任务处理器特征
265/// 定义了处理任务的基本接口
266#[async_trait]
267pub trait TaskProcessor<T, O>: Send + Sync + 'static
268where
269    T: Clone + Send + Sync + 'static,
270    O: Clone + Send + Sync + 'static,
271{
272    async fn process(
273        &self,
274        task: T,
275    ) -> Result<O, ProcessorError>;
276}
277
278/// 异步任务处理器
279/// 负责管理任务队列、并发处理和任务生命周期
280/// - T: 任务类型
281/// - O: 任务输出类型
282/// - P: 任务处理器实现
283pub struct AsyncProcessor<T, O, P>
284where
285    T: Clone + Send + Sync + 'static,
286    O: Clone + Send + Sync + 'static,
287    P: TaskProcessor<T, O>,
288{
289    task_queue: Arc<TaskQueue<T, O>>,
290    config: ProcessorConfig,
291    processor: Arc<P>,
292    shutdown_tx: Option<oneshot::Sender<()>>,
293    handle: Option<tokio::task::JoinHandle<()>>,
294}
295
296impl<T, O, P> AsyncProcessor<T, O, P>
297where
298    T: Clone + Send + Sync + 'static,
299    O: Clone + Send + Sync + 'static,
300    P: TaskProcessor<T, O>,
301{
302    /// 创建新的异步任务处理器
303    pub fn new(
304        config: ProcessorConfig,
305        processor: P,
306    ) -> Self {
307        let task_queue = Arc::new(TaskQueue::new(&config));
308        Self {
309            task_queue,
310            config,
311            processor: Arc::new(processor),
312            shutdown_tx: None,
313            handle: None,
314        }
315    }
316
317    /// 提交新任务到处理器
318    /// 返回任务ID和用于接收处理结果的通道
319    pub async fn submit_task(
320        &self,
321        task: T,
322        priority: u32,
323    ) -> Result<(u64, mpsc::Receiver<TaskResult<T, O>>), ProcessorError> {
324        self.task_queue.enqueue_task(task, priority).await
325    }
326
327    /// 启动任务处理器
328    /// 创建后台任务来处理队列中的任务
329    pub fn start(&mut self) {
330        let queue = self.task_queue.clone();
331        let processor = self.processor.clone();
332        let config = self.config.clone();
333        let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
334
335        self.shutdown_tx = Some(shutdown_tx);
336
337        let handle = tokio::spawn(async move {
338            let mut join_set = tokio::task::JoinSet::new();
339
340            loop {
341                select! {
342                    // 处理关闭信号
343                    _ = &mut shutdown_rx => {
344                        break;
345                    }
346
347                    // 处理任务完成
348                    Some(result) = join_set.join_next() => {
349                        if let Err(e) = result {
350                            debug!("Task failed: {}", e);
351                        }
352                    }
353
354                    // 获取新任务并处理
355                    Some((task, task_id, result_tx, _priority, retry_count)) = queue.get_next_ready() => {
356                        if join_set.len() < config.max_concurrent_tasks {
357                            let processor = processor.clone();
358                            let config = config.clone();
359                            let queue = queue.clone();
360
361                            join_set.spawn(async move {
362                                let start_time = Instant::now();
363                                let mut current_retry = retry_count;
364
365                                loop {
366                                    let result = tokio::time::timeout(
367                                        config.task_timeout,
368                                        processor.process(task.clone())
369                                    ).await;
370
371                                    match result {
372                                        Ok(Ok(output)) => {
373                                            let processing_time = start_time.elapsed();
374                                            let task_result = TaskResult {
375                                                task_id,
376                                                status: TaskStatus::Completed,
377                                                task: Some(task),
378                                                output: Some(output),
379                                                error: None,
380                                                processing_time: Some(processing_time),
381                                            };
382                                            queue.update_stats(&task_result).await;
383                                            let _ = result_tx.send(task_result).await;
384                                            break;
385                                        }
386                                        Ok(Err(e)) => {
387                                            if current_retry < config.max_retries {
388                                                current_retry += 1;
389                                                tokio::time::sleep(config.retry_delay).await;
390                                                continue;
391                                            }
392                                            let task_result = TaskResult {
393                                                task_id,
394                                                status: TaskStatus::Failed(e.to_string()),
395                                                task: Some(task),
396                                                output: None,
397                                                error: Some(e.to_string()),
398                                                processing_time: Some(start_time.elapsed()),
399                                            };
400                                            queue.update_stats(&task_result).await;
401                                            let _ = result_tx.send(task_result).await;
402                                            break;
403                                        }
404                                        Err(_) => {
405                                            let task_result = TaskResult {
406                                                task_id,
407                                                status: TaskStatus::Timeout,
408                                                task: Some(task),
409                                                output: None,
410                                                error: Some("Task execution timed out".to_string()),
411                                                processing_time: Some(start_time.elapsed()),
412                                            };
413                                            queue.update_stats(&task_result).await;
414                                            let _ = result_tx.send(task_result).await;
415                                            break;
416                                        }
417                                    }
418                                }
419                            });
420                        }
421                    }
422                }
423            }
424        });
425
426        self.handle = Some(handle);
427    }
428
429    /// 优雅地关闭处理器
430    /// 等待所有正在处理的任务完成后再关闭
431    pub async fn shutdown(&mut self) -> Result<(), ProcessorError> {
432        if let Some(shutdown_tx) = self.shutdown_tx.take() {
433            shutdown_tx.send(()).map_err(|_| {
434                ProcessorError::InternalError(
435                    "Failed to send shutdown signal".to_string(),
436                )
437            })?;
438
439            if let Some(handle) = self.handle.take() {
440                handle.await.map_err(|e| {
441                    ProcessorError::InternalError(format!(
442                        "Failed to join processor task: {}",
443                        e
444                    ))
445                })?;
446            }
447        }
448        Ok(())
449    }
450
451    pub async fn get_stats(&self) -> ProcessorStats {
452        self.task_queue.get_stats().await
453    }
454}
455
456/// 实现Drop特征,确保处理器在销毁时能够优雅关闭
457impl<T, O, P> Drop for AsyncProcessor<T, O, P>
458where
459    T: Clone + Send + Sync + 'static,
460    O: Clone + Send + Sync + 'static,
461    P: TaskProcessor<T, O>,
462{
463    fn drop(&mut self) {
464        if self.shutdown_tx.is_some() {
465            // 创建一个新的运行时来处理异步关闭
466            let rt = tokio::runtime::Runtime::new().unwrap();
467            rt.block_on(self.shutdown()).unwrap();
468        }
469    }
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475
476    struct TestProcessor;
477
478    #[async_trait::async_trait]
479    impl TaskProcessor<i32, String> for TestProcessor {
480        async fn process(
481            &self,
482            task: i32,
483        ) -> Result<String, ProcessorError> {
484            tokio::time::sleep(Duration::from_millis(100)).await;
485            Ok(format!("Processed: {}", task))
486        }
487    }
488
489    #[tokio::test]
490    async fn test_async_processor() {
491        let config = ProcessorConfig {
492            max_queue_size: 100,
493            max_concurrent_tasks: 5,
494            task_timeout: Duration::from_secs(1),
495            max_retries: 3,
496            retry_delay: Duration::from_secs(1),
497        };
498        let mut processor = AsyncProcessor::new(config, TestProcessor);
499        processor.start();
500
501        let mut receivers = Vec::new();
502        for i in 0..10 {
503            let (_, rx) = processor.submit_task(i, 0).await.unwrap();
504            receivers.push(rx);
505        }
506
507        for mut rx in receivers {
508            let result = rx.recv().await.unwrap();
509            assert_eq!(result.status, TaskStatus::Completed);
510            assert!(result.error.is_none());
511            assert!(result.output.is_some());
512        }
513    }
514
515    #[tokio::test]
516    async fn test_processor_shutdown() {
517        let config = ProcessorConfig {
518            max_queue_size: 100,
519            max_concurrent_tasks: 5,
520            task_timeout: Duration::from_secs(1),
521            max_retries: 3,
522            retry_delay: Duration::from_secs(1),
523        };
524        let mut processor = AsyncProcessor::new(config, TestProcessor);
525        processor.start();
526
527        // Submit some tasks
528        let mut receivers = Vec::new();
529        for i in 0..5 {
530            let (_, rx) = processor.submit_task(i, 0).await.unwrap();
531            receivers.push(rx);
532        }
533
534        // Initiate shutdown
535        processor.shutdown().await.unwrap();
536
537        // Verify all tasks completed
538        for mut rx in receivers {
539            let result = rx.recv().await.unwrap();
540            assert_eq!(result.status, TaskStatus::Completed);
541        }
542    }
543
544    #[tokio::test]
545    async fn test_processor_auto_shutdown() {
546        let config = ProcessorConfig {
547            max_queue_size: 100,
548            max_concurrent_tasks: 5,
549            task_timeout: Duration::from_secs(1),
550            max_retries: 3,
551            retry_delay: Duration::from_secs(1),
552        };
553        let mut processor = AsyncProcessor::new(config, TestProcessor);
554        processor.start();
555
556        // Submit some tasks
557        let mut receivers = Vec::new();
558        for i in 0..5 {
559            let (_, rx) = processor.submit_task(i, 0).await.unwrap();
560            receivers.push(rx);
561        }
562
563        // Drop the processor, which should trigger shutdown
564        drop(processor);
565
566        // Verify all tasks completed
567        for mut rx in receivers {
568            let result = rx.recv().await.unwrap();
569            assert_eq!(result.status, TaskStatus::Completed);
570        }
571    }
572}