use futures::future::join_all;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use super::ProcessingResult;
use super::{ProcessingMetrics, RateLimiter};
use crate::core::{Document, GraphRAGError, KnowledgeGraph};
#[derive(Debug)]
pub struct ConcurrentProcessor {
max_concurrent_documents: usize,
}
impl ConcurrentProcessor {
pub fn new(max_concurrent_documents: usize) -> Self {
Self {
max_concurrent_documents,
}
}
pub async fn process_batch(
&self,
documents: Vec<Document>,
graph: Arc<RwLock<KnowledgeGraph>>,
rate_limiter: Arc<RateLimiter>,
metrics: Arc<ProcessingMetrics>,
) -> Result<Vec<ProcessingResult>, GraphRAGError> {
if documents.is_empty() {
return Ok(Vec::new());
}
tracing::info!(
document_count = documents.len(),
max_concurrency = self.max_concurrent_documents,
"Processing documents batch"
);
let chunk_size = self.max_concurrent_documents;
let mut all_results = Vec::new();
let mut total_errors = 0;
for (chunk_idx, chunk) in documents.chunks(chunk_size).enumerate() {
tracing::debug!(
chunk_number = chunk_idx + 1,
chunk_size = chunk.len(),
"Processing chunk"
);
let chunk_start = Instant::now();
let tasks: Vec<_> = chunk
.iter()
.cloned()
.map(|document| {
let graph = Arc::clone(&graph);
let rate_limiter = Arc::clone(&rate_limiter);
let metrics = Arc::clone(&metrics);
let doc_id = document.id.clone();
tokio::spawn(async move {
let doc_start = Instant::now();
let _permit = match rate_limiter.acquire_llm_permit().await {
Ok(permit) => permit,
Err(e) => {
metrics.increment_rate_limit_errors();
return Err(e);
}
};
let result =
Self::process_single_document_internal(&graph, document, &metrics)
.await;
let duration = doc_start.elapsed();
match &result {
Ok(_) => {
tracing::debug!(document_id = %doc_id, duration_ms = duration.as_millis(), "Document completed");
metrics.record_document_processing_duration(duration);
}
Err(e) => {
tracing::warn!(document_id = %doc_id, duration_ms = duration.as_millis(), error = %e, "Document failed");
metrics.increment_document_processing_error();
}
}
result
})
})
.collect();
let chunk_results = join_all(tasks).await;
for (task_idx, task_result) in chunk_results.into_iter().enumerate() {
match task_result {
Ok(Ok(processing_result)) => {
all_results.push(processing_result);
metrics.increment_document_processing_success();
},
Ok(Err(processing_error)) => {
total_errors += 1;
tracing::error!(
chunk_number = chunk_idx + 1,
task_number = task_idx + 1,
error = %processing_error,
"Processing error"
);
},
Err(join_error) => {
total_errors += 1;
tracing::error!(
chunk_number = chunk_idx + 1,
task_number = task_idx + 1,
error = %join_error,
"Task join error"
);
},
}
}
let chunk_duration = chunk_start.elapsed();
tracing::debug!(
chunk_number = chunk_idx + 1,
duration_ms = chunk_duration.as_millis(),
successes = chunk.len() - total_errors.min(chunk.len()),
errors = total_errors.min(chunk.len()),
"Chunk completed"
);
if chunk_idx + 1 < documents.chunks(chunk_size).len() {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
}
if total_errors > 0 {
tracing::warn!(
total_errors = total_errors,
total_documents = documents.len(),
"Batch processing completed with errors"
);
}
Ok(all_results)
}
async fn process_single_document_internal(
graph: &Arc<RwLock<KnowledgeGraph>>,
document: Document,
_metrics: &ProcessingMetrics,
) -> Result<ProcessingResult, GraphRAGError> {
let start_time = Instant::now();
let result = {
let _graph_read = graph.read().await;
ProcessingResult {
document_id: document.id.clone(),
entities_extracted: 0, chunks_processed: document.chunks.len(),
processing_time: start_time.elapsed(),
success: true,
}
};
Ok(result)
}
}