1use std::{
2 fmt::Display,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6use crate::{error::error_utils, config::ProcessorConfig, debug::debug};
7use tokio::sync::{mpsc, oneshot};
8use async_trait::async_trait;
9use tokio::select;
10
11use crate::{metrics, ForgeResult};
12
13type QueueReceiver<T, O> =
15 Arc<tokio::sync::Mutex<Option<mpsc::Receiver<QueuedTask<T, O>>>>>;
16
17#[derive(Debug, Clone, PartialEq)]
25pub enum TaskStatus {
26 Pending,
27 Processing,
28 Completed,
29 Failed(String),
30 Timeout,
31 Cancelled,
32}
33
34impl From<&TaskStatus> for &'static str {
35 fn from(status: &TaskStatus) -> Self {
36 match status {
37 TaskStatus::Pending => "pending",
38 TaskStatus::Processing => "processing",
39 TaskStatus::Completed => "completed",
40 TaskStatus::Failed(_) => "failed",
41 TaskStatus::Timeout => "timeout",
42 TaskStatus::Cancelled => "cancelled",
43 }
44 }
45}
46
47#[derive(Debug)]
55pub enum ProcessorError {
56 QueueFull,
57 TaskFailed(String),
58 InternalError(String),
59 TaskTimeout,
60 TaskCancelled,
61 RetryExhausted(String),
62}
63
64impl Display for ProcessorError {
65 fn fmt(
66 &self,
67 f: &mut std::fmt::Formatter<'_>,
68 ) -> std::fmt::Result {
69 match self {
70 ProcessorError::QueueFull => write!(f, "任务队列已满"),
71 ProcessorError::TaskFailed(msg) => {
72 write!(f, "任务执行失败: {msg}")
73 },
74 ProcessorError::InternalError(msg) => {
75 write!(f, "内部错误: {msg}")
76 },
77 ProcessorError::TaskTimeout => {
78 write!(f, "任务执行超时")
79 },
80 ProcessorError::TaskCancelled => write!(f, "任务被取消"),
81 ProcessorError::RetryExhausted(msg) => {
82 write!(f, "重试次数耗尽: {msg}")
83 },
84 }
85 }
86}
87
88impl std::error::Error for ProcessorError {}
89
90#[derive(Debug, Default, Clone)]
101pub struct ProcessorStats {
102 pub total_tasks: u64,
103 pub completed_tasks: u64,
104 pub failed_tasks: u64,
105 pub timeout_tasks: u64,
106 pub cancelled_tasks: u64,
107 pub current_queue_size: usize,
108 pub current_processing_tasks: usize,
109}
110
111#[derive(Debug)]
119pub struct TaskResult<T, O>
120where
121 T: Send + Sync,
122 O: Send + Sync,
123{
124 pub task_id: u64,
125 pub status: TaskStatus,
126 pub task: Option<T>,
127 pub output: Option<O>,
128 pub error: Option<String>,
129 pub processing_time: Option<Duration>,
130}
131
132struct QueuedTask<T, O>
139where
140 T: Send + Sync,
141 O: Send + Sync,
142{
143 task: T,
144 task_id: u64,
145 result_tx: mpsc::Sender<TaskResult<T, O>>,
146 priority: u32,
147 retry_count: u32,
148}
149
150pub struct TaskQueue<T, O>
156where
157 T: Send + Sync,
158 O: Send + Sync,
159{
160 queue: mpsc::Sender<QueuedTask<T, O>>,
161 queue_rx: QueueReceiver<T, O>,
162 next_task_id: Arc<tokio::sync::Mutex<u64>>,
163 stats: Arc<tokio::sync::Mutex<ProcessorStats>>,
164}
165
166impl<T: Clone + Send + Sync + 'static, O: Clone + Send + Sync + 'static>
167 TaskQueue<T, O>
168{
169 pub fn new(config: &ProcessorConfig) -> Self {
170 let (tx, rx) = mpsc::channel(config.max_queue_size);
171 Self {
172 queue: tx,
173 queue_rx: Arc::new(tokio::sync::Mutex::new(Some(rx))),
174 next_task_id: Arc::new(tokio::sync::Mutex::new(0)),
175 stats: Arc::new(tokio::sync::Mutex::new(ProcessorStats::default())),
176 }
177 }
178
179 pub async fn enqueue_task(
180 &self,
181 task: T,
182 priority: u32,
183 ) -> ForgeResult<(u64, mpsc::Receiver<TaskResult<T, O>>)> {
184 let mut task_id = self.next_task_id.lock().await;
185 *task_id += 1;
186 let current_id = *task_id;
187
188 let (result_tx, result_rx) = mpsc::channel(1);
189 let queued_task = QueuedTask {
190 task,
191 task_id: current_id,
192 result_tx,
193 priority,
194 retry_count: 0,
195 };
196
197 self.queue
198 .send(queued_task)
199 .await
200 .map_err(|_| error_utils::resource_exhausted_error("任务队列"))?;
201
202 let mut stats = self.stats.lock().await;
203 stats.total_tasks += 1;
204 stats.current_queue_size += 1;
205
206 metrics::task_submitted();
207 metrics::set_queue_size(stats.current_queue_size);
208
209 Ok((current_id, result_rx))
210 }
211
212 pub async fn get_next_ready(
213 &self
214 ) -> Option<(T, u64, mpsc::Sender<TaskResult<T, O>>, u32, u32)> {
215 let mut rx_guard = self.queue_rx.lock().await;
216 if let Some(rx) = rx_guard.as_mut() {
217 if let Some(queued) = rx.recv().await {
218 let mut stats: tokio::sync::MutexGuard<'_, ProcessorStats> =
219 self.stats.lock().await;
220 stats.current_queue_size -= 1;
221 stats.current_processing_tasks += 1;
222 metrics::set_queue_size(stats.current_queue_size);
223 metrics::increment_processing_tasks();
224 return Some((
225 queued.task,
226 queued.task_id,
227 queued.result_tx,
228 queued.priority,
229 queued.retry_count,
230 ));
231 }
232 }
233 None
234 }
235
236 pub async fn get_stats(&self) -> ProcessorStats {
237 self.stats.lock().await.clone()
238 }
239
240 pub async fn update_stats(
241 &self,
242 result: &TaskResult<T, O>,
243 ) {
244 let mut stats = self.stats.lock().await;
245 stats.current_processing_tasks -= 1;
246 metrics::decrement_processing_tasks();
247
248 let status_str: &'static str = (&result.status).into();
249 metrics::task_processed(status_str);
250
251 if let Some(duration) = result.processing_time {
252 metrics::task_processing_duration(duration);
253 }
254
255 match result.status {
256 TaskStatus::Completed => {
257 stats.completed_tasks += 1;
258 },
259 TaskStatus::Failed(_) => stats.failed_tasks += 1,
260 TaskStatus::Timeout => stats.timeout_tasks += 1,
261 TaskStatus::Cancelled => stats.cancelled_tasks += 1,
262 _ => {},
263 }
264 }
265}
266
267#[async_trait]
270pub trait TaskProcessor<T, O>: Send + Sync + 'static
271where
272 T: Clone + Send + Sync + 'static,
273 O: Clone + Send + Sync + 'static,
274{
275 async fn process(
276 &self,
277 task: T,
278 ) -> Result<O, ProcessorError>;
279}
280
281#[derive(Debug, Clone, PartialEq)]
283pub enum ProcessorState {
284 NotStarted,
286 Running,
288 Shutting,
290 Shutdown,
292}
293
294pub struct AsyncProcessor<T, O, P>
300where
301 T: Clone + Send + Sync + 'static,
302 O: Clone + Send + Sync + 'static,
303 P: TaskProcessor<T, O>,
304{
305 task_queue: Arc<TaskQueue<T, O>>,
306 config: ProcessorConfig,
307 processor: Arc<P>,
308 shutdown_tx: Option<oneshot::Sender<()>>,
309 handle: Option<tokio::task::JoinHandle<()>>,
310 state: Arc<tokio::sync::Mutex<ProcessorState>>,
311}
312
313impl<T, O, P> AsyncProcessor<T, O, P>
314where
315 T: Clone + Send + Sync + 'static,
316 O: Clone + Send + Sync + 'static,
317 P: TaskProcessor<T, O>,
318{
319 pub fn new(
321 config: ProcessorConfig,
322 processor: P,
323 ) -> Self {
324 let task_queue = Arc::new(TaskQueue::new(&config));
325 Self {
326 task_queue,
327 config,
328 processor: Arc::new(processor),
329 shutdown_tx: None,
330 handle: None,
331 state: Arc::new(tokio::sync::Mutex::new(
332 ProcessorState::NotStarted,
333 )),
334 }
335 }
336
337 pub async fn submit_task(
340 &self,
341 task: T,
342 priority: u32,
343 ) -> ForgeResult<(u64, mpsc::Receiver<TaskResult<T, O>>)> {
344 self.task_queue.enqueue_task(task, priority).await
345 }
346
347 pub async fn start(&mut self) -> Result<(), ProcessorError> {
350 let mut state = self.state.lock().await;
351 if *state != ProcessorState::NotStarted {
352 return Err(ProcessorError::InternalError(
353 "处理器已经启动或正在关闭".to_string(),
354 ));
355 }
356 *state = ProcessorState::Running;
357 drop(state);
358
359 let queue = self.task_queue.clone();
360 let processor = self.processor.clone();
361 let config = self.config.clone();
362 let state_ref = self.state.clone();
363 let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
364
365 self.shutdown_tx = Some(shutdown_tx);
366
367 let handle = tokio::spawn(async move {
368 let mut join_set = tokio::task::JoinSet::new();
369
370 async fn cleanup_tasks(
372 join_set: &mut tokio::task::JoinSet<()>,
373 timeout: Duration,
374 ) {
375 debug!("开始清理正在运行的任务...");
376
377 let cleanup_start = Instant::now();
379 while !join_set.is_empty() {
380 if cleanup_start.elapsed() > timeout {
381 debug!("清理超时,强制中止剩余任务");
382 join_set.abort_all();
383 break;
384 }
385
386 if let Some(Err(e)) = join_set.join_next().await {
387 if !e.is_cancelled() {
388 debug!("任务执行失败: {}", e);
389 }
390 }
391 }
392 debug!("任务清理完成");
393 }
394
395 loop {
396 select! {
397 _ = &mut shutdown_rx => {
399 debug!("收到关闭信号,开始优雅关闭");
400 {
402 let mut state = state_ref.lock().await;
403 *state = ProcessorState::Shutting;
404 }
405
406 cleanup_tasks(&mut join_set, Duration::from_secs(30)).await;
408 break;
409 }
410
411 Some(result) = join_set.join_next() => {
413 if let Err(e) = result {
414 if !e.is_cancelled() {
415 debug!("任务执行失败: {}", e);
416 }
417 }
418 }
419
420 Some((task, task_id, result_tx, _priority, retry_count)) = queue.get_next_ready() => {
422 {
424 let state = state_ref.lock().await;
425 if *state != ProcessorState::Running {
426 let task_result = TaskResult {
428 task_id,
429 status: TaskStatus::Cancelled,
430 task: Some(task),
431 output: None,
432 error: Some("处理器正在关闭".to_string()),
433 processing_time: Some(Duration::from_millis(0)),
434 };
435 queue.update_stats(&task_result).await;
436 let _ = result_tx.send(task_result).await;
437 continue;
438 }
439 }
440
441 if join_set.len() < config.max_concurrent_tasks {
442 let processor = processor.clone();
443 let config = config.clone();
444 let queue = queue.clone();
445
446 join_set.spawn(async move {
447 let start_time = Instant::now();
448 let mut current_retry = retry_count;
449
450 loop {
451 let result = tokio::time::timeout(
452 config.task_timeout,
453 processor.process(task.clone())
454 ).await;
455
456 match result {
457 Ok(Ok(output)) => {
458 let processing_time = start_time.elapsed();
459 let task_result = TaskResult {
460 task_id,
461 status: TaskStatus::Completed,
462 task: Some(task),
463 output: Some(output),
464 error: None,
465 processing_time: Some(processing_time),
466 };
467 queue.update_stats(&task_result).await;
468 let _ = result_tx.send(task_result).await;
469 break;
470 }
471 Ok(Err(e)) => {
472 if current_retry < config.max_retries {
473 current_retry += 1;
474 tokio::time::sleep(config.retry_delay).await;
475 continue;
476 }
477 let task_result = TaskResult {
478 task_id,
479 status: TaskStatus::Failed(e.to_string()),
480 task: Some(task),
481 output: None,
482 error: Some(e.to_string()),
483 processing_time: Some(start_time.elapsed()),
484 };
485 queue.update_stats(&task_result).await;
486 let _ = result_tx.send(task_result).await;
487 break;
488 }
489 Err(_) => {
490 let task_result = TaskResult {
491 task_id,
492 status: TaskStatus::Timeout,
493 task: Some(task),
494 output: None,
495 error: Some("任务执行超时".to_string()),
496 processing_time: Some(start_time.elapsed()),
497 };
498 queue.update_stats(&task_result).await;
499 let _ = result_tx.send(task_result).await;
500 break;
501 }
502 }
503 }
504 });
505 }
506 }
507 }
508 }
509
510 {
512 let mut state = state_ref.lock().await;
513 *state = ProcessorState::Shutdown;
514 }
515 debug!("异步处理器已完全关闭");
516 });
517
518 self.handle = Some(handle);
519 Ok(())
520 }
521
522 pub async fn shutdown(&mut self) -> Result<(), ProcessorError> {
525 {
527 let mut state = self.state.lock().await;
528 match *state {
529 ProcessorState::NotStarted => {
530 return Err(ProcessorError::InternalError(
531 "处理器尚未启动".to_string(),
532 ));
533 },
534 ProcessorState::Shutdown => {
535 return Ok(()); },
537 ProcessorState::Shutting => {
538 },
540 ProcessorState::Running => {
541 *state = ProcessorState::Shutting;
542 },
543 }
544 }
545
546 if let Some(shutdown_tx) = self.shutdown_tx.take() {
548 shutdown_tx.send(()).map_err(|_| {
549 ProcessorError::InternalError(
550 "Failed to send shutdown signal".to_string(),
551 )
552 })?;
553 }
554
555 if let Some(handle) = self.handle.take() {
557 if let Err(e) = handle.await {
558 return Err(ProcessorError::InternalError(format!(
559 "等待后台任务完成时出错: {e}"
560 )));
561 }
562 }
563
564 {
566 let state = self.state.lock().await;
567 if *state != ProcessorState::Shutdown {
568 return Err(ProcessorError::InternalError(
569 "关闭过程未正确完成".to_string(),
570 ));
571 }
572 }
573
574 debug!("异步处理器已成功关闭");
575 Ok(())
576 }
577
578 pub async fn get_state(&self) -> ProcessorState {
580 let state = self.state.lock().await;
581 state.clone()
582 }
583
584 pub async fn is_running(&self) -> bool {
586 let state = self.state.lock().await;
587 *state == ProcessorState::Running
588 }
589
590 pub async fn get_stats(&self) -> ProcessorStats {
591 self.task_queue.get_stats().await
592 }
593}
594
595impl<T, O, P> Drop for AsyncProcessor<T, O, P>
600where
601 T: Clone + Send + Sync + 'static,
602 O: Clone + Send + Sync + 'static,
603 P: TaskProcessor<T, O>,
604{
605 fn drop(&mut self) {
606 if let Some(shutdown_tx) = self.shutdown_tx.take() {
608 let _ = shutdown_tx.send(());
609 debug!("AsyncProcessor Drop: 已发送关闭信号");
610 }
611
612 if let Some(handle) = self.handle.take() {
614 handle.abort();
615 debug!("AsyncProcessor Drop: 已中止后台任务");
616 }
617 }
618}
619
620#[cfg(test)]
621mod tests {
622 use super::*;
623
624 struct TestProcessor;
625
626 #[async_trait::async_trait]
627 impl TaskProcessor<i32, String> for TestProcessor {
628 async fn process(
629 &self,
630 task: i32,
631 ) -> Result<String, ProcessorError> {
632 tokio::time::sleep(Duration::from_millis(100)).await;
633 Ok(format!("Processed: {task}"))
634 }
635 }
636
637 #[tokio::test]
638 async fn test_async_processor() {
639 let config = ProcessorConfig {
640 max_queue_size: 100,
641 max_concurrent_tasks: 5,
642 task_timeout: Duration::from_secs(1),
643 max_retries: 3,
644 retry_delay: Duration::from_secs(1),
645 cleanup_timeout: Duration::from_secs(10),
646 };
647 let mut processor = AsyncProcessor::new(config, TestProcessor);
648 processor.start().await.unwrap();
649
650 let mut receivers = Vec::new();
651 for i in 0..10 {
652 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
653 receivers.push(rx);
654 }
655
656 for mut rx in receivers {
657 let result = rx.recv().await.unwrap();
658 assert_eq!(result.status, TaskStatus::Completed);
659 assert!(result.error.is_none());
660 assert!(result.output.is_some());
661 }
662
663 processor.shutdown().await.unwrap();
665 }
666
667 #[tokio::test]
668 async fn test_processor_shutdown() {
669 let config = ProcessorConfig {
670 max_queue_size: 100,
671 max_concurrent_tasks: 5,
672 task_timeout: Duration::from_secs(1),
673 max_retries: 3,
674 retry_delay: Duration::from_secs(1),
675 cleanup_timeout: Duration::from_secs(10),
676 };
677 let mut processor = AsyncProcessor::new(config, TestProcessor);
678 processor.start().await.unwrap();
679
680 let mut receivers = Vec::new();
682 for i in 0..5 {
683 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
684 receivers.push(rx);
685 }
686
687 processor.shutdown().await.unwrap();
689
690 for mut rx in receivers {
692 let result = rx.recv().await.unwrap();
693 assert_eq!(result.status, TaskStatus::Completed);
694 }
695 }
696
697 #[tokio::test]
698 async fn test_processor_auto_shutdown() {
699 let config = ProcessorConfig {
700 max_queue_size: 100,
701 max_concurrent_tasks: 5,
702 task_timeout: Duration::from_secs(1),
703 max_retries: 3,
704 retry_delay: Duration::from_secs(1),
705 cleanup_timeout: Duration::from_secs(10),
706 };
707 let mut processor = AsyncProcessor::new(config, TestProcessor);
708 processor.start().await.unwrap();
709
710 let mut receivers = Vec::new();
712 for i in 0..5 {
713 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
714 receivers.push(rx);
715 }
716
717 drop(processor);
719
720 for mut rx in receivers {
722 let result = rx.recv().await.unwrap();
723 assert!(matches!(
725 result.status,
726 TaskStatus::Completed | TaskStatus::Cancelled
727 ));
728 }
729 }
730}