html_translation_lib/pipeline/
concurrent_batch.rs

1//! 并发批处理模块
2//!
3//! 提供高性能的并行文本翻译处理能力
4
5use crate::error::{TranslationError, TranslationResult};
6use crate::storage::memory_pool::GlobalMemoryManager;
7use std::collections::VecDeque;
8use std::sync::{Arc, Mutex};
9use std::time::{Duration, Instant};
10use tokio::sync::{Semaphore, RwLock};
11use tokio::time::timeout;
12use futures::stream::{self, StreamExt};
13
14/// 批处理配置
15#[derive(Debug, Clone)]
16pub struct BatchConfig {
17    /// 批大小
18    pub batch_size: usize,
19    /// 最大并发数
20    pub max_concurrency: usize,
21    /// 处理超时时间
22    pub timeout_duration: Duration,
23    /// 重试次数
24    pub max_retries: u32,
25    /// 背压阈值
26    pub backpressure_threshold: usize,
27}
28
29impl Default for BatchConfig {
30    fn default() -> Self {
31        Self {
32            batch_size: 50,
33            max_concurrency: 10,
34            timeout_duration: Duration::from_secs(30),
35            max_retries: 3,
36            backpressure_threshold: 1000,
37        }
38    }
39}
40
41/// 批处理任务
42#[derive(Debug)]
43pub struct BatchTask {
44    pub id: u64,
45    pub texts: Vec<String>,
46    pub priority: Priority,
47    pub created_at: Instant,
48    pub retry_count: u32,
49}
50
51/// 任务优先级
52#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
53pub enum Priority {
54    Low = 0,
55    Normal = 1,
56    High = 2,
57    Critical = 3,
58}
59
60/// 批处理结果
61#[derive(Debug)]
62pub struct BatchResult {
63    pub task_id: u64,
64    pub results: Vec<String>,
65    pub processing_time: Duration,
66    pub success: bool,
67    pub error: Option<TranslationError>,
68}
69
70/// 并发批处理器
71pub struct ConcurrentBatchProcessor {
72    config: BatchConfig,
73    memory_manager: Arc<GlobalMemoryManager>,
74    /// 任务队列(按优先级排序)
75    task_queue: Arc<Mutex<VecDeque<BatchTask>>>,
76    /// 信号量控制并发数
77    concurrency_limiter: Arc<Semaphore>,
78    /// 统计信息
79    stats: Arc<RwLock<BatchProcessingStats>>,
80    /// 任务ID计数器
81    next_task_id: Arc<Mutex<u64>>,
82}
83
84impl ConcurrentBatchProcessor {
85    /// 创建新的并发批处理器
86    pub fn new(config: BatchConfig, memory_manager: Arc<GlobalMemoryManager>) -> Self {
87        let concurrency_limiter = Arc::new(Semaphore::new(config.max_concurrency));
88        
89        Self {
90            config,
91            memory_manager,
92            task_queue: Arc::new(Mutex::new(VecDeque::new())),
93            concurrency_limiter,
94            stats: Arc::new(RwLock::new(BatchProcessingStats::default())),
95            next_task_id: Arc::new(Mutex::new(0)),
96        }
97    }
98    
99    /// 提交批处理任务
100    pub async fn submit_batch(
101        &self,
102        texts: Vec<String>,
103        priority: Priority,
104    ) -> TranslationResult<u64> {
105        // 检查背压
106        {
107            let queue = self.task_queue.lock().unwrap();
108            if queue.len() > self.config.backpressure_threshold {
109                return Err(TranslationError::InternalError(
110                    "任务队列已满,请稍后重试".to_string()
111                ));
112            }
113        }
114        
115        let task_id = {
116            let mut counter = self.next_task_id.lock().unwrap();
117            *counter += 1;
118            *counter
119        };
120        
121        let task = BatchTask {
122            id: task_id,
123            texts,
124            priority,
125            created_at: Instant::now(),
126            retry_count: 0,
127        };
128        
129        // 插入到优先级队列中
130        {
131            let mut queue = self.task_queue.lock().unwrap();
132            let insert_pos = queue
133                .iter()
134                .position(|t| t.priority < priority)
135                .unwrap_or(queue.len());
136            queue.insert(insert_pos, task);
137        }
138        
139        // 更新统计
140        {
141            let mut stats = self.stats.write().await;
142            stats.tasks_submitted += 1;
143        }
144        
145        Ok(task_id)
146    }
147    
148    /// 并行处理批次队列
149    pub async fn process_queue<F, Fut>(
150        &self,
151        translation_fn: F,
152    ) -> TranslationResult<Vec<BatchResult>>
153    where
154        F: Fn(Vec<String>) -> Fut + Send + Sync + Clone + 'static,
155        Fut: std::future::Future<Output = TranslationResult<Vec<String>>> + Send + 'static,
156    {
157        let mut results = Vec::new();
158        let mut handles = Vec::new();
159        
160        // 启动工作线程
161        for _ in 0..self.config.max_concurrency {
162            let processor = self.clone_for_worker();
163            let translation_fn = translation_fn.clone();
164            
165            let handle = tokio::spawn(async move {
166                processor.worker_loop(translation_fn).await
167            });
168            
169            handles.push(handle);
170        }
171        
172        // 等待所有工作线程完成或超时
173        let timeout_duration = self.config.timeout_duration * 2; // 给工作线程更多时间
174        
175        match timeout(timeout_duration, futures::future::join_all(handles)).await {
176            Ok(worker_results) => {
177                for worker_result in worker_results {
178                    match worker_result {
179                        Ok(batch_results) => results.extend(batch_results),
180                        Err(e) => {
181                            return Err(TranslationError::InternalError(
182                                format!("工作线程错误: {e}")
183                            ));
184                        }
185                    }
186                }
187            }
188            Err(_) => {
189                return Err(TranslationError::TimeoutError(
190                    "批处理队列处理超时".to_string()
191                ));
192            }
193        }
194        
195        Ok(results)
196    }
197    
198    /// 工作线程循环
199    async fn worker_loop<F, Fut>(&self, translation_fn: F) -> Vec<BatchResult>
200    where
201        F: Fn(Vec<String>) -> Fut + Send + Sync,
202        Fut: std::future::Future<Output = TranslationResult<Vec<String>>> + Send,
203    {
204        let mut results = Vec::new();
205        
206        loop {
207            // 获取下一个任务
208            let task = {
209                let mut queue = self.task_queue.lock().unwrap();
210                queue.pop_front()
211            };
212            
213            let task = match task {
214                Some(task) => task,
215                None => break, // 队列为空,退出
216            };
217            
218            // 获取并发许可
219            let _permit = match self.concurrency_limiter.try_acquire() {
220                Ok(permit) => permit,
221                Err(_) => {
222                    // 无法获取许可,重新放回队列
223                    let mut queue = self.task_queue.lock().unwrap();
224                    queue.push_front(task);
225                    break;
226                }
227            };
228            
229            // 处理任务
230            let result = self.process_single_task(task, &translation_fn).await;
231            results.push(result);
232        }
233        
234        results
235    }
236    
237    /// 处理单个任务
238    async fn process_single_task<F, Fut>(
239        &self,
240        task: BatchTask,
241        translation_fn: &F,
242    ) -> BatchResult
243    where
244        F: Fn(Vec<String>) -> Fut + Send + Sync,
245        Fut: std::future::Future<Output = TranslationResult<Vec<String>>> + Send,
246    {
247        let start_time = Instant::now();
248        let task_id = task.id;
249        
250        // 分割任务为更小的批次
251        let batches = self.split_into_batches(task.texts);
252        let mut all_results = Vec::new();
253        let mut has_error = None;
254        
255        // 并行处理所有批次
256        let batch_futures = batches.into_iter().map(|batch| {
257            let translation_fn = translation_fn;
258            async move {
259                let result = timeout(
260                    self.config.timeout_duration,
261                    translation_fn(batch)
262                ).await;
263                
264                match result {
265                    Ok(Ok(translations)) => Ok(translations),
266                    Ok(Err(e)) => Err(e),
267                    Err(_) => Err(TranslationError::TimeoutError(
268                        "批次处理超时".to_string()
269                    )),
270                }
271            }
272        });
273        
274        let batch_results: Vec<_> = stream::iter(batch_futures)
275            .buffer_unordered(self.config.max_concurrency.min(4))
276            .collect()
277            .await;
278        
279        // 收集结果
280        for result in batch_results {
281            match result {
282                Ok(translations) => all_results.extend(translations),
283                Err(e) => {
284                    has_error = Some(e);
285                    break;
286                }
287            }
288        }
289        
290        let processing_time = start_time.elapsed();
291        let success = has_error.is_none();
292        
293        // 更新统计
294        {
295            let mut stats = self.stats.write().await;
296            stats.tasks_processed += 1;
297            stats.total_processing_time += processing_time;
298            
299            if success {
300                stats.tasks_succeeded += 1;
301            } else {
302                stats.tasks_failed += 1;
303            }
304            
305            if processing_time > stats.max_processing_time {
306                stats.max_processing_time = processing_time;
307            }
308            
309            if processing_time < stats.min_processing_time {
310                stats.min_processing_time = processing_time;
311            }
312        }
313        
314        BatchResult {
315            task_id,
316            results: all_results,
317            processing_time,
318            success,
319            error: has_error,
320        }
321    }
322    
323    /// 将大任务分割为小批次
324    fn split_into_batches(&self, texts: Vec<String>) -> Vec<Vec<String>> {
325        texts
326            .chunks(self.config.batch_size)
327            .map(|chunk| chunk.to_vec())
328            .collect()
329    }
330    
331    /// 为工作线程创建克隆(共享状态)
332    fn clone_for_worker(&self) -> Self {
333        Self {
334            config: self.config.clone(),
335            memory_manager: Arc::clone(&self.memory_manager),
336            task_queue: Arc::clone(&self.task_queue),
337            concurrency_limiter: Arc::clone(&self.concurrency_limiter),
338            stats: Arc::clone(&self.stats),
339            next_task_id: Arc::clone(&self.next_task_id),
340        }
341    }
342    
343    /// 获取处理统计信息
344    pub async fn get_stats(&self) -> BatchProcessingStats {
345        self.stats.read().await.clone()
346    }
347    
348    /// 获取队列状态
349    pub fn get_queue_status(&self) -> QueueStatus {
350        let queue = self.task_queue.lock().unwrap();
351        let available_permits = self.concurrency_limiter.available_permits();
352        
353        let mut priority_counts = [0; 4];
354        for task in queue.iter() {
355            priority_counts[task.priority as usize] += 1;
356        }
357        
358        QueueStatus {
359            total_tasks: queue.len(),
360            priority_distribution: priority_counts,
361            available_workers: available_permits,
362            oldest_task_age: queue
363                .front()
364                .map(|task| task.created_at.elapsed())
365                .unwrap_or(Duration::ZERO),
366        }
367    }
368    
369    /// 清空队列
370    pub fn clear_queue(&self) -> usize {
371        let mut queue = self.task_queue.lock().unwrap();
372        let count = queue.len();
373        queue.clear();
374        count
375    }
376}
377
378/// 批处理统计信息
379#[derive(Debug, Clone, Default)]
380pub struct BatchProcessingStats {
381    pub tasks_submitted: u64,
382    pub tasks_processed: u64,
383    pub tasks_succeeded: u64,
384    pub tasks_failed: u64,
385    pub total_processing_time: Duration,
386    pub min_processing_time: Duration,
387    pub max_processing_time: Duration,
388}
389
390impl BatchProcessingStats {
391    /// 计算成功率
392    pub fn success_rate(&self) -> f64 {
393        if self.tasks_processed == 0 {
394            return 0.0;
395        }
396        self.tasks_succeeded as f64 / self.tasks_processed as f64
397    }
398    
399    /// 计算平均处理时间
400    pub fn average_processing_time(&self) -> Duration {
401        if self.tasks_processed == 0 {
402            return Duration::ZERO;
403        }
404        self.total_processing_time / self.tasks_processed as u32
405    }
406    
407    /// 计算吞吐量(任务/秒)
408    pub fn throughput(&self) -> f64 {
409        if self.total_processing_time.is_zero() {
410            return 0.0;
411        }
412        self.tasks_processed as f64 / self.total_processing_time.as_secs_f64()
413    }
414}
415
416/// 队列状态信息
417#[derive(Debug, Clone)]
418pub struct QueueStatus {
419    pub total_tasks: usize,
420    pub priority_distribution: [usize; 4], // [Low, Normal, High, Critical]
421    pub available_workers: usize,
422    pub oldest_task_age: Duration,
423}
424
425/// 自适应批处理器(根据系统负载调整参数)
426pub struct AdaptiveBatchProcessor {
427    inner: ConcurrentBatchProcessor,
428    config: Arc<RwLock<BatchConfig>>,
429    performance_monitor: PerformanceMonitor,
430}
431
432impl AdaptiveBatchProcessor {
433    /// 创建自适应批处理器
434    pub fn new(
435        initial_config: BatchConfig,
436        memory_manager: Arc<GlobalMemoryManager>,
437    ) -> Self {
438        let config = Arc::new(RwLock::new(initial_config.clone()));
439        let inner = ConcurrentBatchProcessor::new(initial_config, memory_manager);
440        let performance_monitor = PerformanceMonitor::new();
441        
442        Self {
443            inner,
444            config,
445            performance_monitor,
446        }
447    }
448    
449    /// 自适应调整配置
450    pub async fn adapt_configuration(&self) {
451        let stats = self.inner.get_stats().await;
452        let queue_status = self.inner.get_queue_status();
453        
454        let mut config = self.config.write().await;
455        
456        // 根据成功率调整重试次数
457        if stats.success_rate() < 0.8 && config.max_retries < 5 {
458            config.max_retries += 1;
459        } else if stats.success_rate() > 0.95 && config.max_retries > 1 {
460            config.max_retries = (config.max_retries - 1).max(1);
461        }
462        
463        // 根据队列长度调整并发数
464        if queue_status.total_tasks > config.backpressure_threshold / 2 {
465            config.max_concurrency = (config.max_concurrency + 2).min(20);
466        } else if queue_status.total_tasks < 10 && config.max_concurrency > 5 {
467            config.max_concurrency = (config.max_concurrency - 1).max(5);
468        }
469        
470        // 根据平均处理时间调整超时
471        let avg_time = stats.average_processing_time();
472        if avg_time > config.timeout_duration {
473            config.timeout_duration = avg_time + Duration::from_secs(5);
474        }
475    }
476}
477
478/// 性能监控器
479#[derive(Debug)]
480pub struct PerformanceMonitor {
481    start_time: Instant,
482}
483
484impl PerformanceMonitor {
485    fn new() -> Self {
486        Self {
487            start_time: Instant::now(),
488        }
489    }
490    
491    /// 获取运行时间
492    pub fn uptime(&self) -> Duration {
493        self.start_time.elapsed()
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500    
501    async fn mock_translation_fn(texts: Vec<String>) -> TranslationResult<Vec<String>> {
502        // 模拟翻译延迟
503        tokio::time::sleep(Duration::from_millis(10)).await;
504        
505        // 简单的模拟翻译
506        let translations = texts.iter()
507            .map(|text| format!("translated_{text}"))
508            .collect();
509        
510        Ok(translations)
511    }
512    
513    #[tokio::test]
514    async fn test_concurrent_batch_processor() {
515        let config = BatchConfig::default();
516        let memory_manager = Arc::new(GlobalMemoryManager::new());
517        let processor = ConcurrentBatchProcessor::new(config, memory_manager);
518        
519        // 提交一些任务
520        let task1 = vec!["Hello".to_string(), "World".to_string()];
521        let task2 = vec!["Foo".to_string(), "Bar".to_string()];
522        
523        let id1 = processor.submit_batch(task1, Priority::Normal).await.unwrap();
524        let id2 = processor.submit_batch(task2, Priority::High).await.unwrap();
525        
526        assert!(id1 != id2);
527        
528        // 处理队列
529        let results = processor.process_queue(mock_translation_fn).await.unwrap();
530        
531        assert_eq!(results.len(), 2);
532        
533        // 检查统计信息
534        let stats = processor.get_stats().await;
535        assert_eq!(stats.tasks_processed, 2);
536    }
537    
538    #[tokio::test]
539    async fn test_priority_ordering() {
540        let config = BatchConfig::default();
541        let memory_manager = Arc::new(GlobalMemoryManager::new());
542        let processor = ConcurrentBatchProcessor::new(config, memory_manager);
543        
544        // 按不同优先级提交任务
545        processor.submit_batch(vec!["low".to_string()], Priority::Low).await.unwrap();
546        processor.submit_batch(vec!["high".to_string()], Priority::High).await.unwrap();
547        processor.submit_batch(vec!["normal".to_string()], Priority::Normal).await.unwrap();
548        
549        let queue_status = processor.get_queue_status();
550        assert_eq!(queue_status.total_tasks, 3);
551        
552        // 高优先级任务应该排在前面
553        let queue = processor.task_queue.lock().unwrap();
554        assert_eq!(queue[0].priority, Priority::High);
555        assert_eq!(queue[1].priority, Priority::Normal);
556        assert_eq!(queue[2].priority, Priority::Low);
557    }
558    
559    #[test]
560    fn test_batch_processing_stats() {
561        let mut stats = BatchProcessingStats::default();
562        stats.tasks_processed = 100;
563        stats.tasks_succeeded = 95;
564        stats.total_processing_time = Duration::from_secs(50);
565        
566        assert_eq!(stats.success_rate(), 0.95);
567        assert_eq!(stats.average_processing_time(), Duration::from_millis(500));
568        assert_eq!(stats.throughput(), 2.0);
569    }
570}