Skip to main content

graphrag_core/
async_graphrag.rs

1//! Async GraphRAG System
2//!
3//! This module provides a complete async implementation of the GraphRAG system
4//! that leverages all async traits for maximum performance and scalability.
5
6use crate::{
7    config::Config,
8    core::{
9        traits::BoxedAsyncLanguageModel, Document, DocumentId, Entity, EntityId, GraphRAGError,
10        KnowledgeGraph, Result, TextChunk,
11    },
12    generation::{AnswerContext, GeneratedAnswer, PromptTemplate},
13    retrieval::SearchResult,
14    summarization::{DocumentTree, HierarchicalConfig, LLMClient, QueryResult},
15};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19
20/// Adapter to connect BoxedAsyncLanguageModel to LLMClient trait
21pub struct AsyncLanguageModelAdapter {
22    model: Arc<BoxedAsyncLanguageModel>,
23}
24
25impl AsyncLanguageModelAdapter {
26    /// Creates a new adapter wrapping a BoxedAsyncLanguageModel.
27    ///
28    /// # Arguments
29    /// * `model` - The async language model to wrap in the adapter
30    ///
31    /// # Returns
32    /// A new AsyncLanguageModelAdapter instance
33    pub fn new(model: Arc<BoxedAsyncLanguageModel>) -> Self {
34        Self { model }
35    }
36}
37
38#[async_trait::async_trait]
39impl LLMClient for AsyncLanguageModelAdapter {
40    async fn generate_summary(
41        &self,
42        text: &str,
43        prompt: &str,
44        _max_tokens: usize,
45        _temperature: f32,
46    ) -> crate::Result<String> {
47        let full_prompt = format!("{}\n\nText: {}", prompt, text);
48
49        let response = self.model.complete(&full_prompt).await.map_err(|e| {
50            crate::core::GraphRAGError::Generation {
51                message: e.to_string(),
52            }
53        })?;
54
55        Ok(response)
56    }
57
58    fn model_name(&self) -> &str {
59        "async_language_model"
60    }
61}
62
63/// Async version of the main GraphRAG system
64pub struct AsyncGraphRAG {
65    #[allow(dead_code)]
66    config: Config,
67    knowledge_graph: Arc<RwLock<Option<KnowledgeGraph>>>,
68    document_trees: Arc<RwLock<HashMap<DocumentId, DocumentTree>>>,
69    hierarchical_config: HierarchicalConfig,
70    language_model: Option<Arc<BoxedAsyncLanguageModel>>,
71}
72
73impl AsyncGraphRAG {
74    /// Create a new async GraphRAG instance
75    pub async fn new(config: Config) -> Result<Self> {
76        let hierarchical_config = config.summarization.clone();
77        Ok(Self {
78            config,
79            knowledge_graph: Arc::new(RwLock::new(None)),
80            document_trees: Arc::new(RwLock::new(HashMap::new())),
81            hierarchical_config,
82            language_model: None,
83        })
84    }
85
86    /// Create with custom hierarchical configuration
87    pub async fn with_hierarchical_config(
88        config: Config,
89        hierarchical_config: HierarchicalConfig,
90    ) -> Result<Self> {
91        Ok(Self {
92            config,
93            knowledge_graph: Arc::new(RwLock::new(None)),
94            document_trees: Arc::new(RwLock::new(HashMap::new())),
95            hierarchical_config,
96            language_model: None,
97        })
98    }
99
100    /// Set the async language model
101    pub async fn set_language_model(&mut self, model: Arc<BoxedAsyncLanguageModel>) {
102        self.language_model = Some(model);
103    }
104
105    /// Initialize the async GraphRAG system
106    pub async fn initialize(&mut self) -> Result<()> {
107        tracing::info!("Initializing async GraphRAG system");
108
109        // Initialize knowledge graph
110        {
111            let mut graph_guard = self.knowledge_graph.write().await;
112            *graph_guard = Some(KnowledgeGraph::new());
113        }
114
115        // Initialize with default async mock LLM if none provided
116        if self.language_model.is_none() {
117            #[cfg(feature = "async-traits")]
118            {
119                let mock_llm = crate::generation::async_mock_llm::AsyncMockLLM::new().await?;
120                self.language_model = Some(Arc::new(mock_llm));
121            }
122            #[cfg(not(feature = "async-traits"))]
123            {
124                return Err(GraphRAGError::Config {
125                    message: "No async language model available and async-traits feature disabled"
126                        .to_string(),
127                });
128            }
129        }
130
131        tracing::info!("Async GraphRAG system initialized successfully");
132        Ok(())
133    }
134
135    /// Add a document to the system asynchronously
136    pub async fn add_document(&mut self, document: Document) -> Result<()> {
137        // Build hierarchical tree for the document first
138        self.build_document_tree(&document).await?;
139
140        let mut graph_guard = self.knowledge_graph.write().await;
141        let graph = graph_guard.as_mut().ok_or_else(|| GraphRAGError::Config {
142            message: "Knowledge graph not initialized".to_string(),
143        })?;
144
145        graph.add_document(document)
146    }
147
148    /// Build hierarchical tree for a document asynchronously
149    pub async fn build_document_tree(&mut self, document: &Document) -> Result<()> {
150        if document.chunks.is_empty() {
151            return Ok(());
152        }
153
154        tracing::debug!(document_id = %document.id, "Building hierarchical tree for document");
155
156        let tree = if self.hierarchical_config.llm_config.enabled {
157            // Use LLM-powered summarization if enabled in config
158            if let Some(ref lm) = self.language_model {
159                let llm_client = Arc::new(AsyncLanguageModelAdapter::new(Arc::clone(lm)));
160                DocumentTree::with_llm_client(
161                    document.id.clone(),
162                    self.hierarchical_config.clone(),
163                    llm_client,
164                )?
165            } else {
166                DocumentTree::new(document.id.clone(), self.hierarchical_config.clone())?
167            }
168        } else {
169            // Use extractive summarization
170            DocumentTree::new(document.id.clone(), self.hierarchical_config.clone())?
171        };
172        // Note: In a full async implementation, DocumentTree would also be async
173
174        {
175            let mut trees_guard = self.document_trees.write().await;
176            trees_guard.insert(document.id.clone(), tree);
177        }
178
179        Ok(())
180    }
181
182    /// Build the knowledge graph from documents asynchronously
183    pub async fn build_graph(&mut self) -> Result<()> {
184        let mut graph_guard = self.knowledge_graph.write().await;
185        let graph = graph_guard.as_mut().ok_or_else(|| GraphRAGError::Config {
186            message: "Knowledge graph not initialized".to_string(),
187        })?;
188
189        tracing::info!("Building knowledge graph asynchronously");
190
191        // Extract entities from all chunks asynchronously
192        let chunks: Vec<_> = graph.chunks().cloned().collect();
193        let mut total_entities = 0;
194
195        // For each chunk, extract entities (would use AsyncEntityExtractor in full implementation)
196        for chunk in &chunks {
197            // Simulate async entity extraction
198            let entities = self.extract_entities_async(chunk).await?;
199
200            // Add entities to the graph
201            let mut chunk_entity_ids = Vec::new();
202            for entity in entities {
203                chunk_entity_ids.push(entity.id.clone());
204                graph.add_entity(entity)?;
205                total_entities += 1;
206            }
207
208            // Update chunk with entity references
209            if let Some(existing_chunk) = graph.get_chunk_mut(&chunk.id) {
210                existing_chunk.entities = chunk_entity_ids;
211            }
212        }
213
214        tracing::info!(
215            entity_count = total_entities,
216            "Knowledge graph built asynchronously"
217        );
218        Ok(())
219    }
220
221    /// Simulate async entity extraction (would use actual AsyncEntityExtractor)
222    async fn extract_entities_async(&self, chunk: &TextChunk) -> Result<Vec<Entity>> {
223        // Simulate async processing delay
224        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
225
226        // Simple entity extraction for demo (would use actual async implementation)
227        let content = chunk.content.to_lowercase();
228        let mut entities = Vec::new();
229
230        // Extract simple named entities
231        let names = ["tom", "huck", "polly", "sid", "mary", "jim"];
232        for (i, name) in names.iter().enumerate() {
233            if content.contains(name) {
234                let entity = Entity::new(
235                    EntityId::new(format!("{name}_{i}")),
236                    name.to_string(),
237                    "PERSON".to_string(),
238                    0.8,
239                );
240                entities.push(entity);
241            }
242        }
243
244        Ok(entities)
245    }
246
247    /// Query the system asynchronously
248    pub async fn query(&self, query: &str) -> Result<Vec<String>> {
249        // Simulate async retrieval (would use actual AsyncRetriever)
250        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
251
252        // For demo, return simple response
253        Ok(vec![format!("Async result for: {}", query)])
254    }
255
256    /// Generate an answer to a question using async pipeline
257    pub async fn answer_question(&self, question: &str) -> Result<GeneratedAnswer> {
258        let graph_guard = self.knowledge_graph.read().await;
259        let graph = graph_guard
260            .as_ref()
261            .ok_or_else(|| GraphRAGError::Generation {
262                message: "Knowledge graph not initialized".to_string(),
263            })?;
264
265        let llm = self
266            .language_model
267            .as_ref()
268            .ok_or_else(|| GraphRAGError::Generation {
269                message: "Language model not initialized".to_string(),
270            })?;
271
272        // Perform async retrieval
273        let search_results = self.async_retrieval(question, graph).await?;
274
275        // Get hierarchical results
276        let hierarchical_results = self.hierarchical_query(question, 5).await?;
277
278        // Generate answer using async LLM
279        self.generate_answer_async(question, search_results, hierarchical_results, llm)
280            .await
281    }
282
283    /// Perform async retrieval
284    async fn async_retrieval(
285        &self,
286        query: &str,
287        graph: &KnowledgeGraph,
288    ) -> Result<Vec<SearchResult>> {
289        // Simulate async retrieval processing
290        tokio::time::sleep(std::time::Duration::from_millis(5)).await;
291
292        // Simple search simulation
293        let mut results = Vec::new();
294        for (i, chunk) in graph.chunks().enumerate().take(3) {
295            if chunk.content.to_lowercase().contains(&query.to_lowercase()) {
296                results.push(SearchResult {
297                    id: chunk.id.to_string(),
298                    content: chunk.content.clone(),
299                    score: 0.8 - (i as f32 * 0.1),
300                    result_type: crate::retrieval::ResultType::Chunk,
301                    entities: chunk.entities.iter().map(|e| e.to_string()).collect(),
302                    source_chunks: vec![chunk.id.to_string()],
303                });
304            }
305        }
306
307        Ok(results)
308    }
309
310    /// Query using hierarchical summarization asynchronously
311    pub async fn hierarchical_query(
312        &self,
313        query: &str,
314        max_results: usize,
315    ) -> Result<Vec<QueryResult>> {
316        let trees_guard = self.document_trees.read().await;
317        let mut all_results = Vec::new();
318
319        // Query all document trees
320        for tree in trees_guard.values() {
321            // In full implementation, DocumentTree would have async query method
322            let tree_results = tree.query(query, max_results)?;
323            all_results.extend(tree_results);
324        }
325
326        // Sort by score and limit results
327        all_results.sort_by(|a, b| {
328            b.score
329                .partial_cmp(&a.score)
330                .unwrap_or(std::cmp::Ordering::Equal)
331        });
332        all_results.truncate(max_results);
333
334        Ok(all_results)
335    }
336
337    /// Generate answer using async language model
338    async fn generate_answer_async(
339        &self,
340        question: &str,
341        search_results: Vec<SearchResult>,
342        hierarchical_results: Vec<QueryResult>,
343        llm: &BoxedAsyncLanguageModel,
344    ) -> Result<GeneratedAnswer> {
345        // Assemble context
346        let context = self
347            .assemble_context_async(search_results, hierarchical_results)
348            .await?;
349
350        // Create prompt
351        let prompt = self.create_qa_prompt(question, &context)?;
352
353        // Generate response using async LLM
354        let response = llm.complete(&prompt).await?;
355
356        // Create answer with metadata
357        Ok(GeneratedAnswer {
358            answer_text: response,
359            confidence_score: context.confidence_score,
360            sources: context.get_sources(),
361            entities_mentioned: context.entities,
362            mode_used: crate::generation::AnswerMode::Abstractive,
363            context_quality: context.confidence_score,
364        })
365    }
366
367    /// Assemble context asynchronously
368    async fn assemble_context_async(
369        &self,
370        search_results: Vec<SearchResult>,
371        hierarchical_results: Vec<QueryResult>,
372    ) -> Result<AnswerContext> {
373        // Simulate async context assembly
374        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
375
376        let mut context = AnswerContext::new();
377
378        // Process search results
379        for result in search_results {
380            context.primary_chunks.push(result);
381        }
382
383        // Process hierarchical results
384        context.hierarchical_summaries = hierarchical_results;
385
386        // Calculate confidence score
387        let avg_score = if context.primary_chunks.is_empty() {
388            0.0
389        } else {
390            context.primary_chunks.iter().map(|r| r.score).sum::<f32>()
391                / context.primary_chunks.len() as f32
392        };
393
394        context.confidence_score = avg_score;
395        context.source_count = context.primary_chunks.len() + context.hierarchical_summaries.len();
396
397        Ok(context)
398    }
399
400    /// Create QA prompt from context
401    fn create_qa_prompt(&self, question: &str, context: &AnswerContext) -> Result<String> {
402        let combined_content = context.get_combined_content();
403
404        let mut values = HashMap::new();
405        values.insert("context".to_string(), combined_content);
406        values.insert("question".to_string(), question.to_string());
407
408        let template = PromptTemplate::new(
409            "Context:\n{context}\n\nQuestion: {question}\n\nBased on the provided context, please answer the question. If the context doesn't contain enough information, please say so.".to_string()
410        );
411
412        template.fill(&values)
413    }
414
415    /// Batch process multiple documents concurrently
416    pub async fn add_documents_batch(&mut self, documents: Vec<Document>) -> Result<()> {
417        tracing::info!(
418            document_count = documents.len(),
419            "Processing documents concurrently"
420        );
421
422        // Process documents sequentially for now to avoid borrowing issues
423        // In a production implementation, you'd use channels or other concurrency patterns
424        for document in documents {
425            self.add_document(document).await?;
426        }
427
428        tracing::info!("All documents processed successfully");
429        Ok(())
430    }
431
432    /// Batch answer multiple questions concurrently
433    pub async fn answer_questions_batch(&self, questions: &[&str]) -> Result<Vec<GeneratedAnswer>> {
434        use futures::stream::{FuturesUnordered, StreamExt};
435
436        let mut futures = FuturesUnordered::new();
437
438        for question in questions {
439            futures.push(self.answer_question(question));
440        }
441
442        let mut answers = Vec::with_capacity(questions.len());
443        while let Some(result) = futures.next().await {
444            answers.push(result?);
445        }
446
447        Ok(answers)
448    }
449
450    /// Get performance statistics
451    pub async fn get_performance_stats(&self) -> AsyncPerformanceStats {
452        let graph_guard = self.knowledge_graph.read().await;
453        let trees_guard = self.document_trees.read().await;
454
455        AsyncPerformanceStats {
456            total_documents: trees_guard.len(),
457            total_entities: graph_guard.as_ref().map(|g| g.entity_count()).unwrap_or(0),
458            total_chunks: graph_guard
459                .as_ref()
460                .map(|g| g.chunks().count())
461                .unwrap_or(0),
462            health_status: AsyncHealthStatus::Healthy,
463        }
464    }
465
466    /// Health check for all async components
467    pub async fn health_check(&self) -> Result<AsyncHealthStatus> {
468        // Check language model
469        if let Some(llm) = &self.language_model {
470            if !llm.health_check().await.unwrap_or(false) {
471                return Ok(AsyncHealthStatus::Degraded);
472            }
473        }
474
475        // Check if knowledge graph is initialized
476        let graph_guard = self.knowledge_graph.read().await;
477        if graph_guard.is_none() {
478            return Ok(AsyncHealthStatus::Degraded);
479        }
480
481        Ok(AsyncHealthStatus::Healthy)
482    }
483
484    /// Save state asynchronously
485    pub async fn save_state_async(&self, output_dir: &str) -> Result<()> {
486        use std::fs;
487
488        // Create output directory
489        fs::create_dir_all(output_dir)?;
490
491        // Save knowledge graph
492        let graph_guard = self.knowledge_graph.read().await;
493        if let Some(graph) = &*graph_guard {
494            graph.save_to_json(&format!("{output_dir}/async_knowledge_graph.json"))?;
495        }
496
497        // Save document trees
498        let trees_guard = self.document_trees.read().await;
499        for (doc_id, tree) in trees_guard.iter() {
500            let filename = format!("{output_dir}/{doc_id}_async_tree.json");
501            let json_content = tree.to_json()?;
502            fs::write(&filename, json_content)?;
503        }
504
505        tracing::info!(output_dir = %output_dir, "Async state saved");
506        Ok(())
507    }
508}
509
510/// Performance statistics for async GraphRAG
511#[derive(Debug)]
512pub struct AsyncPerformanceStats {
513    /// Total number of documents processed in the system
514    pub total_documents: usize,
515    /// Total number of entities extracted across all documents
516    pub total_entities: usize,
517    /// Total number of text chunks created from documents
518    pub total_chunks: usize,
519    /// Current health status of the async GraphRAG system
520    pub health_status: AsyncHealthStatus,
521}
522
523/// Health status for async components
524#[derive(Debug, Clone, PartialEq, Eq)]
525pub enum AsyncHealthStatus {
526    /// All async components are functioning normally with no issues detected
527    Healthy,
528    /// Some async components are experiencing issues but the system remains operational
529    Degraded,
530    /// Critical async components have failed and the system is not functioning properly
531    Unhealthy,
532}
533
534/// Builder for AsyncGraphRAG
535pub struct AsyncGraphRAGBuilder {
536    config: Config,
537    language_model: Option<Arc<BoxedAsyncLanguageModel>>,
538    hierarchical_config: Option<HierarchicalConfig>,
539}
540
541impl AsyncGraphRAGBuilder {
542    /// Create a new async builder
543    pub fn new() -> Self {
544        Self {
545            config: Config::default(),
546            language_model: None,
547            hierarchical_config: None,
548        }
549    }
550
551    /// Set configuration
552    pub fn config(mut self, config: Config) -> Self {
553        self.config = config;
554        self
555    }
556
557    /// Set async language model
558    pub fn language_model(mut self, model: BoxedAsyncLanguageModel) -> Self {
559        self.language_model = Some(Arc::new(model));
560        self
561    }
562
563    /// Set hierarchical configuration
564    pub fn hierarchical_config(mut self, config: HierarchicalConfig) -> Self {
565        self.hierarchical_config = Some(config);
566        self
567    }
568
569    /// Build with async mock LLM
570    #[cfg(feature = "async-traits")]
571    pub async fn with_async_mock_llm(mut self) -> Result<Self> {
572        let mock_llm = crate::generation::async_mock_llm::AsyncMockLLM::new().await?;
573        self.language_model = Some(Arc::new(mock_llm));
574        Ok(self)
575    }
576
577    /// Build with async Ollama LLM
578    #[cfg(all(feature = "ollama", feature = "async-traits"))]
579    pub async fn with_async_ollama(mut self, config: crate::ollama::OllamaConfig) -> Result<Self> {
580        let ollama_llm = crate::ollama::AsyncOllamaGenerator::new(config).await?;
581        self.language_model = Some(Arc::new(ollama_llm));
582        Ok(self)
583    }
584
585    /// Build the async GraphRAG instance
586    pub async fn build(self) -> Result<AsyncGraphRAG> {
587        let hierarchical_config = self.hierarchical_config.unwrap_or_default();
588
589        let mut graphrag =
590            AsyncGraphRAG::with_hierarchical_config(self.config, hierarchical_config).await?;
591
592        if let Some(llm) = self.language_model {
593            graphrag.set_language_model(llm).await;
594        }
595
596        graphrag.initialize().await?;
597
598        Ok(graphrag)
599    }
600}
601
602impl Default for AsyncGraphRAGBuilder {
603    fn default() -> Self {
604        Self::new()
605    }
606}