1use std::{
2 fmt::Display,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6use crate::{error::error_utils, config::ProcessorConfig};
7use mf_state::debug;
8use tokio::sync::{mpsc, oneshot};
9use async_trait::async_trait;
10use tokio::select;
11
12use crate::{metrics, ForgeResult};
13
14#[derive(Debug, Clone, PartialEq)]
22pub enum TaskStatus {
23 Pending,
24 Processing,
25 Completed,
26 Failed(String),
27 Timeout,
28 Cancelled,
29}
30
31impl From<&TaskStatus> for &'static str {
32 fn from(status: &TaskStatus) -> Self {
33 match status {
34 TaskStatus::Pending => "pending",
35 TaskStatus::Processing => "processing",
36 TaskStatus::Completed => "completed",
37 TaskStatus::Failed(_) => "failed",
38 TaskStatus::Timeout => "timeout",
39 TaskStatus::Cancelled => "cancelled",
40 }
41 }
42}
43
44#[derive(Debug)]
52pub enum ProcessorError {
53 QueueFull,
54 TaskFailed(String),
55 InternalError(String),
56 TaskTimeout,
57 TaskCancelled,
58 RetryExhausted(String),
59}
60
61impl Display for ProcessorError {
62 fn fmt(
63 &self,
64 f: &mut std::fmt::Formatter<'_>,
65 ) -> std::fmt::Result {
66 match self {
67 ProcessorError::QueueFull => write!(f, "任务队列已满"),
68 ProcessorError::TaskFailed(msg) => {
69 write!(f, "任务执行失败: {}", msg)
70 },
71 ProcessorError::InternalError(msg) => {
72 write!(f, "内部错误: {}", msg)
73 },
74 ProcessorError::TaskTimeout => {
75 write!(f, "任务执行超时")
76 },
77 ProcessorError::TaskCancelled => write!(f, "任务被取消"),
78 ProcessorError::RetryExhausted(msg) => {
79 write!(f, "重试次数耗尽: {}", msg)
80 },
81 }
82 }
83}
84
85impl std::error::Error for ProcessorError {}
86
87#[derive(Debug, Default, Clone)]
98pub struct ProcessorStats {
99 pub total_tasks: u64,
100 pub completed_tasks: u64,
101 pub failed_tasks: u64,
102 pub timeout_tasks: u64,
103 pub cancelled_tasks: u64,
104 pub current_queue_size: usize,
105 pub current_processing_tasks: usize,
106}
107
108#[derive(Debug)]
116pub struct TaskResult<T, O>
117where
118 T: Send + Sync,
119 O: Send + Sync,
120{
121 pub task_id: u64,
122 pub status: TaskStatus,
123 pub task: Option<T>,
124 pub output: Option<O>,
125 pub error: Option<String>,
126 pub processing_time: Option<Duration>,
127}
128
129struct QueuedTask<T, O>
136where
137 T: Send + Sync,
138 O: Send + Sync,
139{
140 task: T,
141 task_id: u64,
142 result_tx: mpsc::Sender<TaskResult<T, O>>,
143 priority: u32,
144 retry_count: u32,
145}
146
147pub struct TaskQueue<T, O>
153where
154 T: Send + Sync,
155 O: Send + Sync,
156{
157 queue: mpsc::Sender<QueuedTask<T, O>>,
158 queue_rx: Arc<tokio::sync::Mutex<Option<mpsc::Receiver<QueuedTask<T, O>>>>>,
159 next_task_id: Arc<tokio::sync::Mutex<u64>>,
160 stats: Arc<tokio::sync::Mutex<ProcessorStats>>,
161}
162
163impl<T: Clone + Send + Sync + 'static, O: Clone + Send + Sync + 'static>
164 TaskQueue<T, O>
165{
166 pub fn new(config: &ProcessorConfig) -> Self {
167 let (tx, rx) = mpsc::channel(config.max_queue_size);
168 Self {
169 queue: tx,
170 queue_rx: Arc::new(tokio::sync::Mutex::new(Some(rx))),
171 next_task_id: Arc::new(tokio::sync::Mutex::new(0)),
172 stats: Arc::new(tokio::sync::Mutex::new(ProcessorStats::default())),
173 }
174 }
175
176 pub async fn enqueue_task(
177 &self,
178 task: T,
179 priority: u32,
180 ) -> ForgeResult<(u64, mpsc::Receiver<TaskResult<T, O>>)> {
181 let mut task_id = self.next_task_id.lock().await;
182 *task_id += 1;
183 let current_id = *task_id;
184
185 let (result_tx, result_rx) = mpsc::channel(1);
186 let queued_task = QueuedTask {
187 task,
188 task_id: current_id,
189 result_tx,
190 priority,
191 retry_count: 0,
192 };
193
194 self.queue
195 .send(queued_task)
196 .await
197 .map_err(|_| error_utils::resource_exhausted_error("任务队列"))?;
198
199 let mut stats = self.stats.lock().await;
200 stats.total_tasks += 1;
201 stats.current_queue_size += 1;
202
203 metrics::task_submitted();
204 metrics::set_queue_size(stats.current_queue_size);
205
206 Ok((current_id, result_rx))
207 }
208
209 pub async fn get_next_ready(
210 &self
211 ) -> Option<(T, u64, mpsc::Sender<TaskResult<T, O>>, u32, u32)> {
212 let mut rx_guard = self.queue_rx.lock().await;
213 if let Some(rx) = rx_guard.as_mut() {
214 if let Some(queued) = rx.recv().await {
215 let mut stats: tokio::sync::MutexGuard<'_, ProcessorStats> =
216 self.stats.lock().await;
217 stats.current_queue_size -= 1;
218 stats.current_processing_tasks += 1;
219 metrics::set_queue_size(stats.current_queue_size);
220 metrics::increment_processing_tasks();
221 return Some((
222 queued.task,
223 queued.task_id,
224 queued.result_tx,
225 queued.priority,
226 queued.retry_count,
227 ));
228 }
229 }
230 None
231 }
232
233 pub async fn get_stats(&self) -> ProcessorStats {
234 self.stats.lock().await.clone()
235 }
236
237 pub async fn update_stats(
238 &self,
239 result: &TaskResult<T, O>,
240 ) {
241 let mut stats = self.stats.lock().await;
242 stats.current_processing_tasks -= 1;
243 metrics::decrement_processing_tasks();
244
245 let status_str: &'static str = (&result.status).into();
246 metrics::task_processed(status_str);
247
248 if let Some(duration) = result.processing_time {
249 metrics::task_processing_duration(duration);
250 }
251
252 match result.status {
253 TaskStatus::Completed => {
254 stats.completed_tasks += 1;
255 },
256 TaskStatus::Failed(_) => stats.failed_tasks += 1,
257 TaskStatus::Timeout => stats.timeout_tasks += 1,
258 TaskStatus::Cancelled => stats.cancelled_tasks += 1,
259 _ => {},
260 }
261 }
262}
263
264#[async_trait]
267pub trait TaskProcessor<T, O>: Send + Sync + 'static
268where
269 T: Clone + Send + Sync + 'static,
270 O: Clone + Send + Sync + 'static,
271{
272 async fn process(
273 &self,
274 task: T,
275 ) -> Result<O, ProcessorError>;
276}
277
278#[derive(Debug, Clone, PartialEq)]
280pub enum ProcessorState {
281 NotStarted,
283 Running,
285 Shutting,
287 Shutdown,
289}
290
291pub struct AsyncProcessor<T, O, P>
297where
298 T: Clone + Send + Sync + 'static,
299 O: Clone + Send + Sync + 'static,
300 P: TaskProcessor<T, O>,
301{
302 task_queue: Arc<TaskQueue<T, O>>,
303 config: ProcessorConfig,
304 processor: Arc<P>,
305 shutdown_tx: Option<oneshot::Sender<()>>,
306 handle: Option<tokio::task::JoinHandle<()>>,
307 state: Arc<tokio::sync::Mutex<ProcessorState>>,
308}
309
310impl<T, O, P> AsyncProcessor<T, O, P>
311where
312 T: Clone + Send + Sync + 'static,
313 O: Clone + Send + Sync + 'static,
314 P: TaskProcessor<T, O>,
315{
316 pub fn new(
318 config: ProcessorConfig,
319 processor: P,
320 ) -> Self {
321 let task_queue = Arc::new(TaskQueue::new(&config));
322 Self {
323 task_queue,
324 config,
325 processor: Arc::new(processor),
326 shutdown_tx: None,
327 handle: None,
328 state: Arc::new(tokio::sync::Mutex::new(ProcessorState::NotStarted)),
329 }
330 }
331
332 pub async fn submit_task(
335 &self,
336 task: T,
337 priority: u32,
338 ) -> ForgeResult<(u64, mpsc::Receiver<TaskResult<T, O>>)> {
339 self.task_queue.enqueue_task(task, priority).await
340 }
341
342 pub async fn start(&mut self) -> Result<(), ProcessorError> {
345 let mut state = self.state.lock().await;
346 if *state != ProcessorState::NotStarted {
347 return Err(ProcessorError::InternalError(
348 "处理器已经启动或正在关闭".to_string(),
349 ));
350 }
351 *state = ProcessorState::Running;
352 drop(state);
353
354 let queue = self.task_queue.clone();
355 let processor = self.processor.clone();
356 let config = self.config.clone();
357 let state_ref = self.state.clone();
358 let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
359
360 self.shutdown_tx = Some(shutdown_tx);
361
362 let handle = tokio::spawn(async move {
363 let mut join_set = tokio::task::JoinSet::new();
364
365 async fn cleanup_tasks(
367 join_set: &mut tokio::task::JoinSet<()>,
368 timeout: Duration,
369 ) {
370 debug!("开始清理正在运行的任务...");
371
372 let cleanup_start = Instant::now();
374 while !join_set.is_empty() {
375 if cleanup_start.elapsed() > timeout {
376 debug!("清理超时,强制中止剩余任务");
377 join_set.abort_all();
378 break;
379 }
380
381 if let Some(result) = join_set.join_next().await {
382 if let Err(e) = result {
383 if !e.is_cancelled() {
384 debug!("任务执行失败: {}", e);
385 }
386 }
387 }
388 }
389 debug!("任务清理完成");
390 }
391
392 loop {
393 select! {
394 _ = &mut shutdown_rx => {
396 debug!("收到关闭信号,开始优雅关闭");
397 {
399 let mut state = state_ref.lock().await;
400 *state = ProcessorState::Shutting;
401 }
402
403 cleanup_tasks(&mut join_set, Duration::from_secs(30)).await;
405 break;
406 }
407
408 Some(result) = join_set.join_next() => {
410 if let Err(e) = result {
411 if !e.is_cancelled() {
412 debug!("任务执行失败: {}", e);
413 }
414 }
415 }
416
417 Some((task, task_id, result_tx, _priority, retry_count)) = queue.get_next_ready() => {
419 {
421 let state = state_ref.lock().await;
422 if *state != ProcessorState::Running {
423 let task_result = TaskResult {
425 task_id,
426 status: TaskStatus::Cancelled,
427 task: Some(task),
428 output: None,
429 error: Some("处理器正在关闭".to_string()),
430 processing_time: Some(Duration::from_millis(0)),
431 };
432 queue.update_stats(&task_result).await;
433 let _ = result_tx.send(task_result).await;
434 continue;
435 }
436 }
437
438 if join_set.len() < config.max_concurrent_tasks {
439 let processor = processor.clone();
440 let config = config.clone();
441 let queue = queue.clone();
442
443 join_set.spawn(async move {
444 let start_time = Instant::now();
445 let mut current_retry = retry_count;
446
447 loop {
448 let result = tokio::time::timeout(
449 config.task_timeout,
450 processor.process(task.clone())
451 ).await;
452
453 match result {
454 Ok(Ok(output)) => {
455 let processing_time = start_time.elapsed();
456 let task_result = TaskResult {
457 task_id,
458 status: TaskStatus::Completed,
459 task: Some(task),
460 output: Some(output),
461 error: None,
462 processing_time: Some(processing_time),
463 };
464 queue.update_stats(&task_result).await;
465 let _ = result_tx.send(task_result).await;
466 break;
467 }
468 Ok(Err(e)) => {
469 if current_retry < config.max_retries {
470 current_retry += 1;
471 tokio::time::sleep(config.retry_delay).await;
472 continue;
473 }
474 let task_result = TaskResult {
475 task_id,
476 status: TaskStatus::Failed(e.to_string()),
477 task: Some(task),
478 output: None,
479 error: Some(e.to_string()),
480 processing_time: Some(start_time.elapsed()),
481 };
482 queue.update_stats(&task_result).await;
483 let _ = result_tx.send(task_result).await;
484 break;
485 }
486 Err(_) => {
487 let task_result = TaskResult {
488 task_id,
489 status: TaskStatus::Timeout,
490 task: Some(task),
491 output: None,
492 error: Some("任务执行超时".to_string()),
493 processing_time: Some(start_time.elapsed()),
494 };
495 queue.update_stats(&task_result).await;
496 let _ = result_tx.send(task_result).await;
497 break;
498 }
499 }
500 }
501 });
502 }
503 }
504 }
505 }
506
507 {
509 let mut state = state_ref.lock().await;
510 *state = ProcessorState::Shutdown;
511 }
512 debug!("异步处理器已完全关闭");
513 });
514
515 self.handle = Some(handle);
516 Ok(())
517 }
518
519 pub async fn shutdown(&mut self) -> Result<(), ProcessorError> {
522 {
524 let mut state = self.state.lock().await;
525 match *state {
526 ProcessorState::NotStarted => {
527 return Err(ProcessorError::InternalError(
528 "处理器尚未启动".to_string(),
529 ));
530 }
531 ProcessorState::Shutdown => {
532 return Ok(()); }
534 ProcessorState::Shutting => {
535 }
537 ProcessorState::Running => {
538 *state = ProcessorState::Shutting;
539 }
540 }
541 }
542
543 if let Some(shutdown_tx) = self.shutdown_tx.take() {
545 shutdown_tx.send(()).map_err(|_| {
546 ProcessorError::InternalError(
547 "Failed to send shutdown signal".to_string(),
548 )
549 })?;
550 }
551
552 if let Some(handle) = self.handle.take() {
554 if let Err(e) = handle.await {
555 return Err(ProcessorError::InternalError(format!(
556 "等待后台任务完成时出错: {}",
557 e
558 )));
559 }
560 }
561
562 {
564 let state = self.state.lock().await;
565 if *state != ProcessorState::Shutdown {
566 return Err(ProcessorError::InternalError(
567 "关闭过程未正确完成".to_string(),
568 ));
569 }
570 }
571
572 debug!("异步处理器已成功关闭");
573 Ok(())
574 }
575
576 pub async fn get_state(&self) -> ProcessorState {
578 let state = self.state.lock().await;
579 state.clone()
580 }
581
582 pub async fn is_running(&self) -> bool {
584 let state = self.state.lock().await;
585 *state == ProcessorState::Running
586 }
587
588 pub async fn get_stats(&self) -> ProcessorStats {
589 self.task_queue.get_stats().await
590 }
591}
592
593impl<T, O, P> Drop for AsyncProcessor<T, O, P>
598where
599 T: Clone + Send + Sync + 'static,
600 O: Clone + Send + Sync + 'static,
601 P: TaskProcessor<T, O>,
602{
603 fn drop(&mut self) {
604 if let Some(shutdown_tx) = self.shutdown_tx.take() {
606 let _ = shutdown_tx.send(());
607 debug!("AsyncProcessor Drop: 已发送关闭信号");
608 }
609
610 if let Some(handle) = self.handle.take() {
612 handle.abort();
613 debug!("AsyncProcessor Drop: 已中止后台任务");
614 }
615 }
616}
617
618#[cfg(test)]
619mod tests {
620 use super::*;
621
622 struct TestProcessor;
623
624 #[async_trait::async_trait]
625 impl TaskProcessor<i32, String> for TestProcessor {
626 async fn process(
627 &self,
628 task: i32,
629 ) -> Result<String, ProcessorError> {
630 tokio::time::sleep(Duration::from_millis(100)).await;
631 Ok(format!("Processed: {}", task))
632 }
633 }
634
635 #[tokio::test]
636 async fn test_async_processor() {
637 let config = ProcessorConfig {
638 max_queue_size: 100,
639 max_concurrent_tasks: 5,
640 task_timeout: Duration::from_secs(1),
641 max_retries: 3,
642 retry_delay: Duration::from_secs(1),
643 cleanup_timeout: Duration::from_secs(10),
644 };
645 let mut processor = AsyncProcessor::new(config, TestProcessor);
646 processor.start().await.unwrap();
647
648 let mut receivers = Vec::new();
649 for i in 0..10 {
650 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
651 receivers.push(rx);
652 }
653
654 for mut rx in receivers {
655 let result = rx.recv().await.unwrap();
656 assert_eq!(result.status, TaskStatus::Completed);
657 assert!(result.error.is_none());
658 assert!(result.output.is_some());
659 }
660
661 processor.shutdown().await.unwrap();
663 }
664
665 #[tokio::test]
666 async fn test_processor_shutdown() {
667 let config = ProcessorConfig {
668 max_queue_size: 100,
669 max_concurrent_tasks: 5,
670 task_timeout: Duration::from_secs(1),
671 max_retries: 3,
672 retry_delay: Duration::from_secs(1),
673 cleanup_timeout: Duration::from_secs(10),
674 };
675 let mut processor = AsyncProcessor::new(config, TestProcessor);
676 processor.start().await.unwrap();
677
678 let mut receivers = Vec::new();
680 for i in 0..5 {
681 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
682 receivers.push(rx);
683 }
684
685 processor.shutdown().await.unwrap();
687
688 for mut rx in receivers {
690 let result = rx.recv().await.unwrap();
691 assert_eq!(result.status, TaskStatus::Completed);
692 }
693 }
694
695 #[tokio::test]
696 async fn test_processor_auto_shutdown() {
697 let config = ProcessorConfig {
698 max_queue_size: 100,
699 max_concurrent_tasks: 5,
700 task_timeout: Duration::from_secs(1),
701 max_retries: 3,
702 retry_delay: Duration::from_secs(1),
703 cleanup_timeout: Duration::from_secs(10),
704 };
705 let mut processor = AsyncProcessor::new(config, TestProcessor);
706 processor.start().await.unwrap();
707
708 let mut receivers = Vec::new();
710 for i in 0..5 {
711 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
712 receivers.push(rx);
713 }
714
715 drop(processor);
717
718 for mut rx in receivers {
720 let result = rx.recv().await.unwrap();
721 assert!(matches!(
723 result.status,
724 TaskStatus::Completed | TaskStatus::Cancelled
725 ));
726 }
727 }
728}