use crate::{
config::Config,
core::{
traits::BoxedAsyncLanguageModel, Document, DocumentId, Entity, EntityId, GraphRAGError,
KnowledgeGraph, Result, TextChunk,
},
generation::{AnswerContext, GeneratedAnswer, PromptTemplate},
retrieval::SearchResult,
summarization::{DocumentTree, HierarchicalConfig, LLMClient, QueryResult},
};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
pub struct AsyncLanguageModelAdapter {
model: Arc<BoxedAsyncLanguageModel>,
}
impl AsyncLanguageModelAdapter {
pub fn new(model: Arc<BoxedAsyncLanguageModel>) -> Self {
Self { model }
}
}
#[async_trait::async_trait]
impl LLMClient for AsyncLanguageModelAdapter {
async fn generate_summary(
&self,
text: &str,
prompt: &str,
_max_tokens: usize,
_temperature: f32,
) -> crate::Result<String> {
let full_prompt = format!("{}\n\nText: {}", prompt, text);
let response = self.model.complete(&full_prompt).await.map_err(|e| {
crate::core::GraphRAGError::Generation {
message: e.to_string(),
}
})?;
Ok(response)
}
fn model_name(&self) -> &str {
"async_language_model"
}
}
pub struct AsyncGraphRAG {
#[allow(dead_code)]
config: Config,
knowledge_graph: Arc<RwLock<Option<KnowledgeGraph>>>,
document_trees: Arc<RwLock<HashMap<DocumentId, DocumentTree>>>,
hierarchical_config: HierarchicalConfig,
language_model: Option<Arc<BoxedAsyncLanguageModel>>,
}
impl AsyncGraphRAG {
pub async fn new(config: Config) -> Result<Self> {
let hierarchical_config = config.summarization.clone();
Ok(Self {
config,
knowledge_graph: Arc::new(RwLock::new(None)),
document_trees: Arc::new(RwLock::new(HashMap::new())),
hierarchical_config,
language_model: None,
})
}
pub async fn with_hierarchical_config(
config: Config,
hierarchical_config: HierarchicalConfig,
) -> Result<Self> {
Ok(Self {
config,
knowledge_graph: Arc::new(RwLock::new(None)),
document_trees: Arc::new(RwLock::new(HashMap::new())),
hierarchical_config,
language_model: None,
})
}
pub async fn set_language_model(&mut self, model: Arc<BoxedAsyncLanguageModel>) {
self.language_model = Some(model);
}
pub async fn initialize(&mut self) -> Result<()> {
tracing::info!("Initializing async GraphRAG system");
{
let mut graph_guard = self.knowledge_graph.write().await;
*graph_guard = Some(KnowledgeGraph::new());
}
if self.language_model.is_none() {
#[cfg(feature = "async-traits")]
{
let mock_llm = crate::generation::async_mock_llm::AsyncMockLLM::new().await?;
self.language_model = Some(Arc::new(mock_llm));
}
#[cfg(not(feature = "async-traits"))]
{
return Err(GraphRAGError::Config {
message: "No async language model available and async-traits feature disabled"
.to_string(),
});
}
}
tracing::info!("Async GraphRAG system initialized successfully");
Ok(())
}
pub async fn add_document(&mut self, document: Document) -> Result<()> {
self.build_document_tree(&document).await?;
let mut graph_guard = self.knowledge_graph.write().await;
let graph = graph_guard.as_mut().ok_or_else(|| GraphRAGError::Config {
message: "Knowledge graph not initialized".to_string(),
})?;
graph.add_document(document)
}
pub async fn build_document_tree(&mut self, document: &Document) -> Result<()> {
if document.chunks.is_empty() {
return Ok(());
}
tracing::debug!(document_id = %document.id, "Building hierarchical tree for document");
let tree = if self.hierarchical_config.llm_config.enabled {
if let Some(ref lm) = self.language_model {
let llm_client = Arc::new(AsyncLanguageModelAdapter::new(Arc::clone(lm)));
DocumentTree::with_llm_client(
document.id.clone(),
self.hierarchical_config.clone(),
llm_client,
)?
} else {
DocumentTree::new(document.id.clone(), self.hierarchical_config.clone())?
}
} else {
DocumentTree::new(document.id.clone(), self.hierarchical_config.clone())?
};
{
let mut trees_guard = self.document_trees.write().await;
trees_guard.insert(document.id.clone(), tree);
}
Ok(())
}
pub async fn build_graph(&mut self) -> Result<()> {
let mut graph_guard = self.knowledge_graph.write().await;
let graph = graph_guard.as_mut().ok_or_else(|| GraphRAGError::Config {
message: "Knowledge graph not initialized".to_string(),
})?;
tracing::info!("Building knowledge graph asynchronously");
let chunks: Vec<_> = graph.chunks().cloned().collect();
let mut total_entities = 0;
for chunk in &chunks {
let entities = self.extract_entities_async(chunk).await?;
let mut chunk_entity_ids = Vec::new();
for entity in entities {
chunk_entity_ids.push(entity.id.clone());
graph.add_entity(entity)?;
total_entities += 1;
}
if let Some(existing_chunk) = graph.get_chunk_mut(&chunk.id) {
existing_chunk.entities = chunk_entity_ids;
}
}
tracing::info!(
entity_count = total_entities,
"Knowledge graph built asynchronously"
);
Ok(())
}
async fn extract_entities_async(&self, chunk: &TextChunk) -> Result<Vec<Entity>> {
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
let content = chunk.content.to_lowercase();
let mut entities = Vec::new();
let names = ["tom", "huck", "polly", "sid", "mary", "jim"];
for (i, name) in names.iter().enumerate() {
if content.contains(name) {
let entity = Entity::new(
EntityId::new(format!("{name}_{i}")),
name.to_string(),
"PERSON".to_string(),
0.8,
);
entities.push(entity);
}
}
Ok(entities)
}
pub async fn query(&self, query: &str) -> Result<Vec<String>> {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
Ok(vec![format!("Async result for: {}", query)])
}
pub async fn answer_question(&self, question: &str) -> Result<GeneratedAnswer> {
let graph_guard = self.knowledge_graph.read().await;
let graph = graph_guard
.as_ref()
.ok_or_else(|| GraphRAGError::Generation {
message: "Knowledge graph not initialized".to_string(),
})?;
let llm = self
.language_model
.as_ref()
.ok_or_else(|| GraphRAGError::Generation {
message: "Language model not initialized".to_string(),
})?;
let search_results = self.async_retrieval(question, graph).await?;
let hierarchical_results = self.hierarchical_query(question, 5).await?;
self.generate_answer_async(question, search_results, hierarchical_results, llm)
.await
}
async fn async_retrieval(
&self,
query: &str,
graph: &KnowledgeGraph,
) -> Result<Vec<SearchResult>> {
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
let mut results = Vec::new();
for (i, chunk) in graph.chunks().enumerate().take(3) {
if chunk.content.to_lowercase().contains(&query.to_lowercase()) {
results.push(SearchResult {
id: chunk.id.to_string(),
content: chunk.content.clone(),
score: 0.8 - (i as f32 * 0.1),
result_type: crate::retrieval::ResultType::Chunk,
entities: chunk.entities.iter().map(|e| e.to_string()).collect(),
source_chunks: vec![chunk.id.to_string()],
});
}
}
Ok(results)
}
pub async fn hierarchical_query(
&self,
query: &str,
max_results: usize,
) -> Result<Vec<QueryResult>> {
let trees_guard = self.document_trees.read().await;
let mut all_results = Vec::new();
for tree in trees_guard.values() {
let tree_results = tree.query(query, max_results)?;
all_results.extend(tree_results);
}
all_results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
all_results.truncate(max_results);
Ok(all_results)
}
async fn generate_answer_async(
&self,
question: &str,
search_results: Vec<SearchResult>,
hierarchical_results: Vec<QueryResult>,
llm: &BoxedAsyncLanguageModel,
) -> Result<GeneratedAnswer> {
let context = self
.assemble_context_async(search_results, hierarchical_results)
.await?;
let prompt = self.create_qa_prompt(question, &context)?;
let response = llm.complete(&prompt).await?;
Ok(GeneratedAnswer {
answer_text: response,
confidence_score: context.confidence_score,
sources: context.get_sources(),
entities_mentioned: context.entities,
mode_used: crate::generation::AnswerMode::Abstractive,
context_quality: context.confidence_score,
})
}
async fn assemble_context_async(
&self,
search_results: Vec<SearchResult>,
hierarchical_results: Vec<QueryResult>,
) -> Result<AnswerContext> {
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
let mut context = AnswerContext::new();
for result in search_results {
context.primary_chunks.push(result);
}
context.hierarchical_summaries = hierarchical_results;
let avg_score = if context.primary_chunks.is_empty() {
0.0
} else {
context.primary_chunks.iter().map(|r| r.score).sum::<f32>()
/ context.primary_chunks.len() as f32
};
context.confidence_score = avg_score;
context.source_count = context.primary_chunks.len() + context.hierarchical_summaries.len();
Ok(context)
}
fn create_qa_prompt(&self, question: &str, context: &AnswerContext) -> Result<String> {
let combined_content = context.get_combined_content();
let mut values = HashMap::new();
values.insert("context".to_string(), combined_content);
values.insert("question".to_string(), question.to_string());
let template = PromptTemplate::new(
"Context:\n{context}\n\nQuestion: {question}\n\nBased on the provided context, please answer the question. If the context doesn't contain enough information, please say so.".to_string()
);
template.fill(&values)
}
pub async fn add_documents_batch(&mut self, documents: Vec<Document>) -> Result<()> {
tracing::info!(
document_count = documents.len(),
"Processing documents concurrently"
);
for document in documents {
self.add_document(document).await?;
}
tracing::info!("All documents processed successfully");
Ok(())
}
pub async fn answer_questions_batch(&self, questions: &[&str]) -> Result<Vec<GeneratedAnswer>> {
use futures::stream::{FuturesUnordered, StreamExt};
let mut futures = FuturesUnordered::new();
for question in questions {
futures.push(self.answer_question(question));
}
let mut answers = Vec::with_capacity(questions.len());
while let Some(result) = futures.next().await {
answers.push(result?);
}
Ok(answers)
}
pub async fn get_performance_stats(&self) -> AsyncPerformanceStats {
let graph_guard = self.knowledge_graph.read().await;
let trees_guard = self.document_trees.read().await;
AsyncPerformanceStats {
total_documents: trees_guard.len(),
total_entities: graph_guard.as_ref().map(|g| g.entity_count()).unwrap_or(0),
total_chunks: graph_guard
.as_ref()
.map(|g| g.chunks().count())
.unwrap_or(0),
health_status: AsyncHealthStatus::Healthy,
}
}
pub async fn health_check(&self) -> Result<AsyncHealthStatus> {
if let Some(llm) = &self.language_model {
if !llm.health_check().await.unwrap_or(false) {
return Ok(AsyncHealthStatus::Degraded);
}
}
let graph_guard = self.knowledge_graph.read().await;
if graph_guard.is_none() {
return Ok(AsyncHealthStatus::Degraded);
}
Ok(AsyncHealthStatus::Healthy)
}
pub async fn save_state_async(&self, output_dir: &str) -> Result<()> {
use std::fs;
fs::create_dir_all(output_dir)?;
let graph_guard = self.knowledge_graph.read().await;
if let Some(graph) = &*graph_guard {
graph.save_to_json(&format!("{output_dir}/async_knowledge_graph.json"))?;
}
let trees_guard = self.document_trees.read().await;
for (doc_id, tree) in trees_guard.iter() {
let filename = format!("{output_dir}/{doc_id}_async_tree.json");
let json_content = tree.to_json()?;
fs::write(&filename, json_content)?;
}
tracing::info!(output_dir = %output_dir, "Async state saved");
Ok(())
}
}
#[derive(Debug)]
pub struct AsyncPerformanceStats {
pub total_documents: usize,
pub total_entities: usize,
pub total_chunks: usize,
pub health_status: AsyncHealthStatus,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AsyncHealthStatus {
Healthy,
Degraded,
Unhealthy,
}
pub struct AsyncGraphRAGBuilder {
config: Config,
language_model: Option<Arc<BoxedAsyncLanguageModel>>,
hierarchical_config: Option<HierarchicalConfig>,
}
impl AsyncGraphRAGBuilder {
pub fn new() -> Self {
Self {
config: Config::default(),
language_model: None,
hierarchical_config: None,
}
}
pub fn config(mut self, config: Config) -> Self {
self.config = config;
self
}
pub fn language_model(mut self, model: BoxedAsyncLanguageModel) -> Self {
self.language_model = Some(Arc::new(model));
self
}
pub fn hierarchical_config(mut self, config: HierarchicalConfig) -> Self {
self.hierarchical_config = Some(config);
self
}
#[cfg(feature = "async-traits")]
pub async fn with_async_mock_llm(mut self) -> Result<Self> {
let mock_llm = crate::generation::async_mock_llm::AsyncMockLLM::new().await?;
self.language_model = Some(Arc::new(mock_llm));
Ok(self)
}
#[cfg(all(feature = "ollama", feature = "async-traits"))]
pub async fn with_async_ollama(mut self, config: crate::ollama::OllamaConfig) -> Result<Self> {
let ollama_llm = crate::ollama::AsyncOllamaGenerator::new(config).await?;
self.language_model = Some(Arc::new(ollama_llm));
Ok(self)
}
pub async fn build(self) -> Result<AsyncGraphRAG> {
let hierarchical_config = self.hierarchical_config.unwrap_or_default();
let mut graphrag =
AsyncGraphRAG::with_hierarchical_config(self.config, hierarchical_config).await?;
if let Some(llm) = self.language_model {
graphrag.set_language_model(llm).await;
}
graphrag.initialize().await?;
Ok(graphrag)
}
}
impl Default for AsyncGraphRAGBuilder {
fn default() -> Self {
Self::new()
}
}