1use std::{
2 fmt::Display,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6use anyhow::anyhow;
7use moduforge_state::debug;
8use tokio::sync::{mpsc, oneshot};
9use async_trait::async_trait;
10use tokio::select;
11
12use crate::{metrics, ForgeResult};
13
14#[derive(Debug, Clone, PartialEq)]
22pub enum TaskStatus {
23 Pending,
24 Processing,
25 Completed,
26 Failed(String),
27 Timeout,
28 Cancelled,
29}
30
31impl From<&TaskStatus> for &'static str {
32 fn from(status: &TaskStatus) -> Self {
33 match status {
34 TaskStatus::Pending => "pending",
35 TaskStatus::Processing => "processing",
36 TaskStatus::Completed => "completed",
37 TaskStatus::Failed(_) => "failed",
38 TaskStatus::Timeout => "timeout",
39 TaskStatus::Cancelled => "cancelled",
40 }
41 }
42}
43
44#[derive(Debug)]
52pub enum ProcessorError {
53 QueueFull,
54 TaskFailed(String),
55 InternalError(String),
56 TaskTimeout,
57 TaskCancelled,
58 RetryExhausted(String),
59}
60
61impl Display for ProcessorError {
62 fn fmt(
63 &self,
64 f: &mut std::fmt::Formatter<'_>,
65 ) -> std::fmt::Result {
66 match self {
67 ProcessorError::QueueFull => write!(f, "任务队列已满"),
68 ProcessorError::TaskFailed(msg) => {
69 write!(f, "任务执行失败: {}", msg)
70 },
71 ProcessorError::InternalError(msg) => {
72 write!(f, "内部错误: {}", msg)
73 },
74 ProcessorError::TaskTimeout => {
75 write!(f, "任务执行超时")
76 },
77 ProcessorError::TaskCancelled => write!(f, "任务被取消"),
78 ProcessorError::RetryExhausted(msg) => {
79 write!(f, "重试次数耗尽: {}", msg)
80 },
81 }
82 }
83}
84
85impl std::error::Error for ProcessorError {}
86
87#[derive(Clone, Debug)]
94pub struct ProcessorConfig {
95 pub max_queue_size: usize,
96 pub max_concurrent_tasks: usize,
97 pub task_timeout: Duration,
98 pub max_retries: u32,
99 pub retry_delay: Duration,
100}
101
102impl Default for ProcessorConfig {
103 fn default() -> Self {
104 Self {
105 max_queue_size: 1000,
106 max_concurrent_tasks: 10,
107 task_timeout: Duration::from_secs(30),
108 max_retries: 3,
109 retry_delay: Duration::from_secs(1),
110 }
111 }
112}
113
114#[derive(Debug, Default, Clone)]
123pub struct ProcessorStats {
124 pub total_tasks: u64,
125 pub completed_tasks: u64,
126 pub failed_tasks: u64,
127 pub timeout_tasks: u64,
128 pub cancelled_tasks: u64,
129 pub current_queue_size: usize,
130 pub current_processing_tasks: usize,
131}
132
133#[derive(Debug)]
141pub struct TaskResult<T, O>
142where
143 T: Send + Sync,
144 O: Send + Sync,
145{
146 pub task_id: u64,
147 pub status: TaskStatus,
148 pub task: Option<T>,
149 pub output: Option<O>,
150 pub error: Option<String>,
151 pub processing_time: Option<Duration>,
152}
153
154struct QueuedTask<T, O>
161where
162 T: Send + Sync,
163 O: Send + Sync,
164{
165 task: T,
166 task_id: u64,
167 result_tx: mpsc::Sender<TaskResult<T, O>>,
168 priority: u32,
169 retry_count: u32,
170}
171
172pub struct TaskQueue<T, O>
178where
179 T: Send + Sync,
180 O: Send + Sync,
181{
182 queue: mpsc::Sender<QueuedTask<T, O>>,
183 queue_rx: Arc<tokio::sync::Mutex<Option<mpsc::Receiver<QueuedTask<T, O>>>>>,
184 next_task_id: Arc<tokio::sync::Mutex<u64>>,
185 stats: Arc<tokio::sync::Mutex<ProcessorStats>>,
186}
187
188impl<T: Clone + Send + Sync + 'static, O: Clone + Send + Sync + 'static>
189 TaskQueue<T, O>
190{
191 pub fn new(config: &ProcessorConfig) -> Self {
192 let (tx, rx) = mpsc::channel(config.max_queue_size);
193 Self {
194 queue: tx,
195 queue_rx: Arc::new(tokio::sync::Mutex::new(Some(rx))),
196 next_task_id: Arc::new(tokio::sync::Mutex::new(0)),
197 stats: Arc::new(tokio::sync::Mutex::new(ProcessorStats::default())),
198 }
199 }
200
201 pub async fn enqueue_task(
202 &self,
203 task: T,
204 priority: u32,
205 ) -> ForgeResult<(u64, mpsc::Receiver<TaskResult<T, O>>)> {
206 let mut task_id = self.next_task_id.lock().await;
207 *task_id += 1;
208 let current_id = *task_id;
209
210 let (result_tx, result_rx) = mpsc::channel(1);
211 let queued_task = QueuedTask {
212 task,
213 task_id: current_id,
214 result_tx,
215 priority,
216 retry_count: 0,
217 };
218
219 self.queue
220 .send(queued_task)
221 .await
222 .map_err(|_| anyhow!("队列已经满了"))?;
223
224 let mut stats = self.stats.lock().await;
225 stats.total_tasks += 1;
226 stats.current_queue_size += 1;
227
228 metrics::task_submitted();
229 metrics::set_queue_size(stats.current_queue_size);
230
231 Ok((current_id, result_rx))
232 }
233
234 pub async fn get_next_ready(
235 &self
236 ) -> Option<(T, u64, mpsc::Sender<TaskResult<T, O>>, u32, u32)> {
237 let mut rx_guard = self.queue_rx.lock().await;
238 if let Some(rx) = rx_guard.as_mut() {
239 if let Some(queued) = rx.recv().await {
240 let mut stats: tokio::sync::MutexGuard<'_, ProcessorStats> =
241 self.stats.lock().await;
242 stats.current_queue_size -= 1;
243 stats.current_processing_tasks += 1;
244 metrics::set_queue_size(stats.current_queue_size);
245 metrics::increment_processing_tasks();
246 return Some((
247 queued.task,
248 queued.task_id,
249 queued.result_tx,
250 queued.priority,
251 queued.retry_count,
252 ));
253 }
254 }
255 None
256 }
257
258 pub async fn get_stats(&self) -> ProcessorStats {
259 self.stats.lock().await.clone()
260 }
261
262 pub async fn update_stats(
263 &self,
264 result: &TaskResult<T, O>,
265 ) {
266 let mut stats = self.stats.lock().await;
267 stats.current_processing_tasks -= 1;
268 metrics::decrement_processing_tasks();
269
270 let status_str: &'static str = (&result.status).into();
271 metrics::task_processed(status_str);
272
273 if let Some(duration) = result.processing_time {
274 metrics::task_processing_duration(duration);
275 }
276
277 match result.status {
278 TaskStatus::Completed => {
279 stats.completed_tasks += 1;
280 }
281 TaskStatus::Failed(_) => stats.failed_tasks += 1,
282 TaskStatus::Timeout => stats.timeout_tasks += 1,
283 TaskStatus::Cancelled => stats.cancelled_tasks += 1,
284 _ => {}
285 }
286 }
287}
288
289#[async_trait]
292pub trait TaskProcessor<T, O>: Send + Sync + 'static
293where
294 T: Clone + Send + Sync + 'static,
295 O: Clone + Send + Sync + 'static,
296{
297 async fn process(
298 &self,
299 task: T,
300 ) -> Result<O, ProcessorError>;
301}
302
303pub struct AsyncProcessor<T, O, P>
309where
310 T: Clone + Send + Sync + 'static,
311 O: Clone + Send + Sync + 'static,
312 P: TaskProcessor<T, O>,
313{
314 task_queue: Arc<TaskQueue<T, O>>,
315 config: ProcessorConfig,
316 processor: Arc<P>,
317 shutdown_tx: Option<oneshot::Sender<()>>,
318 handle: Option<tokio::task::JoinHandle<()>>,
319}
320
321impl<T, O, P> AsyncProcessor<T, O, P>
322where
323 T: Clone + Send + Sync + 'static,
324 O: Clone + Send + Sync + 'static,
325 P: TaskProcessor<T, O>,
326{
327 pub fn new(
329 config: ProcessorConfig,
330 processor: P,
331 ) -> Self {
332 let task_queue = Arc::new(TaskQueue::new(&config));
333 Self {
334 task_queue,
335 config,
336 processor: Arc::new(processor),
337 shutdown_tx: None,
338 handle: None,
339 }
340 }
341
342 pub async fn submit_task(
345 &self,
346 task: T,
347 priority: u32,
348 ) -> ForgeResult<(u64, mpsc::Receiver<TaskResult<T, O>>)> {
349 self.task_queue.enqueue_task(task, priority).await
350 }
351
352 pub fn start(&mut self) {
355 let queue = self.task_queue.clone();
356 let processor = self.processor.clone();
357 let config = self.config.clone();
358 let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
359
360 self.shutdown_tx = Some(shutdown_tx);
361
362 let handle = tokio::spawn(async move {
363 let mut join_set = tokio::task::JoinSet::new();
364
365 loop {
366 select! {
367 _ = &mut shutdown_rx => {
369 break;
370 }
371
372 Some(result) = join_set.join_next() => {
374 if let Err(e) = result {
375 debug!("任务执行失败: {}", e);
376 }
377 }
378
379 Some((task, task_id, result_tx, _priority, retry_count)) = queue.get_next_ready() => {
381 if join_set.len() < config.max_concurrent_tasks {
382 let processor = processor.clone();
383 let config = config.clone();
384 let queue = queue.clone();
385
386 join_set.spawn(async move {
387 let start_time = Instant::now();
388 let mut current_retry = retry_count;
389
390 loop {
391 let result = tokio::time::timeout(
392 config.task_timeout,
393 processor.process(task.clone())
394 ).await;
395
396 match result {
397 Ok(Ok(output)) => {
398 let processing_time = start_time.elapsed();
399 let task_result = TaskResult {
400 task_id,
401 status: TaskStatus::Completed,
402 task: Some(task),
403 output: Some(output),
404 error: None,
405 processing_time: Some(processing_time),
406 };
407 queue.update_stats(&task_result).await;
408 let _ = result_tx.send(task_result).await;
409 break;
410 }
411 Ok(Err(e)) => {
412 if current_retry < config.max_retries {
413 current_retry += 1;
414 tokio::time::sleep(config.retry_delay).await;
415 continue;
416 }
417 let task_result = TaskResult {
418 task_id,
419 status: TaskStatus::Failed(e.to_string()),
420 task: Some(task),
421 output: None,
422 error: Some(e.to_string()),
423 processing_time: Some(start_time.elapsed()),
424 };
425 queue.update_stats(&task_result).await;
426 let _ = result_tx.send(task_result).await;
427 break;
428 }
429 Err(_) => {
430 let task_result = TaskResult {
431 task_id,
432 status: TaskStatus::Timeout,
433 task: Some(task),
434 output: None,
435 error: Some("任务执行超时".to_string()),
436 processing_time: Some(start_time.elapsed()),
437 };
438 queue.update_stats(&task_result).await;
439 let _ = result_tx.send(task_result).await;
440 break;
441 }
442 }
443 }
444 });
445 }
446 }
447 }
448 }
449 });
450
451 self.handle = Some(handle);
452 }
453
454 pub fn shutdown(&mut self) -> Result<(), ProcessorError> {
457 if let Some(shutdown_tx) = self.shutdown_tx.take() {
458 shutdown_tx.send(()).map_err(|_| {
459 ProcessorError::InternalError(
460 "Failed to send shutdown signal".to_string(),
461 )
462 })?;
463 }
464 Ok(())
465 }
466
467 pub async fn get_stats(&self) -> ProcessorStats {
468 self.task_queue.get_stats().await
469 }
470}
471
472impl<T, O, P> Drop for AsyncProcessor<T, O, P>
474where
475 T: Clone + Send + Sync + 'static,
476 O: Clone + Send + Sync + 'static,
477 P: TaskProcessor<T, O>,
478{
479 fn drop(&mut self) {
480 if self.shutdown_tx.is_some() {
481 let _ = self.shutdown();
483 }
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490
491 struct TestProcessor;
492
493 #[async_trait::async_trait]
494 impl TaskProcessor<i32, String> for TestProcessor {
495 async fn process(
496 &self,
497 task: i32,
498 ) -> Result<String, ProcessorError> {
499 tokio::time::sleep(Duration::from_millis(100)).await;
500 Ok(format!("Processed: {}", task))
501 }
502 }
503
504 #[tokio::test]
505 async fn test_async_processor() {
506 let config = ProcessorConfig {
507 max_queue_size: 100,
508 max_concurrent_tasks: 5,
509 task_timeout: Duration::from_secs(1),
510 max_retries: 3,
511 retry_delay: Duration::from_secs(1),
512 };
513 let mut processor = AsyncProcessor::new(config, TestProcessor);
514 processor.start();
515
516 let mut receivers = Vec::new();
517 for i in 0..10 {
518 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
519 receivers.push(rx);
520 }
521
522 for mut rx in receivers {
523 let result = rx.recv().await.unwrap();
524 assert_eq!(result.status, TaskStatus::Completed);
525 assert!(result.error.is_none());
526 assert!(result.output.is_some());
527 }
528 }
529
530 #[tokio::test]
531 async fn test_processor_shutdown() {
532 let config = ProcessorConfig {
533 max_queue_size: 100,
534 max_concurrent_tasks: 5,
535 task_timeout: Duration::from_secs(1),
536 max_retries: 3,
537 retry_delay: Duration::from_secs(1),
538 };
539 let mut processor = AsyncProcessor::new(config, TestProcessor);
540 processor.start();
541
542 let mut receivers = Vec::new();
544 for i in 0..5 {
545 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
546 receivers.push(rx);
547 }
548
549 processor.shutdown().unwrap();
551
552 for mut rx in receivers {
554 let result = rx.recv().await.unwrap();
555 assert_eq!(result.status, TaskStatus::Completed);
556 }
557 }
558
559 #[tokio::test]
560 async fn test_processor_auto_shutdown() {
561 let config = ProcessorConfig {
562 max_queue_size: 100,
563 max_concurrent_tasks: 5,
564 task_timeout: Duration::from_secs(1),
565 max_retries: 3,
566 retry_delay: Duration::from_secs(1),
567 };
568 let mut processor = AsyncProcessor::new(config, TestProcessor);
569 processor.start();
570
571 let mut receivers = Vec::new();
573 for i in 0..5 {
574 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
575 receivers.push(rx);
576 }
577
578 drop(processor);
580
581 for mut rx in receivers {
583 let result = rx.recv().await.unwrap();
584 assert_eq!(result.status, TaskStatus::Completed);
585 }
586 }
587}