html_translation_lib/pipeline/
concurrent_batch.rs1use crate::error::{TranslationError, TranslationResult};
6use crate::storage::memory_pool::GlobalMemoryManager;
7use std::collections::VecDeque;
8use std::sync::{Arc, Mutex};
9use std::time::{Duration, Instant};
10use tokio::sync::{Semaphore, RwLock};
11use tokio::time::timeout;
12use futures::stream::{self, StreamExt};
13
14#[derive(Debug, Clone)]
16pub struct BatchConfig {
17 pub batch_size: usize,
19 pub max_concurrency: usize,
21 pub timeout_duration: Duration,
23 pub max_retries: u32,
25 pub backpressure_threshold: usize,
27}
28
29impl Default for BatchConfig {
30 fn default() -> Self {
31 Self {
32 batch_size: 50,
33 max_concurrency: 10,
34 timeout_duration: Duration::from_secs(30),
35 max_retries: 3,
36 backpressure_threshold: 1000,
37 }
38 }
39}
40
41#[derive(Debug)]
43pub struct BatchTask {
44 pub id: u64,
45 pub texts: Vec<String>,
46 pub priority: Priority,
47 pub created_at: Instant,
48 pub retry_count: u32,
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
53pub enum Priority {
54 Low = 0,
55 Normal = 1,
56 High = 2,
57 Critical = 3,
58}
59
60#[derive(Debug)]
62pub struct BatchResult {
63 pub task_id: u64,
64 pub results: Vec<String>,
65 pub processing_time: Duration,
66 pub success: bool,
67 pub error: Option<TranslationError>,
68}
69
70pub struct ConcurrentBatchProcessor {
72 config: BatchConfig,
73 memory_manager: Arc<GlobalMemoryManager>,
74 task_queue: Arc<Mutex<VecDeque<BatchTask>>>,
76 concurrency_limiter: Arc<Semaphore>,
78 stats: Arc<RwLock<BatchProcessingStats>>,
80 next_task_id: Arc<Mutex<u64>>,
82}
83
84impl ConcurrentBatchProcessor {
85 pub fn new(config: BatchConfig, memory_manager: Arc<GlobalMemoryManager>) -> Self {
87 let concurrency_limiter = Arc::new(Semaphore::new(config.max_concurrency));
88
89 Self {
90 config,
91 memory_manager,
92 task_queue: Arc::new(Mutex::new(VecDeque::new())),
93 concurrency_limiter,
94 stats: Arc::new(RwLock::new(BatchProcessingStats::default())),
95 next_task_id: Arc::new(Mutex::new(0)),
96 }
97 }
98
99 pub async fn submit_batch(
101 &self,
102 texts: Vec<String>,
103 priority: Priority,
104 ) -> TranslationResult<u64> {
105 {
107 let queue = self.task_queue.lock().unwrap();
108 if queue.len() > self.config.backpressure_threshold {
109 return Err(TranslationError::InternalError(
110 "任务队列已满,请稍后重试".to_string()
111 ));
112 }
113 }
114
115 let task_id = {
116 let mut counter = self.next_task_id.lock().unwrap();
117 *counter += 1;
118 *counter
119 };
120
121 let task = BatchTask {
122 id: task_id,
123 texts,
124 priority,
125 created_at: Instant::now(),
126 retry_count: 0,
127 };
128
129 {
131 let mut queue = self.task_queue.lock().unwrap();
132 let insert_pos = queue
133 .iter()
134 .position(|t| t.priority < priority)
135 .unwrap_or(queue.len());
136 queue.insert(insert_pos, task);
137 }
138
139 {
141 let mut stats = self.stats.write().await;
142 stats.tasks_submitted += 1;
143 }
144
145 Ok(task_id)
146 }
147
148 pub async fn process_queue<F, Fut>(
150 &self,
151 translation_fn: F,
152 ) -> TranslationResult<Vec<BatchResult>>
153 where
154 F: Fn(Vec<String>) -> Fut + Send + Sync + Clone + 'static,
155 Fut: std::future::Future<Output = TranslationResult<Vec<String>>> + Send + 'static,
156 {
157 let mut results = Vec::new();
158 let mut handles = Vec::new();
159
160 for _ in 0..self.config.max_concurrency {
162 let processor = self.clone_for_worker();
163 let translation_fn = translation_fn.clone();
164
165 let handle = tokio::spawn(async move {
166 processor.worker_loop(translation_fn).await
167 });
168
169 handles.push(handle);
170 }
171
172 let timeout_duration = self.config.timeout_duration * 2; match timeout(timeout_duration, futures::future::join_all(handles)).await {
176 Ok(worker_results) => {
177 for worker_result in worker_results {
178 match worker_result {
179 Ok(batch_results) => results.extend(batch_results),
180 Err(e) => {
181 return Err(TranslationError::InternalError(
182 format!("工作线程错误: {e}")
183 ));
184 }
185 }
186 }
187 }
188 Err(_) => {
189 return Err(TranslationError::TimeoutError(
190 "批处理队列处理超时".to_string()
191 ));
192 }
193 }
194
195 Ok(results)
196 }
197
198 async fn worker_loop<F, Fut>(&self, translation_fn: F) -> Vec<BatchResult>
200 where
201 F: Fn(Vec<String>) -> Fut + Send + Sync,
202 Fut: std::future::Future<Output = TranslationResult<Vec<String>>> + Send,
203 {
204 let mut results = Vec::new();
205
206 loop {
207 let task = {
209 let mut queue = self.task_queue.lock().unwrap();
210 queue.pop_front()
211 };
212
213 let task = match task {
214 Some(task) => task,
215 None => break, };
217
218 let _permit = match self.concurrency_limiter.try_acquire() {
220 Ok(permit) => permit,
221 Err(_) => {
222 let mut queue = self.task_queue.lock().unwrap();
224 queue.push_front(task);
225 break;
226 }
227 };
228
229 let result = self.process_single_task(task, &translation_fn).await;
231 results.push(result);
232 }
233
234 results
235 }
236
237 async fn process_single_task<F, Fut>(
239 &self,
240 task: BatchTask,
241 translation_fn: &F,
242 ) -> BatchResult
243 where
244 F: Fn(Vec<String>) -> Fut + Send + Sync,
245 Fut: std::future::Future<Output = TranslationResult<Vec<String>>> + Send,
246 {
247 let start_time = Instant::now();
248 let task_id = task.id;
249
250 let batches = self.split_into_batches(task.texts);
252 let mut all_results = Vec::new();
253 let mut has_error = None;
254
255 let batch_futures = batches.into_iter().map(|batch| {
257 let translation_fn = translation_fn;
258 async move {
259 let result = timeout(
260 self.config.timeout_duration,
261 translation_fn(batch)
262 ).await;
263
264 match result {
265 Ok(Ok(translations)) => Ok(translations),
266 Ok(Err(e)) => Err(e),
267 Err(_) => Err(TranslationError::TimeoutError(
268 "批次处理超时".to_string()
269 )),
270 }
271 }
272 });
273
274 let batch_results: Vec<_> = stream::iter(batch_futures)
275 .buffer_unordered(self.config.max_concurrency.min(4))
276 .collect()
277 .await;
278
279 for result in batch_results {
281 match result {
282 Ok(translations) => all_results.extend(translations),
283 Err(e) => {
284 has_error = Some(e);
285 break;
286 }
287 }
288 }
289
290 let processing_time = start_time.elapsed();
291 let success = has_error.is_none();
292
293 {
295 let mut stats = self.stats.write().await;
296 stats.tasks_processed += 1;
297 stats.total_processing_time += processing_time;
298
299 if success {
300 stats.tasks_succeeded += 1;
301 } else {
302 stats.tasks_failed += 1;
303 }
304
305 if processing_time > stats.max_processing_time {
306 stats.max_processing_time = processing_time;
307 }
308
309 if processing_time < stats.min_processing_time {
310 stats.min_processing_time = processing_time;
311 }
312 }
313
314 BatchResult {
315 task_id,
316 results: all_results,
317 processing_time,
318 success,
319 error: has_error,
320 }
321 }
322
323 fn split_into_batches(&self, texts: Vec<String>) -> Vec<Vec<String>> {
325 texts
326 .chunks(self.config.batch_size)
327 .map(|chunk| chunk.to_vec())
328 .collect()
329 }
330
331 fn clone_for_worker(&self) -> Self {
333 Self {
334 config: self.config.clone(),
335 memory_manager: Arc::clone(&self.memory_manager),
336 task_queue: Arc::clone(&self.task_queue),
337 concurrency_limiter: Arc::clone(&self.concurrency_limiter),
338 stats: Arc::clone(&self.stats),
339 next_task_id: Arc::clone(&self.next_task_id),
340 }
341 }
342
343 pub async fn get_stats(&self) -> BatchProcessingStats {
345 self.stats.read().await.clone()
346 }
347
348 pub fn get_queue_status(&self) -> QueueStatus {
350 let queue = self.task_queue.lock().unwrap();
351 let available_permits = self.concurrency_limiter.available_permits();
352
353 let mut priority_counts = [0; 4];
354 for task in queue.iter() {
355 priority_counts[task.priority as usize] += 1;
356 }
357
358 QueueStatus {
359 total_tasks: queue.len(),
360 priority_distribution: priority_counts,
361 available_workers: available_permits,
362 oldest_task_age: queue
363 .front()
364 .map(|task| task.created_at.elapsed())
365 .unwrap_or(Duration::ZERO),
366 }
367 }
368
369 pub fn clear_queue(&self) -> usize {
371 let mut queue = self.task_queue.lock().unwrap();
372 let count = queue.len();
373 queue.clear();
374 count
375 }
376}
377
378#[derive(Debug, Clone, Default)]
380pub struct BatchProcessingStats {
381 pub tasks_submitted: u64,
382 pub tasks_processed: u64,
383 pub tasks_succeeded: u64,
384 pub tasks_failed: u64,
385 pub total_processing_time: Duration,
386 pub min_processing_time: Duration,
387 pub max_processing_time: Duration,
388}
389
390impl BatchProcessingStats {
391 pub fn success_rate(&self) -> f64 {
393 if self.tasks_processed == 0 {
394 return 0.0;
395 }
396 self.tasks_succeeded as f64 / self.tasks_processed as f64
397 }
398
399 pub fn average_processing_time(&self) -> Duration {
401 if self.tasks_processed == 0 {
402 return Duration::ZERO;
403 }
404 self.total_processing_time / self.tasks_processed as u32
405 }
406
407 pub fn throughput(&self) -> f64 {
409 if self.total_processing_time.is_zero() {
410 return 0.0;
411 }
412 self.tasks_processed as f64 / self.total_processing_time.as_secs_f64()
413 }
414}
415
416#[derive(Debug, Clone)]
418pub struct QueueStatus {
419 pub total_tasks: usize,
420 pub priority_distribution: [usize; 4], pub available_workers: usize,
422 pub oldest_task_age: Duration,
423}
424
425pub struct AdaptiveBatchProcessor {
427 inner: ConcurrentBatchProcessor,
428 config: Arc<RwLock<BatchConfig>>,
429 performance_monitor: PerformanceMonitor,
430}
431
432impl AdaptiveBatchProcessor {
433 pub fn new(
435 initial_config: BatchConfig,
436 memory_manager: Arc<GlobalMemoryManager>,
437 ) -> Self {
438 let config = Arc::new(RwLock::new(initial_config.clone()));
439 let inner = ConcurrentBatchProcessor::new(initial_config, memory_manager);
440 let performance_monitor = PerformanceMonitor::new();
441
442 Self {
443 inner,
444 config,
445 performance_monitor,
446 }
447 }
448
449 pub async fn adapt_configuration(&self) {
451 let stats = self.inner.get_stats().await;
452 let queue_status = self.inner.get_queue_status();
453
454 let mut config = self.config.write().await;
455
456 if stats.success_rate() < 0.8 && config.max_retries < 5 {
458 config.max_retries += 1;
459 } else if stats.success_rate() > 0.95 && config.max_retries > 1 {
460 config.max_retries = (config.max_retries - 1).max(1);
461 }
462
463 if queue_status.total_tasks > config.backpressure_threshold / 2 {
465 config.max_concurrency = (config.max_concurrency + 2).min(20);
466 } else if queue_status.total_tasks < 10 && config.max_concurrency > 5 {
467 config.max_concurrency = (config.max_concurrency - 1).max(5);
468 }
469
470 let avg_time = stats.average_processing_time();
472 if avg_time > config.timeout_duration {
473 config.timeout_duration = avg_time + Duration::from_secs(5);
474 }
475 }
476}
477
478#[derive(Debug)]
480pub struct PerformanceMonitor {
481 start_time: Instant,
482}
483
484impl PerformanceMonitor {
485 fn new() -> Self {
486 Self {
487 start_time: Instant::now(),
488 }
489 }
490
491 pub fn uptime(&self) -> Duration {
493 self.start_time.elapsed()
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500
501 async fn mock_translation_fn(texts: Vec<String>) -> TranslationResult<Vec<String>> {
502 tokio::time::sleep(Duration::from_millis(10)).await;
504
505 let translations = texts.iter()
507 .map(|text| format!("translated_{text}"))
508 .collect();
509
510 Ok(translations)
511 }
512
513 #[tokio::test]
514 async fn test_concurrent_batch_processor() {
515 let config = BatchConfig::default();
516 let memory_manager = Arc::new(GlobalMemoryManager::new());
517 let processor = ConcurrentBatchProcessor::new(config, memory_manager);
518
519 let task1 = vec!["Hello".to_string(), "World".to_string()];
521 let task2 = vec!["Foo".to_string(), "Bar".to_string()];
522
523 let id1 = processor.submit_batch(task1, Priority::Normal).await.unwrap();
524 let id2 = processor.submit_batch(task2, Priority::High).await.unwrap();
525
526 assert!(id1 != id2);
527
528 let results = processor.process_queue(mock_translation_fn).await.unwrap();
530
531 assert_eq!(results.len(), 2);
532
533 let stats = processor.get_stats().await;
535 assert_eq!(stats.tasks_processed, 2);
536 }
537
538 #[tokio::test]
539 async fn test_priority_ordering() {
540 let config = BatchConfig::default();
541 let memory_manager = Arc::new(GlobalMemoryManager::new());
542 let processor = ConcurrentBatchProcessor::new(config, memory_manager);
543
544 processor.submit_batch(vec!["low".to_string()], Priority::Low).await.unwrap();
546 processor.submit_batch(vec!["high".to_string()], Priority::High).await.unwrap();
547 processor.submit_batch(vec!["normal".to_string()], Priority::Normal).await.unwrap();
548
549 let queue_status = processor.get_queue_status();
550 assert_eq!(queue_status.total_tasks, 3);
551
552 let queue = processor.task_queue.lock().unwrap();
554 assert_eq!(queue[0].priority, Priority::High);
555 assert_eq!(queue[1].priority, Priority::Normal);
556 assert_eq!(queue[2].priority, Priority::Low);
557 }
558
559 #[test]
560 fn test_batch_processing_stats() {
561 let mut stats = BatchProcessingStats::default();
562 stats.tasks_processed = 100;
563 stats.tasks_succeeded = 95;
564 stats.total_processing_time = Duration::from_secs(50);
565
566 assert_eq!(stats.success_rate(), 0.95);
567 assert_eq!(stats.average_processing_time(), Duration::from_millis(500));
568 assert_eq!(stats.throughput(), 2.0);
569 }
570}