1use std::{
2 fmt::Display,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6use moduforge_state::debug;
7use tokio::sync::{mpsc, oneshot};
8use async_trait::async_trait;
9use tokio::select;
10
11#[derive(Debug, Clone, PartialEq)]
19pub enum TaskStatus {
20 Pending,
21 Processing,
22 Completed,
23 Failed(String),
24 Timeout,
25 Cancelled,
26}
27
28#[derive(Debug)]
36pub enum ProcessorError {
37 QueueFull,
38 TaskFailed(String),
39 InternalError(String),
40 TaskTimeout,
41 TaskCancelled,
42 RetryExhausted(String),
43}
44
45impl Display for ProcessorError {
46 fn fmt(
47 &self,
48 f: &mut std::fmt::Formatter<'_>,
49 ) -> std::fmt::Result {
50 match self {
51 ProcessorError::QueueFull => write!(f, "任务队列已满"),
52 ProcessorError::TaskFailed(msg) => {
53 write!(f, "任务执行失败: {}", msg)
54 },
55 ProcessorError::InternalError(msg) => {
56 write!(f, "内部错误: {}", msg)
57 },
58 ProcessorError::TaskTimeout => {
59 write!(f, "任务执行超时")
60 },
61 ProcessorError::TaskCancelled => write!(f, "任务被取消"),
62 ProcessorError::RetryExhausted(msg) => {
63 write!(f, "重试次数耗尽: {}", msg)
64 },
65 }
66 }
67}
68
69impl std::error::Error for ProcessorError {}
70
71#[derive(Clone, Debug)]
78pub struct ProcessorConfig {
79 pub max_queue_size: usize,
80 pub max_concurrent_tasks: usize,
81 pub task_timeout: Duration,
82 pub max_retries: u32,
83 pub retry_delay: Duration,
84}
85
86impl Default for ProcessorConfig {
87 fn default() -> Self {
88 Self {
89 max_queue_size: 1000,
90 max_concurrent_tasks: 10,
91 task_timeout: Duration::from_secs(30),
92 max_retries: 3,
93 retry_delay: Duration::from_secs(1),
94 }
95 }
96}
97
98#[derive(Debug, Default, Clone)]
108pub struct ProcessorStats {
109 pub total_tasks: u64,
110 pub completed_tasks: u64,
111 pub failed_tasks: u64,
112 pub timeout_tasks: u64,
113 pub cancelled_tasks: u64,
114 pub average_processing_time: Duration,
115 pub current_queue_size: usize,
116 pub current_processing_tasks: usize,
117}
118
119#[derive(Debug)]
127pub struct TaskResult<T, O>
128where
129 T: Send + Sync,
130 O: Send + Sync,
131{
132 pub task_id: u64,
133 pub status: TaskStatus,
134 pub task: Option<T>,
135 pub output: Option<O>,
136 pub error: Option<String>,
137 pub processing_time: Option<Duration>,
138}
139
140struct QueuedTask<T, O>
147where
148 T: Send + Sync,
149 O: Send + Sync,
150{
151 task: T,
152 task_id: u64,
153 result_tx: mpsc::Sender<TaskResult<T, O>>,
154 priority: u32,
155 retry_count: u32,
156}
157
158pub struct TaskQueue<T, O>
164where
165 T: Send + Sync,
166 O: Send + Sync,
167{
168 queue: mpsc::Sender<QueuedTask<T, O>>,
169 queue_rx: Arc<tokio::sync::Mutex<Option<mpsc::Receiver<QueuedTask<T, O>>>>>,
170 next_task_id: Arc<tokio::sync::Mutex<u64>>,
171 stats: Arc<tokio::sync::Mutex<ProcessorStats>>,
172}
173
174impl<T: Clone + Send + Sync + 'static, O: Clone + Send + Sync + 'static>
175 TaskQueue<T, O>
176{
177 pub fn new(config: &ProcessorConfig) -> Self {
178 let (tx, rx) = mpsc::channel(config.max_queue_size);
179 Self {
180 queue: tx,
181 queue_rx: Arc::new(tokio::sync::Mutex::new(Some(rx))),
182 next_task_id: Arc::new(tokio::sync::Mutex::new(0)),
183 stats: Arc::new(tokio::sync::Mutex::new(ProcessorStats::default())),
184 }
185 }
186
187 pub async fn enqueue_task(
188 &self,
189 task: T,
190 priority: u32,
191 ) -> Result<(u64, mpsc::Receiver<TaskResult<T, O>>), ProcessorError> {
192 let mut task_id = self.next_task_id.lock().await;
193 *task_id += 1;
194 let current_id = *task_id;
195
196 let (result_tx, result_rx) = mpsc::channel(1);
197 let queued_task = QueuedTask {
198 task,
199 task_id: current_id,
200 result_tx,
201 priority,
202 retry_count: 0,
203 };
204
205 self.queue
206 .send(queued_task)
207 .await
208 .map_err(|_| ProcessorError::QueueFull)?;
209
210 let mut stats = self.stats.lock().await;
211 stats.total_tasks += 1;
212 stats.current_queue_size += 1;
213
214 Ok((current_id, result_rx))
215 }
216
217 pub async fn get_next_ready(
218 &self
219 ) -> Option<(T, u64, mpsc::Sender<TaskResult<T, O>>, u32, u32)> {
220 let mut rx_guard = self.queue_rx.lock().await;
221 if let Some(rx) = rx_guard.as_mut() {
222 if let Some(queued) = rx.recv().await {
223 let mut stats: tokio::sync::MutexGuard<'_, ProcessorStats> =
224 self.stats.lock().await;
225 stats.current_queue_size -= 1;
226 stats.current_processing_tasks += 1;
227 return Some((
228 queued.task,
229 queued.task_id,
230 queued.result_tx,
231 queued.priority,
232 queued.retry_count,
233 ));
234 }
235 }
236 None
237 }
238
239 pub async fn get_stats(&self) -> ProcessorStats {
240 self.stats.lock().await.clone()
241 }
242
243 pub async fn update_stats(
244 &self,
245 result: &TaskResult<T, O>,
246 ) {
247 let mut stats = self.stats.lock().await;
248 match result.status {
249 TaskStatus::Completed => {
250 stats.completed_tasks += 1;
251 if let Some(processing_time) = result.processing_time {
252 stats.average_processing_time =
253 (stats.average_processing_time + processing_time) / 2;
254 }
255 },
256 TaskStatus::Failed(_) => stats.failed_tasks += 1,
257 TaskStatus::Timeout => stats.timeout_tasks += 1,
258 TaskStatus::Cancelled => stats.cancelled_tasks += 1,
259 _ => {},
260 }
261 stats.current_processing_tasks -= 1;
262 }
263}
264
265#[async_trait]
268pub trait TaskProcessor<T, O>: Send + Sync + 'static
269where
270 T: Clone + Send + Sync + 'static,
271 O: Clone + Send + Sync + 'static,
272{
273 async fn process(
274 &self,
275 task: T,
276 ) -> Result<O, ProcessorError>;
277}
278
279pub struct AsyncProcessor<T, O, P>
285where
286 T: Clone + Send + Sync + 'static,
287 O: Clone + Send + Sync + 'static,
288 P: TaskProcessor<T, O>,
289{
290 task_queue: Arc<TaskQueue<T, O>>,
291 config: ProcessorConfig,
292 processor: Arc<P>,
293 shutdown_tx: Option<oneshot::Sender<()>>,
294 handle: Option<tokio::task::JoinHandle<()>>,
295}
296
297impl<T, O, P> AsyncProcessor<T, O, P>
298where
299 T: Clone + Send + Sync + 'static,
300 O: Clone + Send + Sync + 'static,
301 P: TaskProcessor<T, O>,
302{
303 pub fn new(
305 config: ProcessorConfig,
306 processor: P,
307 ) -> Self {
308 let task_queue = Arc::new(TaskQueue::new(&config));
309 Self {
310 task_queue,
311 config,
312 processor: Arc::new(processor),
313 shutdown_tx: None,
314 handle: None,
315 }
316 }
317
318 pub async fn submit_task(
321 &self,
322 task: T,
323 priority: u32,
324 ) -> Result<(u64, mpsc::Receiver<TaskResult<T, O>>), ProcessorError> {
325 self.task_queue.enqueue_task(task, priority).await
326 }
327
328 pub fn start(&mut self) {
331 let queue = self.task_queue.clone();
332 let processor = self.processor.clone();
333 let config = self.config.clone();
334 let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
335
336 self.shutdown_tx = Some(shutdown_tx);
337
338 let handle = tokio::spawn(async move {
339 let mut join_set = tokio::task::JoinSet::new();
340
341 loop {
342 select! {
343 _ = &mut shutdown_rx => {
345 break;
346 }
347
348 Some(result) = join_set.join_next() => {
350 if let Err(e) = result {
351 debug!("任务执行失败: {}", e);
352 }
353 }
354
355 Some((task, task_id, result_tx, _priority, retry_count)) = queue.get_next_ready() => {
357 if join_set.len() < config.max_concurrent_tasks {
358 let processor = processor.clone();
359 let config = config.clone();
360 let queue = queue.clone();
361
362 join_set.spawn(async move {
363 let start_time = Instant::now();
364 let mut current_retry = retry_count;
365
366 loop {
367 let result = tokio::time::timeout(
368 config.task_timeout,
369 processor.process(task.clone())
370 ).await;
371
372 match result {
373 Ok(Ok(output)) => {
374 let processing_time = start_time.elapsed();
375 let task_result = TaskResult {
376 task_id,
377 status: TaskStatus::Completed,
378 task: Some(task),
379 output: Some(output),
380 error: None,
381 processing_time: Some(processing_time),
382 };
383 queue.update_stats(&task_result).await;
384 let _ = result_tx.send(task_result).await;
385 break;
386 }
387 Ok(Err(e)) => {
388 if current_retry < config.max_retries {
389 current_retry += 1;
390 tokio::time::sleep(config.retry_delay).await;
391 continue;
392 }
393 let task_result = TaskResult {
394 task_id,
395 status: TaskStatus::Failed(e.to_string()),
396 task: Some(task),
397 output: None,
398 error: Some(e.to_string()),
399 processing_time: Some(start_time.elapsed()),
400 };
401 queue.update_stats(&task_result).await;
402 let _ = result_tx.send(task_result).await;
403 break;
404 }
405 Err(_) => {
406 let task_result = TaskResult {
407 task_id,
408 status: TaskStatus::Timeout,
409 task: Some(task),
410 output: None,
411 error: Some("任务执行超时".to_string()),
412 processing_time: Some(start_time.elapsed()),
413 };
414 queue.update_stats(&task_result).await;
415 let _ = result_tx.send(task_result).await;
416 break;
417 }
418 }
419 }
420 });
421 }
422 }
423 }
424 }
425 });
426
427 self.handle = Some(handle);
428 }
429
430 pub async fn shutdown(&mut self) -> Result<(), ProcessorError> {
433 if let Some(shutdown_tx) = self.shutdown_tx.take() {
434 shutdown_tx.send(()).map_err(|_| {
435 ProcessorError::InternalError(
436 "Failed to send shutdown signal".to_string(),
437 )
438 })?;
439
440 if let Some(handle) = self.handle.take() {
441 handle.await.map_err(|e| {
442 ProcessorError::InternalError(format!(
443 "Failed to join processor task: {}",
444 e
445 ))
446 })?;
447 }
448 }
449 Ok(())
450 }
451
452 pub async fn get_stats(&self) -> ProcessorStats {
453 self.task_queue.get_stats().await
454 }
455}
456
457impl<T, O, P> Drop for AsyncProcessor<T, O, P>
459where
460 T: Clone + Send + Sync + 'static,
461 O: Clone + Send + Sync + 'static,
462 P: TaskProcessor<T, O>,
463{
464 fn drop(&mut self) {
465 if self.shutdown_tx.is_some() {
466 let rt = tokio::runtime::Runtime::new().unwrap();
468 rt.block_on(self.shutdown()).unwrap();
469 }
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
476
477 struct TestProcessor;
478
479 #[async_trait::async_trait]
480 impl TaskProcessor<i32, String> for TestProcessor {
481 async fn process(
482 &self,
483 task: i32,
484 ) -> Result<String, ProcessorError> {
485 tokio::time::sleep(Duration::from_millis(100)).await;
486 Ok(format!("Processed: {}", task))
487 }
488 }
489
490 #[tokio::test]
491 async fn test_async_processor() {
492 let config = ProcessorConfig {
493 max_queue_size: 100,
494 max_concurrent_tasks: 5,
495 task_timeout: Duration::from_secs(1),
496 max_retries: 3,
497 retry_delay: Duration::from_secs(1),
498 };
499 let mut processor = AsyncProcessor::new(config, TestProcessor);
500 processor.start();
501
502 let mut receivers = Vec::new();
503 for i in 0..10 {
504 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
505 receivers.push(rx);
506 }
507
508 for mut rx in receivers {
509 let result = rx.recv().await.unwrap();
510 assert_eq!(result.status, TaskStatus::Completed);
511 assert!(result.error.is_none());
512 assert!(result.output.is_some());
513 }
514 }
515
516 #[tokio::test]
517 async fn test_processor_shutdown() {
518 let config = ProcessorConfig {
519 max_queue_size: 100,
520 max_concurrent_tasks: 5,
521 task_timeout: Duration::from_secs(1),
522 max_retries: 3,
523 retry_delay: Duration::from_secs(1),
524 };
525 let mut processor = AsyncProcessor::new(config, TestProcessor);
526 processor.start();
527
528 let mut receivers = Vec::new();
530 for i in 0..5 {
531 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
532 receivers.push(rx);
533 }
534
535 processor.shutdown().await.unwrap();
537
538 for mut rx in receivers {
540 let result = rx.recv().await.unwrap();
541 assert_eq!(result.status, TaskStatus::Completed);
542 }
543 }
544
545 #[tokio::test]
546 async fn test_processor_auto_shutdown() {
547 let config = ProcessorConfig {
548 max_queue_size: 100,
549 max_concurrent_tasks: 5,
550 task_timeout: Duration::from_secs(1),
551 max_retries: 3,
552 retry_delay: Duration::from_secs(1),
553 };
554 let mut processor = AsyncProcessor::new(config, TestProcessor);
555 processor.start();
556
557 let mut receivers = Vec::new();
559 for i in 0..5 {
560 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
561 receivers.push(rx);
562 }
563
564 drop(processor);
566
567 for mut rx in receivers {
569 let result = rx.recv().await.unwrap();
570 assert_eq!(result.status, TaskStatus::Completed);
571 }
572 }
573}