1use std::{
2 fmt::Display,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6use anyhow::anyhow;
7use moduforge_state::debug;
8use tokio::sync::{mpsc, oneshot};
9use async_trait::async_trait;
10use tokio::select;
11
12use crate::EditorResult;
13
14#[derive(Debug, Clone, PartialEq)]
22pub enum TaskStatus {
23 Pending,
24 Processing,
25 Completed,
26 Failed(String),
27 Timeout,
28 Cancelled,
29}
30
31#[derive(Debug)]
39pub enum ProcessorError {
40 QueueFull,
41 TaskFailed(String),
42 InternalError(String),
43 TaskTimeout,
44 TaskCancelled,
45 RetryExhausted(String),
46}
47
48impl Display for ProcessorError {
49 fn fmt(
50 &self,
51 f: &mut std::fmt::Formatter<'_>,
52 ) -> std::fmt::Result {
53 match self {
54 ProcessorError::QueueFull => write!(f, "任务队列已满"),
55 ProcessorError::TaskFailed(msg) => {
56 write!(f, "任务执行失败: {}", msg)
57 },
58 ProcessorError::InternalError(msg) => {
59 write!(f, "内部错误: {}", msg)
60 },
61 ProcessorError::TaskTimeout => {
62 write!(f, "任务执行超时")
63 },
64 ProcessorError::TaskCancelled => write!(f, "任务被取消"),
65 ProcessorError::RetryExhausted(msg) => {
66 write!(f, "重试次数耗尽: {}", msg)
67 },
68 }
69 }
70}
71
72impl std::error::Error for ProcessorError {}
73
74#[derive(Clone, Debug)]
81pub struct ProcessorConfig {
82 pub max_queue_size: usize,
83 pub max_concurrent_tasks: usize,
84 pub task_timeout: Duration,
85 pub max_retries: u32,
86 pub retry_delay: Duration,
87}
88
89impl Default for ProcessorConfig {
90 fn default() -> Self {
91 Self {
92 max_queue_size: 1000,
93 max_concurrent_tasks: 10,
94 task_timeout: Duration::from_secs(30),
95 max_retries: 3,
96 retry_delay: Duration::from_secs(1),
97 }
98 }
99}
100
101#[derive(Debug, Default, Clone)]
111pub struct ProcessorStats {
112 pub total_tasks: u64,
113 pub completed_tasks: u64,
114 pub failed_tasks: u64,
115 pub timeout_tasks: u64,
116 pub cancelled_tasks: u64,
117 pub average_processing_time: Duration,
118 pub current_queue_size: usize,
119 pub current_processing_tasks: usize,
120}
121
122#[derive(Debug)]
130pub struct TaskResult<T, O>
131where
132 T: Send + Sync,
133 O: Send + Sync,
134{
135 pub task_id: u64,
136 pub status: TaskStatus,
137 pub task: Option<T>,
138 pub output: Option<O>,
139 pub error: Option<String>,
140 pub processing_time: Option<Duration>,
141}
142
143struct QueuedTask<T, O>
150where
151 T: Send + Sync,
152 O: Send + Sync,
153{
154 task: T,
155 task_id: u64,
156 result_tx: mpsc::Sender<TaskResult<T, O>>,
157 priority: u32,
158 retry_count: u32,
159}
160
161pub struct TaskQueue<T, O>
167where
168 T: Send + Sync,
169 O: Send + Sync,
170{
171 queue: mpsc::Sender<QueuedTask<T, O>>,
172 queue_rx: Arc<tokio::sync::Mutex<Option<mpsc::Receiver<QueuedTask<T, O>>>>>,
173 next_task_id: Arc<tokio::sync::Mutex<u64>>,
174 stats: Arc<tokio::sync::Mutex<ProcessorStats>>,
175}
176
177impl<T: Clone + Send + Sync + 'static, O: Clone + Send + Sync + 'static>
178 TaskQueue<T, O>
179{
180 pub fn new(config: &ProcessorConfig) -> Self {
181 let (tx, rx) = mpsc::channel(config.max_queue_size);
182 Self {
183 queue: tx,
184 queue_rx: Arc::new(tokio::sync::Mutex::new(Some(rx))),
185 next_task_id: Arc::new(tokio::sync::Mutex::new(0)),
186 stats: Arc::new(tokio::sync::Mutex::new(ProcessorStats::default())),
187 }
188 }
189
190 pub async fn enqueue_task(
191 &self,
192 task: T,
193 priority: u32,
194 ) -> EditorResult<(u64, mpsc::Receiver<TaskResult<T, O>>)> {
195 let mut task_id = self.next_task_id.lock().await;
196 *task_id += 1;
197 let current_id = *task_id;
198
199 let (result_tx, result_rx) = mpsc::channel(1);
200 let queued_task = QueuedTask {
201 task,
202 task_id: current_id,
203 result_tx,
204 priority,
205 retry_count: 0,
206 };
207
208 self.queue
209 .send(queued_task)
210 .await
211 .map_err(|_| anyhow!("队列已经满了"))?;
212
213 let mut stats = self.stats.lock().await;
214 stats.total_tasks += 1;
215 stats.current_queue_size += 1;
216
217 Ok((current_id, result_rx))
218 }
219
220 pub async fn get_next_ready(
221 &self
222 ) -> Option<(T, u64, mpsc::Sender<TaskResult<T, O>>, u32, u32)> {
223 let mut rx_guard = self.queue_rx.lock().await;
224 if let Some(rx) = rx_guard.as_mut() {
225 if let Some(queued) = rx.recv().await {
226 let mut stats: tokio::sync::MutexGuard<'_, ProcessorStats> =
227 self.stats.lock().await;
228 stats.current_queue_size -= 1;
229 stats.current_processing_tasks += 1;
230 return Some((
231 queued.task,
232 queued.task_id,
233 queued.result_tx,
234 queued.priority,
235 queued.retry_count,
236 ));
237 }
238 }
239 None
240 }
241
242 pub async fn get_stats(&self) -> ProcessorStats {
243 self.stats.lock().await.clone()
244 }
245
246 pub async fn update_stats(
247 &self,
248 result: &TaskResult<T, O>,
249 ) {
250 let mut stats = self.stats.lock().await;
251 match result.status {
252 TaskStatus::Completed => {
253 stats.completed_tasks += 1;
254 if let Some(processing_time) = result.processing_time {
255 stats.average_processing_time =
256 (stats.average_processing_time + processing_time) / 2;
257 }
258 },
259 TaskStatus::Failed(_) => stats.failed_tasks += 1,
260 TaskStatus::Timeout => stats.timeout_tasks += 1,
261 TaskStatus::Cancelled => stats.cancelled_tasks += 1,
262 _ => {},
263 }
264 stats.current_processing_tasks -= 1;
265 }
266}
267
268#[async_trait]
271pub trait TaskProcessor<T, O>: Send + Sync + 'static
272where
273 T: Clone + Send + Sync + 'static,
274 O: Clone + Send + Sync + 'static,
275{
276 async fn process(
277 &self,
278 task: T,
279 ) -> Result<O, ProcessorError>;
280}
281
282pub struct AsyncProcessor<T, O, P>
288where
289 T: Clone + Send + Sync + 'static,
290 O: Clone + Send + Sync + 'static,
291 P: TaskProcessor<T, O>,
292{
293 task_queue: Arc<TaskQueue<T, O>>,
294 config: ProcessorConfig,
295 processor: Arc<P>,
296 shutdown_tx: Option<oneshot::Sender<()>>,
297 handle: Option<tokio::task::JoinHandle<()>>,
298}
299
300impl<T, O, P> AsyncProcessor<T, O, P>
301where
302 T: Clone + Send + Sync + 'static,
303 O: Clone + Send + Sync + 'static,
304 P: TaskProcessor<T, O>,
305{
306 pub fn new(
308 config: ProcessorConfig,
309 processor: P,
310 ) -> Self {
311 let task_queue = Arc::new(TaskQueue::new(&config));
312 Self {
313 task_queue,
314 config,
315 processor: Arc::new(processor),
316 shutdown_tx: None,
317 handle: None,
318 }
319 }
320
321 pub async fn submit_task(
324 &self,
325 task: T,
326 priority: u32,
327 ) -> EditorResult<(u64, mpsc::Receiver<TaskResult<T, O>>)> {
328 self.task_queue.enqueue_task(task, priority).await
329 }
330
331 pub fn start(&mut self) {
334 let queue = self.task_queue.clone();
335 let processor = self.processor.clone();
336 let config = self.config.clone();
337 let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
338
339 self.shutdown_tx = Some(shutdown_tx);
340
341 let handle = tokio::spawn(async move {
342 let mut join_set = tokio::task::JoinSet::new();
343
344 loop {
345 select! {
346 _ = &mut shutdown_rx => {
348 break;
349 }
350
351 Some(result) = join_set.join_next() => {
353 if let Err(e) = result {
354 debug!("任务执行失败: {}", e);
355 }
356 }
357
358 Some((task, task_id, result_tx, _priority, retry_count)) = queue.get_next_ready() => {
360 if join_set.len() < config.max_concurrent_tasks {
361 let processor = processor.clone();
362 let config = config.clone();
363 let queue = queue.clone();
364
365 join_set.spawn(async move {
366 let start_time = Instant::now();
367 let mut current_retry = retry_count;
368
369 loop {
370 let result = tokio::time::timeout(
371 config.task_timeout,
372 processor.process(task.clone())
373 ).await;
374
375 match result {
376 Ok(Ok(output)) => {
377 let processing_time = start_time.elapsed();
378 let task_result = TaskResult {
379 task_id,
380 status: TaskStatus::Completed,
381 task: Some(task),
382 output: Some(output),
383 error: None,
384 processing_time: Some(processing_time),
385 };
386 queue.update_stats(&task_result).await;
387 let _ = result_tx.send(task_result).await;
388 break;
389 }
390 Ok(Err(e)) => {
391 if current_retry < config.max_retries {
392 current_retry += 1;
393 tokio::time::sleep(config.retry_delay).await;
394 continue;
395 }
396 let task_result = TaskResult {
397 task_id,
398 status: TaskStatus::Failed(e.to_string()),
399 task: Some(task),
400 output: None,
401 error: Some(e.to_string()),
402 processing_time: Some(start_time.elapsed()),
403 };
404 queue.update_stats(&task_result).await;
405 let _ = result_tx.send(task_result).await;
406 break;
407 }
408 Err(_) => {
409 let task_result = TaskResult {
410 task_id,
411 status: TaskStatus::Timeout,
412 task: Some(task),
413 output: None,
414 error: Some("任务执行超时".to_string()),
415 processing_time: Some(start_time.elapsed()),
416 };
417 queue.update_stats(&task_result).await;
418 let _ = result_tx.send(task_result).await;
419 break;
420 }
421 }
422 }
423 });
424 }
425 }
426 }
427 }
428 });
429
430 self.handle = Some(handle);
431 }
432
433 pub async fn shutdown(&mut self) -> Result<(), ProcessorError> {
436 if let Some(shutdown_tx) = self.shutdown_tx.take() {
437 shutdown_tx.send(()).map_err(|_| {
438 ProcessorError::InternalError(
439 "Failed to send shutdown signal".to_string(),
440 )
441 })?;
442
443 if let Some(handle) = self.handle.take() {
444 handle.await.map_err(|e| {
445 ProcessorError::InternalError(format!(
446 "Failed to join processor task: {}",
447 e
448 ))
449 })?;
450 }
451 }
452 Ok(())
453 }
454
455 pub async fn get_stats(&self) -> ProcessorStats {
456 self.task_queue.get_stats().await
457 }
458}
459
460impl<T, O, P> Drop for AsyncProcessor<T, O, P>
462where
463 T: Clone + Send + Sync + 'static,
464 O: Clone + Send + Sync + 'static,
465 P: TaskProcessor<T, O>,
466{
467 fn drop(&mut self) {
468 if self.shutdown_tx.is_some() {
469 let rt = tokio::runtime::Runtime::new().unwrap();
471 rt.block_on(self.shutdown()).unwrap();
472 }
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 struct TestProcessor;
481
482 #[async_trait::async_trait]
483 impl TaskProcessor<i32, String> for TestProcessor {
484 async fn process(
485 &self,
486 task: i32,
487 ) -> Result<String, ProcessorError> {
488 tokio::time::sleep(Duration::from_millis(100)).await;
489 Ok(format!("Processed: {}", task))
490 }
491 }
492
493 #[tokio::test]
494 async fn test_async_processor() {
495 let config = ProcessorConfig {
496 max_queue_size: 100,
497 max_concurrent_tasks: 5,
498 task_timeout: Duration::from_secs(1),
499 max_retries: 3,
500 retry_delay: Duration::from_secs(1),
501 };
502 let mut processor = AsyncProcessor::new(config, TestProcessor);
503 processor.start();
504
505 let mut receivers = Vec::new();
506 for i in 0..10 {
507 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
508 receivers.push(rx);
509 }
510
511 for mut rx in receivers {
512 let result = rx.recv().await.unwrap();
513 assert_eq!(result.status, TaskStatus::Completed);
514 assert!(result.error.is_none());
515 assert!(result.output.is_some());
516 }
517 }
518
519 #[tokio::test]
520 async fn test_processor_shutdown() {
521 let config = ProcessorConfig {
522 max_queue_size: 100,
523 max_concurrent_tasks: 5,
524 task_timeout: Duration::from_secs(1),
525 max_retries: 3,
526 retry_delay: Duration::from_secs(1),
527 };
528 let mut processor = AsyncProcessor::new(config, TestProcessor);
529 processor.start();
530
531 let mut receivers = Vec::new();
533 for i in 0..5 {
534 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
535 receivers.push(rx);
536 }
537
538 processor.shutdown().await.unwrap();
540
541 for mut rx in receivers {
543 let result = rx.recv().await.unwrap();
544 assert_eq!(result.status, TaskStatus::Completed);
545 }
546 }
547
548 #[tokio::test]
549 async fn test_processor_auto_shutdown() {
550 let config = ProcessorConfig {
551 max_queue_size: 100,
552 max_concurrent_tasks: 5,
553 task_timeout: Duration::from_secs(1),
554 max_retries: 3,
555 retry_delay: Duration::from_secs(1),
556 };
557 let mut processor = AsyncProcessor::new(config, TestProcessor);
558 processor.start();
559
560 let mut receivers = Vec::new();
562 for i in 0..5 {
563 let (_, rx) = processor.submit_task(i, 0).await.unwrap();
564 receivers.push(rx);
565 }
566
567 drop(processor);
569
570 for mut rx in receivers {
572 let result = rx.recv().await.unwrap();
573 assert_eq!(result.status, TaskStatus::Completed);
574 }
575 }
576}