1use std::{
2 fmt::Display,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6use crate::{error::error_utils, config::ProcessorConfig};
7use mf_state::debug;
8use tokio::sync::{mpsc, oneshot};
9use async_trait::async_trait;
10use tokio::select;
11
12use crate::{metrics, ForgeResult};
13
14#[derive(Debug, Clone, PartialEq)]
22pub enum TaskStatus {
23 Pending,
24 Processing,
25 Completed,
26 Failed(String),
27 Timeout,
28 Cancelled,
29}
30
31impl From<&TaskStatus> for &'static str {
32 fn from(status: &TaskStatus) -> Self {
33 match status {
34 TaskStatus::Pending => "pending",
35 TaskStatus::Processing => "processing",
36 TaskStatus::Completed => "completed",
37 TaskStatus::Failed(_) => "failed",
38 TaskStatus::Timeout => "timeout",
39 TaskStatus::Cancelled => "cancelled",
40 }
41 }
42}
43
44#[derive(Debug)]
52pub enum ProcessorError {
53 QueueFull,
54 TaskFailed(String),
55 InternalError(String),
56 TaskTimeout,
57 TaskCancelled,
58 RetryExhausted(String),
59}
60
61impl Display for ProcessorError {
62 fn fmt(
63 &self,
64 f: &mut std::fmt::Formatter<'_>,
65 ) -> std::fmt::Result {
66 match self {
67 ProcessorError::QueueFull => write!(f, "任务队列已满"),
68 ProcessorError::TaskFailed(msg) => {
69 write!(f, "任务执行失败: {}", msg)
70 },
71 ProcessorError::InternalError(msg) => {
72 write!(f, "内部错误: {}", msg)
73 },
74 ProcessorError::TaskTimeout => {
75 write!(f, "任务执行超时")
76 },
77 ProcessorError::TaskCancelled => write!(f, "任务被取消"),
78 ProcessorError::RetryExhausted(msg) => {
79 write!(f, "重试次数耗尽: {}", msg)
80 },
81 }
82 }
83}
84
85impl std::error::Error for ProcessorError {}
86
87#[derive(Debug, Default, Clone)]
98pub struct ProcessorStats {
99 pub total_tasks: u64,
100 pub completed_tasks: u64,
101 pub failed_tasks: u64,
102 pub timeout_tasks: u64,
103 pub cancelled_tasks: u64,
104 pub current_queue_size: usize,
105 pub current_processing_tasks: usize,
106}
107
108#[derive(Debug)]
116pub struct TaskResult<T, O>
117where
118 T: Send + Sync,
119 O: Send + Sync,
120{
121 pub task_id: u64,
122 pub status: TaskStatus,
123 pub task: Option<T>,
124 pub output: Option<O>,
125 pub error: Option<String>,
126 pub processing_time: Option<Duration>,
127}
128
129struct QueuedTask<T, O>
136where
137 T: Send + Sync,
138 O: Send + Sync,
139{
140 task: T,
141 task_id: u64,
142 result_tx: mpsc::Sender<TaskResult<T, O>>,
143 priority: u32,
144 retry_count: u32,
145}
146
147pub struct TaskQueue<T, O>
153where
154 T: Send + Sync,
155 O: Send + Sync,
156{
157 queue: mpsc::Sender<QueuedTask<T, O>>,
158 queue_rx: Arc<tokio::sync::Mutex<Option<mpsc::Receiver<QueuedTask<T, O>>>>>,
159 next_task_id: Arc<tokio::sync::Mutex<u64>>,
160 stats: Arc<tokio::sync::Mutex<ProcessorStats>>,
161}
162
163impl<T: Clone + Send + Sync + 'static, O: Clone + Send + Sync + 'static>
164 TaskQueue<T, O>
165{
166 pub fn new(config: &ProcessorConfig) -> Self {
167 let (tx, rx) = mpsc::channel(config.max_queue_size);
168 Self {
169 queue: tx,
170 queue_rx: Arc::new(tokio::sync::Mutex::new(Some(rx))),
171 next_task_id: Arc::new(tokio::sync::Mutex::new(0)),
172 stats: Arc::new(tokio::sync::Mutex::new(ProcessorStats::default())),
173 }
174 }
175
176 pub async fn enqueue_task(
177 &self,
178 task: T,
179 priority: u32,
180 ) -> ForgeResult<(u64, mpsc::Receiver<TaskResult<T, O>>)> {
181 let mut task_id = self.next_task_id.lock().await;
182 *task_id += 1;
183 let current_id = *task_id;
184
185 let (result_tx, result_rx) = mpsc::channel(1);
186 let queued_task = QueuedTask {
187 task,
188 task_id: current_id,
189 result_tx,
190 priority,
191 retry_count: 0,
192 };
193
194 self.queue
195 .send(queued_task)
196 .await
197 .map_err(|_| error_utils::resource_exhausted_error("任务队列"))?;
198
199 let mut stats = self.stats.lock().await;
200 stats.total_tasks += 1;
201 stats.current_queue_size += 1;
202
203 metrics::task_submitted();
204 metrics::set_queue_size(stats.current_queue_size);
205
206 Ok((current_id, result_rx))
207 }
208
209 pub async fn get_next_ready(
210 &self
211 ) -> Option<(T, u64, mpsc::Sender<TaskResult<T, O>>, u32, u32)> {
212 let mut rx_guard = self.queue_rx.lock().await;
213 if let Some(rx) = rx_guard.as_mut() {
214 if let Some(queued) = rx.recv().await {
215 let mut stats: tokio::sync::MutexGuard<'_, ProcessorStats> =
216 self.stats.lock().await;
217 stats.current_queue_size -= 1;
218 stats.current_processing_tasks += 1;
219 metrics::set_queue_size(stats.current_queue_size);
220 metrics::increment_processing_tasks();
221 return Some((
222 queued.task,
223 queued.task_id,
224 queued.result_tx,
225 queued.priority,
226 queued.retry_count,
227 ));
228 }
229 }
230 None
231 }
232
233 pub async fn get_stats(&self) -> ProcessorStats {
234 self.stats.lock().await.clone()
235 }
236
237 pub async fn update_stats(
238 &self,
239 result: &TaskResult<T, O>,
240 ) {
241 let mut stats = self.stats.lock().await;
242 stats.current_processing_tasks -= 1;
243 metrics::decrement_processing_tasks();
244
245 let status_str: &'static str = (&result.status).into();
246 metrics::task_processed(status_str);
247
248 if let Some(duration) = result.processing_time {
249 metrics::task_processing_duration(duration);
250 }
251
252 match result.status {
253 TaskStatus::Completed => {
254 stats.completed_tasks += 1;
255 },
256 TaskStatus::Failed(_) => stats.failed_tasks += 1,
257 TaskStatus::Timeout => stats.timeout_tasks += 1,
258 TaskStatus::Cancelled => stats.cancelled_tasks += 1,
259 _ => {},
260 }
261 }
262}
263
264#[async_trait]
267pub trait TaskProcessor<T, O>: Send + Sync + 'static
268where
269 T: Clone + Send + Sync + 'static,
270 O: Clone + Send + Sync + 'static,
271{
272 async fn process(
273 &self,
274 task: T,
275 ) -> Result<O, ProcessorError>;
276}
277
278#[derive(Debug, Clone, PartialEq)]
280pub enum ProcessorState {
281 NotStarted,
283 Running,
285 Shutting,
287 Shutdown,
289}
290
291pub struct AsyncProcessor<T, O, P>
297where
298 T: Clone + Send + Sync + 'static,
299 O: Clone + Send + Sync + 'static,
300 P: TaskProcessor<T, O>,
301{
302 task_queue: Arc<TaskQueue<T, O>>,
303 config: ProcessorConfig,
304 processor: Arc<P>,
305 shutdown_tx: Option<oneshot::Sender<()>>,
306 handle: Option<tokio::task::JoinHandle<()>>,
307 state: Arc<tokio::sync::Mutex<ProcessorState>>,
308}
309
310impl<T, O, P> AsyncProcessor<T, O, P>
311where
312 T: Clone + Send + Sync + 'static,
313 O: Clone + Send + Sync + 'static,
314 P: TaskProcessor<T, O>,
315{
316 pub fn new(
318 config: ProcessorConfig,
319 processor: P,
320 ) -> Self {
321 let task_queue = Arc::new(TaskQueue::new(&config));
322 Self {
323 task_queue,
324 config,
325 processor: Arc::new(processor),
326 shutdown_tx: None,
327 handle: None,
328 state: Arc::new(tokio::sync::Mutex::new(
329 ProcessorState::NotStarted,
330 )),
331 }
332 }
333
334 pub async fn submit_task(
337 &self,
338 task: T,
339 priority: u32,
340 ) -> ForgeResult<(u64, mpsc::Receiver<TaskResult<T, O>>)> {
341 self.task_queue.enqueue_task(task, priority).await
342 }
343
344 pub async fn start(&mut self) -> Result<(), ProcessorError> {
347 let mut state = self.state.lock().await;
348 if *state != ProcessorState::NotStarted {
349 return Err(ProcessorError::InternalError(
350 "处理器已经启动或正在关闭".to_string(),
351 ));
352 }
353 *state = ProcessorState::Running;
354 drop(state);
355
356 let queue = self.task_queue.clone();
357 let processor = self.processor.clone();
358 let config = self.config.clone();
359 let state_ref = self.state.clone();
360 let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
361
362 self.shutdown_tx = Some(shutdown_tx);
363
364 let handle = tokio::spawn(async move {
365 let mut join_set = tokio::task::JoinSet::new();
366
367 async fn cleanup_tasks(
369 join_set: &mut tokio::task::JoinSet<()>,
370 timeout: Duration,
371 ) {
372 debug!("开始清理正在运行的任务...");
373
374 let cleanup_start = Instant::now();
376 while !join_set.is_empty() {
377 if cleanup_start.elapsed() > timeout {
378 debug!("清理超时,强制中止剩余任务");
379 join_set.abort_all();
380 break;
381 }
382
383 if let Some(result) = join_set.join_next().await {
384 if let Err(e) = result {
385 if !e.is_cancelled() {
386 debug!("任务执行失败: {}", e);
387 }
388 }
389 }
390 }
391 debug!("任务清理完成");
392 }
393
394 loop {
395 select! {
396 _ = &mut shutdown_rx => {
398 debug!("收到关闭信号,开始优雅关闭");
399 {
401 let mut state = state_ref.lock().await;
402 *state = ProcessorState::Shutting;
403 }
404
405 cleanup_tasks(&mut join_set, Duration::from_secs(30)).await;
407 break;
408 }
409
410 Some(result) = join_set.join_next() => {
412 if let Err(e) = result {
413 if !e.is_cancelled() {
414 debug!("任务执行失败: {}", e);
415 }
416 }
417 }
418
419 Some((task, task_id, result_tx, _priority, retry_count)) = queue.get_next_ready() => {
421 {
423 let state = state_ref.lock().await;
424 if *state != ProcessorState::Running {
425 let task_result = TaskResult {
427 task_id,
428 status: TaskStatus::Cancelled,
429 task: Some(task),
430 output: None,
431 error: Some("处理器正在关闭".to_string()),
432 processing_time: Some(Duration::from_millis(0)),
433 };
434 queue.update_stats(&task_result).await;
435 let _ = result_tx.send(task_result).await;
436 continue;
437 }
438 }
439
440 if join_set.len() < config.max_concurrent_tasks {
441 let processor = processor.clone();
442 let config = config.clone();
443 let queue = queue.clone();
444
445 join_set.spawn(async move {
446 let start_time = Instant::now();
447 let mut current_retry = retry_count;
448
449 loop {
450 let result = tokio::time::timeout(
451 config.task_timeout,
452 processor.process(task.clone())
453 ).await;
454
455 match result {
456 Ok(Ok(output)) => {
457 let processing_time = start_time.elapsed();
458 let task_result = TaskResult {
459 task_id,
460 status: TaskStatus::Completed,
461 task: Some(task),
462 output: Some(output),
463 error: None,
464 processing_time: Some(processing_time),
465 };
466 queue.update_stats(&task_result).await;
467 let _ = result_tx.send(task_result).await;
468 break;
469 }
470 Ok(Err(e)) => {
471 if current_retry < config.max_retries {
472 current_retry += 1;
473 tokio::time::sleep(config.retry_delay).await;
474 continue;
475 }
476 let task_result = TaskResult {
477 task_id,
478 status: TaskStatus::Failed(e.to_string()),
479 task: Some(task),
480 output: None,
481 error: Some(e.to_string()),
482 processing_time: Some(start_time.elapsed()),
483 };
484 queue.update_stats(&task_result).await;
485 let _ = result_tx.send(task_result).await;
486 break;
487 }
488 Err(_) => {
489 let task_result = TaskResult {
490 task_id,
491 status: TaskStatus::Timeout,
492 task: Some(task),
493 output: None,
494 error: Some("任务执行超时".to_string()),
495 processing_time: Some(start_time.elapsed()),
496 };
497 queue.update_stats(&task_result).await;
498 let _ = result_tx.send(task_result).await;
499 break;
500 }
501 }
502 }
503 });
504 }
505 }
506 }
507 }
508
509 {
511 let mut state = state_ref.lock().await;
512 *state = ProcessorState::Shutdown;
513 }
514 debug!("异步处理器已完全关闭");
515 });
516
517 self.handle = Some(handle);
518 Ok(())
519 }
520
521 pub async fn shutdown(&mut self) -> Result<(), ProcessorError> {
524 {
526 let mut state = self.state.lock().await;
527 match *state {
528 ProcessorState::NotStarted => {
529 return Err(ProcessorError::InternalError(
530 "处理器尚未启动".to_string(),
531 ));
532 },
533 ProcessorState::Shutdown => {
534 return Ok(()); },
536 ProcessorState::Shutting => {
537 },
539 ProcessorState::Running => {
540 *state = ProcessorState::Shutting;
541 },
542 }
543 }
544
545 if let Some(shutdown_tx) = self.shutdown_tx.take() {
547 shutdown_tx.send(()).map_err(|_| {
548 ProcessorError::InternalError(
549 "Failed to send shutdown signal".to_string(),
550 )
551 })?;
552 }
553
554 if let Some(handle) = self.handle.take() {
556 if let Err(e) = handle.await {
557 return Err(ProcessorError::InternalError(format!(
558 "等待后台任务完成时出错: {}",
559 e
560 )));
561 }
562 }
563
564 {
566 let state = self.state.lock().await;
567 if *state != ProcessorState::Shutdown {
568 return Err(ProcessorError::InternalError(
569 "关闭过程未正确完成".to_string(),
570 ));
571 }
572 }
573
574 debug!("异步处理器已成功关闭");
575 Ok(())
576 }
577
578 pub async fn get_state(&self) -> ProcessorState {
580 let state = self.state.lock().await;
581 state.clone()
582 }
583
584 pub async fn is_running(&self) -> bool {
586 let state = self.state.lock().await;
587 *state == ProcessorState::Running
588 }
589
590 pub async fn get_stats(&self) -> ProcessorStats {
591 self.task_queue.get_stats().await
592 }
593}
594
595impl<T, O, P> Drop for AsyncProcessor<T, O, P>
600where
601 T: Clone + Send + Sync + 'static,
602 O: Clone + Send + Sync + 'static,
603 P: TaskProcessor<T, O>,
604{
605 fn drop(&mut self) {
606 if let Some(shutdown_tx) = self.shutdown_tx.take() {
608 let _ = shutdown_tx.send(());
609 debug!("AsyncProcessor Drop: 已发送关闭信号");
610 }
611
612 if let Some(handle) = self.handle.take() {
614 handle.abort();
615 debug!("AsyncProcessor Drop: 已中止后台任务");
616 }
617 }
618}
619
620#[cfg(test)]
621mod tests {
622 use super::*;
623
624 struct TestProcessor;
625
626 #[async_trait::async_trait]
627 impl TaskProcessor<i32, String> for TestProcessor {
628 async fn process(
629 &self,
630 task: i32,
631 ) -> Result<String, ProcessorError> {
632 tokio::time::sleep(Duration::from_millis(100)).await;
633 Ok(format!("Processed: {}", task))
634 }
635 }
636
637 #[tokio::test]
638 async fn test_async_processor() {
639 let config = ProcessorConfig {
640 max_queue_size: 100,
641 max_concurrent_tasks: 5,
642 task_timeout: Duration::from_secs(1),
643 max_retries: 3,
644 retry_delay: Duration::from_secs(1),
645 cleanup_timeout: Duration::from_secs(10),
646 };
647 let mut processor = AsyncProcessor::new(config, TestProcessor);
648 processor.start().await.unwrap();
649
650 let mut receivers = Vec::new();
651 for i in 0..10 {
652 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
653 receivers.push(rx);
654 }
655
656 for mut rx in receivers {
657 let result = rx.recv().await.unwrap();
658 assert_eq!(result.status, TaskStatus::Completed);
659 assert!(result.error.is_none());
660 assert!(result.output.is_some());
661 }
662
663 processor.shutdown().await.unwrap();
665 }
666
667 #[tokio::test]
668 async fn test_processor_shutdown() {
669 let config = ProcessorConfig {
670 max_queue_size: 100,
671 max_concurrent_tasks: 5,
672 task_timeout: Duration::from_secs(1),
673 max_retries: 3,
674 retry_delay: Duration::from_secs(1),
675 cleanup_timeout: Duration::from_secs(10),
676 };
677 let mut processor = AsyncProcessor::new(config, TestProcessor);
678 processor.start().await.unwrap();
679
680 let mut receivers = Vec::new();
682 for i in 0..5 {
683 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
684 receivers.push(rx);
685 }
686
687 processor.shutdown().await.unwrap();
689
690 for mut rx in receivers {
692 let result = rx.recv().await.unwrap();
693 assert_eq!(result.status, TaskStatus::Completed);
694 }
695 }
696
697 #[tokio::test]
698 async fn test_processor_auto_shutdown() {
699 let config = ProcessorConfig {
700 max_queue_size: 100,
701 max_concurrent_tasks: 5,
702 task_timeout: Duration::from_secs(1),
703 max_retries: 3,
704 retry_delay: Duration::from_secs(1),
705 cleanup_timeout: Duration::from_secs(10),
706 };
707 let mut processor = AsyncProcessor::new(config, TestProcessor);
708 processor.start().await.unwrap();
709
710 let mut receivers = Vec::new();
712 for i in 0..5 {
713 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
714 receivers.push(rx);
715 }
716
717 drop(processor);
719
720 for mut rx in receivers {
722 let result = rx.recv().await.unwrap();
723 assert!(matches!(
725 result.status,
726 TaskStatus::Completed | TaskStatus::Cancelled
727 ));
728 }
729 }
730}