use indexmap::IndexMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use crate::core::{Document, GraphRAGError, KnowledgeGraph};
pub mod concurrent_pipeline;
pub mod monitoring;
pub mod rate_limiting;
pub use concurrent_pipeline::ConcurrentProcessor;
pub use monitoring::ProcessingMetrics;
pub use rate_limiting::RateLimiter;
#[derive(Debug, Clone)]
pub struct ProcessingResult {
pub document_id: crate::core::DocumentId,
pub entities_extracted: usize,
pub chunks_processed: usize,
pub processing_time: Duration,
pub success: bool,
}
#[derive(Debug, Clone)]
pub struct AsyncConfig {
pub max_concurrent_llm_calls: usize,
pub max_concurrent_embeddings: usize,
pub max_concurrent_documents: usize,
pub llm_rate_limit_per_second: f64,
pub embedding_rate_limit_per_second: f64,
}
impl Default for AsyncConfig {
fn default() -> Self {
Self {
max_concurrent_llm_calls: 3,
max_concurrent_embeddings: 5,
max_concurrent_documents: 10,
llm_rate_limit_per_second: 2.0,
embedding_rate_limit_per_second: 10.0,
}
}
}
#[derive(Debug)]
pub struct AsyncGraphRAGCore {
graph: Arc<RwLock<KnowledgeGraph>>,
rate_limiter: Arc<RateLimiter>,
concurrent_processor: Arc<ConcurrentProcessor>,
metrics: Arc<ProcessingMetrics>,
config: AsyncConfig,
}
impl AsyncGraphRAGCore {
pub async fn new(graph: KnowledgeGraph, config: AsyncConfig) -> Result<Self, GraphRAGError> {
let rate_limiter = Arc::new(RateLimiter::new(&config));
let concurrent_processor =
Arc::new(ConcurrentProcessor::new(config.max_concurrent_documents));
let metrics = Arc::new(ProcessingMetrics::new());
Ok(Self {
graph: Arc::new(RwLock::new(graph)),
rate_limiter,
concurrent_processor,
metrics,
config,
})
}
pub async fn process_documents_async(
&self,
documents: Vec<Document>,
) -> Result<Vec<ProcessingResult>, GraphRAGError> {
let start_time = Instant::now();
self.metrics.increment_batch_processing_started();
tracing::info!(
document_count = documents.len(),
"Processing documents concurrently"
);
let results = self
.concurrent_processor
.process_batch(
documents,
Arc::clone(&self.graph),
Arc::clone(&self.rate_limiter),
Arc::clone(&self.metrics),
)
.await?;
let duration = start_time.elapsed();
self.metrics.record_batch_processing_duration(duration);
tracing::info!(
duration_ms = duration.as_millis(),
successes = results.len(),
"Batch processing completed"
);
Ok(results)
}
pub async fn process_single_document_async(
&self,
document: Document,
) -> Result<ProcessingResult, GraphRAGError> {
let start_time = Instant::now();
self.metrics.increment_document_processing_started();
let _llm_permit = self.rate_limiter.acquire_llm_permit().await?;
let result = {
let _graph = self.graph.read().await;
ProcessingResult {
document_id: document.id.clone(),
entities_extracted: 0,
chunks_processed: document.chunks.len(),
processing_time: start_time.elapsed(),
success: true,
}
};
let duration = start_time.elapsed();
if result.success {
self.metrics.increment_document_processing_success();
self.metrics.record_document_processing_duration(duration);
} else {
self.metrics.increment_document_processing_error();
tracing::warn!(document_id = %result.document_id, "Document processing failed");
}
Ok(result)
}
pub async fn query_async(&self, query: &str) -> Result<String, GraphRAGError> {
let start_time = Instant::now();
self.metrics.increment_query_started();
let _llm_permit = self.rate_limiter.acquire_llm_permit().await?;
let result = {
let graph = self.graph.read().await;
let entity_count = graph.entities().count();
if entity_count == 0 {
Err(GraphRAGError::Unsupported {
operation: "query processing".to_string(),
reason: "No entities in knowledge graph".to_string(),
})
} else {
Ok(format!(
"Query processed: '{query}'. Found {entity_count} entities in graph. This is a basic implementation."
))
}
};
let duration = start_time.elapsed();
match &result {
Ok(_) => {
self.metrics.increment_query_success();
self.metrics.record_query_duration(duration);
tracing::info!(
duration_ms = duration.as_millis(),
"Query completed successfully"
);
},
Err(e) => {
self.metrics.increment_query_error();
tracing::error!(error = %e, "Query processing error");
},
}
result
}
pub fn get_metrics(&self) -> &ProcessingMetrics {
&self.metrics
}
pub fn get_config(&self) -> &AsyncConfig {
&self.config
}
pub async fn health_check(&self) -> HealthStatus {
let graph_status = {
let graph = self.graph.read().await;
if graph.entities().count() > 0 {
ComponentStatus::Healthy
} else {
ComponentStatus::Warning("No entities in graph".to_string())
}
};
let rate_limiter_status = self.rate_limiter.health_check();
HealthStatus {
overall: if matches!(graph_status, ComponentStatus::Healthy)
&& matches!(rate_limiter_status, ComponentStatus::Healthy)
{
ComponentStatus::Healthy
} else {
ComponentStatus::Warning("Some components have issues".to_string())
},
components: indexmap::indexmap! {
"graph".to_string() => graph_status,
"rate_limiter".to_string() => rate_limiter_status,
},
}
}
pub async fn shutdown(&self) -> Result<(), GraphRAGError> {
tracing::info!("Shutting down async GraphRAG processor");
tracing::info!("Async processor shutdown complete");
Ok(())
}
}
#[derive(Debug, Clone)]
pub enum ComponentStatus {
Healthy,
Warning(String),
Error(String),
}
#[derive(Debug, Clone)]
pub struct HealthStatus {
pub overall: ComponentStatus,
pub components: IndexMap<String, ComponentStatus>,
}
#[derive(Debug)]
pub struct TaskScheduler {
max_concurrent_tasks: usize,
}
impl TaskScheduler {
pub fn new(max_concurrent_tasks: usize) -> Self {
Self {
max_concurrent_tasks,
}
}
pub async fn schedule_task<F, T>(&self, task: F) -> Result<T, GraphRAGError>
where
F: std::future::Future<Output = Result<T, GraphRAGError>>,
{
task.await
}
pub fn max_concurrent_tasks(&self) -> usize {
self.max_concurrent_tasks
}
}
#[derive(Debug, Default)]
pub struct PerformanceTracker {
total_operations: std::sync::atomic::AtomicU64,
total_duration: std::sync::Mutex<Duration>,
}
impl PerformanceTracker {
pub fn new() -> Self {
Self::default()
}
pub fn record_operation(&self, duration: Duration) {
self.total_operations
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let mut total_duration = self.total_duration.lock().expect("lock poisoned");
*total_duration += duration;
}
pub fn get_average_duration(&self) -> Duration {
let total_ops = self
.total_operations
.load(std::sync::atomic::Ordering::Relaxed);
if total_ops == 0 {
return Duration::from_secs(0);
}
let total_duration = *self.total_duration.lock().expect("lock poisoned");
total_duration / total_ops as u32
}
pub fn get_total_operations(&self) -> u64 {
self.total_operations
.load(std::sync::atomic::Ordering::Relaxed)
}
}