1use std::{
2 fmt::Display,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6use anyhow::anyhow;
7use moduforge_state::debug;
8use tokio::sync::{mpsc, oneshot};
9use async_trait::async_trait;
10use tokio::select;
11
12use crate::EditorResult;
13
14#[derive(Debug, Clone, PartialEq)]
22pub enum TaskStatus {
23 Pending,
24 Processing,
25 Completed,
26 Failed(String),
27 Timeout,
28 Cancelled,
29}
30
31#[derive(Debug)]
39pub enum ProcessorError {
40 QueueFull,
41 TaskFailed(String),
42 InternalError(String),
43 TaskTimeout,
44 TaskCancelled,
45 RetryExhausted(String),
46}
47
48impl Display for ProcessorError {
49 fn fmt(
50 &self,
51 f: &mut std::fmt::Formatter<'_>,
52 ) -> std::fmt::Result {
53 match self {
54 ProcessorError::QueueFull => write!(f, "任务队列已满"),
55 ProcessorError::TaskFailed(msg) => {
56 write!(f, "任务执行失败: {}", msg)
57 },
58 ProcessorError::InternalError(msg) => {
59 write!(f, "内部错误: {}", msg)
60 },
61 ProcessorError::TaskTimeout => {
62 write!(f, "任务执行超时")
63 },
64 ProcessorError::TaskCancelled => write!(f, "任务被取消"),
65 ProcessorError::RetryExhausted(msg) => {
66 write!(f, "重试次数耗尽: {}", msg)
67 },
68 }
69 }
70}
71
72impl std::error::Error for ProcessorError {}
73
74#[derive(Clone, Debug)]
81pub struct ProcessorConfig {
82 pub max_queue_size: usize,
83 pub max_concurrent_tasks: usize,
84 pub task_timeout: Duration,
85 pub max_retries: u32,
86 pub retry_delay: Duration,
87}
88
89impl Default for ProcessorConfig {
90 fn default() -> Self {
91 Self {
92 max_queue_size: 1000,
93 max_concurrent_tasks: 10,
94 task_timeout: Duration::from_secs(30),
95 max_retries: 3,
96 retry_delay: Duration::from_secs(1),
97 }
98 }
99}
100
101#[derive(Debug, Default, Clone)]
111pub struct ProcessorStats {
112 pub total_tasks: u64,
113 pub completed_tasks: u64,
114 pub failed_tasks: u64,
115 pub timeout_tasks: u64,
116 pub cancelled_tasks: u64,
117 pub average_processing_time: Duration,
118 pub current_queue_size: usize,
119 pub current_processing_tasks: usize,
120}
121
122#[derive(Debug)]
130pub struct TaskResult<T, O>
131where
132 T: Send + Sync,
133 O: Send + Sync,
134{
135 pub task_id: u64,
136 pub status: TaskStatus,
137 pub task: Option<T>,
138 pub output: Option<O>,
139 pub error: Option<String>,
140 pub processing_time: Option<Duration>,
141}
142
143struct QueuedTask<T, O>
150where
151 T: Send + Sync,
152 O: Send + Sync,
153{
154 task: T,
155 task_id: u64,
156 result_tx: mpsc::Sender<TaskResult<T, O>>,
157 priority: u32,
158 retry_count: u32,
159}
160
161pub struct TaskQueue<T, O>
167where
168 T: Send + Sync,
169 O: Send + Sync,
170{
171 queue: mpsc::Sender<QueuedTask<T, O>>,
172 queue_rx: Arc<tokio::sync::Mutex<Option<mpsc::Receiver<QueuedTask<T, O>>>>>,
173 next_task_id: Arc<tokio::sync::Mutex<u64>>,
174 stats: Arc<tokio::sync::Mutex<ProcessorStats>>,
175}
176
177impl<T: Clone + Send + Sync + 'static, O: Clone + Send + Sync + 'static>
178 TaskQueue<T, O>
179{
180 pub fn new(config: &ProcessorConfig) -> Self {
181 let (tx, rx) = mpsc::channel(config.max_queue_size);
182 Self {
183 queue: tx,
184 queue_rx: Arc::new(tokio::sync::Mutex::new(Some(rx))),
185 next_task_id: Arc::new(tokio::sync::Mutex::new(0)),
186 stats: Arc::new(tokio::sync::Mutex::new(ProcessorStats::default())),
187 }
188 }
189
190 pub async fn enqueue_task(
191 &self,
192 task: T,
193 priority: u32,
194 ) -> EditorResult<(u64, mpsc::Receiver<TaskResult<T, O>>)> {
195 let mut task_id = self.next_task_id.lock().await;
196 *task_id += 1;
197 let current_id = *task_id;
198
199 let (result_tx, result_rx) = mpsc::channel(1);
200 let queued_task = QueuedTask {
201 task,
202 task_id: current_id,
203 result_tx,
204 priority,
205 retry_count: 0,
206 };
207
208 self.queue
209 .send(queued_task)
210 .await
211 .map_err(|_| anyhow!("队列已经满了"))?;
212
213 let mut stats = self.stats.lock().await;
214 stats.total_tasks += 1;
215 stats.current_queue_size += 1;
216
217 Ok((current_id, result_rx))
218 }
219
220 pub async fn get_next_ready(
221 &self
222 ) -> Option<(T, u64, mpsc::Sender<TaskResult<T, O>>, u32, u32)> {
223 let mut rx_guard = self.queue_rx.lock().await;
224 if let Some(rx) = rx_guard.as_mut() {
225 if let Some(queued) = rx.recv().await {
226 let mut stats: tokio::sync::MutexGuard<'_, ProcessorStats> =
227 self.stats.lock().await;
228 stats.current_queue_size -= 1;
229 stats.current_processing_tasks += 1;
230 return Some((
231 queued.task,
232 queued.task_id,
233 queued.result_tx,
234 queued.priority,
235 queued.retry_count,
236 ));
237 }
238 }
239 None
240 }
241
242 pub async fn get_stats(&self) -> ProcessorStats {
243 self.stats.lock().await.clone()
244 }
245
246 pub async fn update_stats(
247 &self,
248 result: &TaskResult<T, O>,
249 ) {
250 let mut stats = self.stats.lock().await;
251 match result.status {
252 TaskStatus::Completed => {
253 stats.completed_tasks += 1;
254 if let Some(processing_time) = result.processing_time {
255 stats.average_processing_time =
256 (stats.average_processing_time + processing_time) / 2;
257 }
258 },
259 TaskStatus::Failed(_) => stats.failed_tasks += 1,
260 TaskStatus::Timeout => stats.timeout_tasks += 1,
261 TaskStatus::Cancelled => stats.cancelled_tasks += 1,
262 _ => {},
263 }
264 stats.current_processing_tasks -= 1;
265 }
266}
267
268#[async_trait]
271pub trait TaskProcessor<T, O>: Send + Sync + 'static
272where
273 T: Clone + Send + Sync + 'static,
274 O: Clone + Send + Sync + 'static,
275{
276 async fn process(
277 &self,
278 task: T,
279 ) -> Result<O, ProcessorError>;
280}
281
282pub struct AsyncProcessor<T, O, P>
288where
289 T: Clone + Send + Sync + 'static,
290 O: Clone + Send + Sync + 'static,
291 P: TaskProcessor<T, O>,
292{
293 task_queue: Arc<TaskQueue<T, O>>,
294 config: ProcessorConfig,
295 processor: Arc<P>,
296 shutdown_tx: Option<oneshot::Sender<()>>,
297 handle: Option<tokio::task::JoinHandle<()>>,
298}
299
300impl<T, O, P> AsyncProcessor<T, O, P>
301where
302 T: Clone + Send + Sync + 'static,
303 O: Clone + Send + Sync + 'static,
304 P: TaskProcessor<T, O>,
305{
306 pub fn new(
308 config: ProcessorConfig,
309 processor: P,
310 ) -> Self {
311 let task_queue = Arc::new(TaskQueue::new(&config));
312 Self {
313 task_queue,
314 config,
315 processor: Arc::new(processor),
316 shutdown_tx: None,
317 handle: None,
318 }
319 }
320
321 pub async fn submit_task(
324 &self,
325 task: T,
326 priority: u32,
327 ) -> EditorResult<(u64, mpsc::Receiver<TaskResult<T, O>>)> {
328 self.task_queue.enqueue_task(task, priority).await
329 }
330
331 pub fn start(&mut self) {
334 let queue = self.task_queue.clone();
335 let processor = self.processor.clone();
336 let config = self.config.clone();
337 let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
338
339 self.shutdown_tx = Some(shutdown_tx);
340
341 let handle = tokio::spawn(async move {
342 let mut join_set = tokio::task::JoinSet::new();
343
344 loop {
345 select! {
346 _ = &mut shutdown_rx => {
348 break;
349 }
350
351 Some(result) = join_set.join_next() => {
353 if let Err(e) = result {
354 debug!("任务执行失败: {}", e);
355 }
356 }
357
358 Some((task, task_id, result_tx, _priority, retry_count)) = queue.get_next_ready() => {
360 if join_set.len() < config.max_concurrent_tasks {
361 let processor = processor.clone();
362 let config = config.clone();
363 let queue = queue.clone();
364
365 join_set.spawn(async move {
366 let start_time = Instant::now();
367 let mut current_retry = retry_count;
368
369 loop {
370 let result = tokio::time::timeout(
371 config.task_timeout,
372 processor.process(task.clone())
373 ).await;
374
375 match result {
376 Ok(Ok(output)) => {
377 let processing_time = start_time.elapsed();
378 let task_result = TaskResult {
379 task_id,
380 status: TaskStatus::Completed,
381 task: Some(task),
382 output: Some(output),
383 error: None,
384 processing_time: Some(processing_time),
385 };
386 queue.update_stats(&task_result).await;
387 let _ = result_tx.send(task_result).await;
388 break;
389 }
390 Ok(Err(e)) => {
391 if current_retry < config.max_retries {
392 current_retry += 1;
393 tokio::time::sleep(config.retry_delay).await;
394 continue;
395 }
396 let task_result = TaskResult {
397 task_id,
398 status: TaskStatus::Failed(e.to_string()),
399 task: Some(task),
400 output: None,
401 error: Some(e.to_string()),
402 processing_time: Some(start_time.elapsed()),
403 };
404 queue.update_stats(&task_result).await;
405 let _ = result_tx.send(task_result).await;
406 break;
407 }
408 Err(_) => {
409 let task_result = TaskResult {
410 task_id,
411 status: TaskStatus::Timeout,
412 task: Some(task),
413 output: None,
414 error: Some("任务执行超时".to_string()),
415 processing_time: Some(start_time.elapsed()),
416 };
417 queue.update_stats(&task_result).await;
418 let _ = result_tx.send(task_result).await;
419 break;
420 }
421 }
422 }
423 });
424 }
425 }
426 }
427 }
428 });
429
430 self.handle = Some(handle);
431 }
432
433 pub fn shutdown(&mut self) -> Result<(), ProcessorError> {
436 if let Some(shutdown_tx) = self.shutdown_tx.take() {
437 shutdown_tx.send(()).map_err(|_| {
438 ProcessorError::InternalError(
439 "Failed to send shutdown signal".to_string(),
440 )
441 })?;
442 }
443 Ok(())
444 }
445
446 pub async fn get_stats(&self) -> ProcessorStats {
447 self.task_queue.get_stats().await
448 }
449}
450
451impl<T, O, P> Drop for AsyncProcessor<T, O, P>
453where
454 T: Clone + Send + Sync + 'static,
455 O: Clone + Send + Sync + 'static,
456 P: TaskProcessor<T, O>,
457{
458 fn drop(&mut self) {
459 if self.shutdown_tx.is_some() {
460 let _ = self.shutdown();
462 }
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 struct TestProcessor;
471
472 #[async_trait::async_trait]
473 impl TaskProcessor<i32, String> for TestProcessor {
474 async fn process(
475 &self,
476 task: i32,
477 ) -> Result<String, ProcessorError> {
478 tokio::time::sleep(Duration::from_millis(100)).await;
479 Ok(format!("Processed: {}", task))
480 }
481 }
482
483 #[tokio::test]
484 async fn test_async_processor() {
485 let config = ProcessorConfig {
486 max_queue_size: 100,
487 max_concurrent_tasks: 5,
488 task_timeout: Duration::from_secs(1),
489 max_retries: 3,
490 retry_delay: Duration::from_secs(1),
491 };
492 let mut processor = AsyncProcessor::new(config, TestProcessor);
493 processor.start();
494
495 let mut receivers = Vec::new();
496 for i in 0..10 {
497 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
498 receivers.push(rx);
499 }
500
501 for mut rx in receivers {
502 let result = rx.recv().await.unwrap();
503 assert_eq!(result.status, TaskStatus::Completed);
504 assert!(result.error.is_none());
505 assert!(result.output.is_some());
506 }
507 }
508
509 #[tokio::test]
510 async fn test_processor_shutdown() {
511 let config = ProcessorConfig {
512 max_queue_size: 100,
513 max_concurrent_tasks: 5,
514 task_timeout: Duration::from_secs(1),
515 max_retries: 3,
516 retry_delay: Duration::from_secs(1),
517 };
518 let mut processor = AsyncProcessor::new(config, TestProcessor);
519 processor.start();
520
521 let mut receivers = Vec::new();
523 for i in 0..5 {
524 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
525 receivers.push(rx);
526 }
527
528 processor.shutdown().unwrap();
530
531 for mut rx in receivers {
533 let result = rx.recv().await.unwrap();
534 assert_eq!(result.status, TaskStatus::Completed);
535 }
536 }
537
538 #[tokio::test]
539 async fn test_processor_auto_shutdown() {
540 let config = ProcessorConfig {
541 max_queue_size: 100,
542 max_concurrent_tasks: 5,
543 task_timeout: Duration::from_secs(1),
544 max_retries: 3,
545 retry_delay: Duration::from_secs(1),
546 };
547 let mut processor = AsyncProcessor::new(config, TestProcessor);
548 processor.start();
549
550 let mut receivers = Vec::new();
552 for i in 0..5 {
553 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
554 receivers.push(rx);
555 }
556
557 drop(processor);
559
560 for mut rx in receivers {
562 let result = rx.recv().await.unwrap();
563 assert_eq!(result.status, TaskStatus::Completed);
564 }
565 }
566}