Skip to main content

graphrag_core/
async_graphrag.rs

1//! Async GraphRAG System
2//!
3//! This module provides a complete async implementation of the GraphRAG system
4//! that leverages all async traits for maximum performance and scalability.
5
6use crate::{
7    config::Config,
8    core::{
9        traits::BoxedAsyncLanguageModel,
10        Document, DocumentId, Entity, EntityId, GraphRAGError, KnowledgeGraph, Result,
11        TextChunk,
12    },
13    generation::{AnswerContext, GeneratedAnswer, PromptTemplate},
14    retrieval::SearchResult,
15    summarization::{DocumentTree, HierarchicalConfig, QueryResult, LLMClient},
16};
17use std::collections::HashMap;
18use std::sync::Arc;
19use tokio::sync::RwLock;
20
21/// Adapter to connect BoxedAsyncLanguageModel to LLMClient trait
22pub struct AsyncLanguageModelAdapter {
23    model: Arc<BoxedAsyncLanguageModel>,
24}
25
26impl AsyncLanguageModelAdapter {
27    /// Creates a new adapter wrapping a BoxedAsyncLanguageModel.
28    ///
29    /// # Arguments
30    /// * `model` - The async language model to wrap in the adapter
31    ///
32    /// # Returns
33    /// A new AsyncLanguageModelAdapter instance
34    pub fn new(model: Arc<BoxedAsyncLanguageModel>) -> Self {
35        Self { model }
36    }
37}
38
39#[async_trait::async_trait]
40impl LLMClient for AsyncLanguageModelAdapter {
41    async fn generate_summary(
42        &self,
43        text: &str,
44        prompt: &str,
45        _max_tokens: usize,
46        _temperature: f32,
47    ) -> crate::Result<String> {
48        let full_prompt = format!("{}\n\nText: {}", prompt, text);
49
50        let response = self.model
51            .complete(&full_prompt)
52            .await.map_err(|e| crate::core::GraphRAGError::Generation { message: e.to_string() })?;
53
54        Ok(response)
55    }
56
57    fn model_name(&self) -> &str {
58        "async_language_model"
59    }
60}
61
62/// Async version of the main GraphRAG system
63pub struct AsyncGraphRAG {
64    #[allow(dead_code)] config: Config,
65    knowledge_graph: Arc<RwLock<Option<KnowledgeGraph>>>,
66    document_trees: Arc<RwLock<HashMap<DocumentId, DocumentTree>>>,
67    hierarchical_config: HierarchicalConfig,
68    language_model: Option<Arc<BoxedAsyncLanguageModel>>,
69}
70
71impl AsyncGraphRAG {
72    /// Create a new async GraphRAG instance
73    pub async fn new(config: Config) -> Result<Self> {
74        let hierarchical_config = config.summarization.clone();
75        Ok(Self {
76            config,
77            knowledge_graph: Arc::new(RwLock::new(None)),
78            document_trees: Arc::new(RwLock::new(HashMap::new())),
79            hierarchical_config,
80            language_model: None,
81        })
82    }
83
84    /// Create with custom hierarchical configuration
85    pub async fn with_hierarchical_config(
86        config: Config,
87        hierarchical_config: HierarchicalConfig,
88    ) -> Result<Self> {
89        Ok(Self {
90            config,
91            knowledge_graph: Arc::new(RwLock::new(None)),
92            document_trees: Arc::new(RwLock::new(HashMap::new())),
93            hierarchical_config,
94            language_model: None,
95        })
96    }
97
98    /// Set the async language model
99    pub async fn set_language_model(&mut self, model: Arc<BoxedAsyncLanguageModel>) {
100        self.language_model = Some(model);
101    }
102
103    /// Initialize the async GraphRAG system
104    pub async fn initialize(&mut self) -> Result<()> {
105        tracing::info!("Initializing async GraphRAG system");
106
107        // Initialize knowledge graph
108        {
109            let mut graph_guard = self.knowledge_graph.write().await;
110            *graph_guard = Some(KnowledgeGraph::new());
111        }
112
113        // Initialize with default async mock LLM if none provided
114        if self.language_model.is_none() {
115            #[cfg(feature = "async-traits")]
116            {
117                let mock_llm = crate::generation::async_mock_llm::AsyncMockLLM::new().await?;
118                self.language_model = Some(Arc::new(mock_llm));
119            }
120            #[cfg(not(feature = "async-traits"))]
121            {
122                return Err(GraphRAGError::Config {
123                    message: "No async language model available and async-traits feature disabled".to_string(),
124                });
125            }
126        }
127
128        tracing::info!("Async GraphRAG system initialized successfully");
129        Ok(())
130    }
131
132    /// Add a document to the system asynchronously
133    pub async fn add_document(&mut self, document: Document) -> Result<()> {
134        // Build hierarchical tree for the document first
135        self.build_document_tree(&document).await?;
136
137        let mut graph_guard = self.knowledge_graph.write().await;
138        let graph = graph_guard
139            .as_mut()
140            .ok_or_else(|| GraphRAGError::Config {
141                message: "Knowledge graph not initialized".to_string(),
142            })?;
143
144        graph.add_document(document)
145    }
146
147    /// Build hierarchical tree for a document asynchronously
148    pub async fn build_document_tree(&mut self, document: &Document) -> Result<()> {
149        if document.chunks.is_empty() {
150            return Ok(());
151        }
152
153        tracing::debug!(document_id = %document.id, "Building hierarchical tree for document");
154
155        let tree = if self.hierarchical_config.llm_config.enabled {
156            // Use LLM-powered summarization if enabled in config
157            if let Some(ref lm) = self.language_model {
158                let llm_client = Arc::new(AsyncLanguageModelAdapter::new(Arc::clone(lm)));
159                DocumentTree::with_llm_client(
160                    document.id.clone(),
161                    self.hierarchical_config.clone(),
162                    llm_client,
163                )?
164            } else {
165                DocumentTree::new(document.id.clone(), self.hierarchical_config.clone())?
166            }
167        } else {
168            // Use extractive summarization
169            DocumentTree::new(document.id.clone(), self.hierarchical_config.clone())?
170        };
171        // Note: In a full async implementation, DocumentTree would also be async
172
173        {
174            let mut trees_guard = self.document_trees.write().await;
175            trees_guard.insert(document.id.clone(), tree);
176        }
177
178        Ok(())
179    }
180
181    /// Build the knowledge graph from documents asynchronously
182    pub async fn build_graph(&mut self) -> Result<()> {
183        let mut graph_guard = self.knowledge_graph.write().await;
184        let graph = graph_guard
185            .as_mut()
186            .ok_or_else(|| GraphRAGError::Config {
187                message: "Knowledge graph not initialized".to_string(),
188            })?;
189
190        tracing::info!("Building knowledge graph asynchronously");
191
192        // Extract entities from all chunks asynchronously
193        let chunks: Vec<_> = graph.chunks().cloned().collect();
194        let mut total_entities = 0;
195
196        // For each chunk, extract entities (would use AsyncEntityExtractor in full implementation)
197        for chunk in &chunks {
198            // Simulate async entity extraction
199            let entities = self.extract_entities_async(chunk).await?;
200
201            // Add entities to the graph
202            let mut chunk_entity_ids = Vec::new();
203            for entity in entities {
204                chunk_entity_ids.push(entity.id.clone());
205                graph.add_entity(entity)?;
206                total_entities += 1;
207            }
208
209            // Update chunk with entity references
210            if let Some(existing_chunk) = graph.get_chunk_mut(&chunk.id) {
211                existing_chunk.entities = chunk_entity_ids;
212            }
213        }
214
215        tracing::info!(entity_count = total_entities, "Knowledge graph built asynchronously");
216        Ok(())
217    }
218
219    /// Simulate async entity extraction (would use actual AsyncEntityExtractor)
220    async fn extract_entities_async(&self, chunk: &TextChunk) -> Result<Vec<Entity>> {
221        // Simulate async processing delay
222        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
223
224        // Simple entity extraction for demo (would use actual async implementation)
225        let content = chunk.content.to_lowercase();
226        let mut entities = Vec::new();
227
228        // Extract simple named entities
229        let names = ["tom", "huck", "polly", "sid", "mary", "jim"];
230        for (i, name) in names.iter().enumerate() {
231            if content.contains(name) {
232                let entity = Entity::new(
233                    EntityId::new(format!("{name}_{i}")),
234                    name.to_string(),
235                    "PERSON".to_string(),
236                    0.8,
237                );
238                entities.push(entity);
239            }
240        }
241
242        Ok(entities)
243    }
244
245    /// Query the system asynchronously
246    pub async fn query(&self, query: &str) -> Result<Vec<String>> {
247        // Simulate async retrieval (would use actual AsyncRetriever)
248        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
249
250        // For demo, return simple response
251        Ok(vec![format!("Async result for: {}", query)])
252    }
253
254    /// Generate an answer to a question using async pipeline
255    pub async fn answer_question(&self, question: &str) -> Result<GeneratedAnswer> {
256        let graph_guard = self.knowledge_graph.read().await;
257        let graph = graph_guard
258            .as_ref()
259            .ok_or_else(|| GraphRAGError::Generation {
260                message: "Knowledge graph not initialized".to_string(),
261            })?;
262
263        let llm = self
264            .language_model
265            .as_ref()
266            .ok_or_else(|| GraphRAGError::Generation {
267                message: "Language model not initialized".to_string(),
268            })?;
269
270        // Perform async retrieval
271        let search_results = self.async_retrieval(question, graph).await?;
272
273        // Get hierarchical results
274        let hierarchical_results = self.hierarchical_query(question, 5).await?;
275
276        // Generate answer using async LLM
277        self.generate_answer_async(question, search_results, hierarchical_results, llm)
278            .await
279    }
280
281    /// Perform async retrieval
282    async fn async_retrieval(
283        &self,
284        query: &str,
285        graph: &KnowledgeGraph,
286    ) -> Result<Vec<SearchResult>> {
287        // Simulate async retrieval processing
288        tokio::time::sleep(std::time::Duration::from_millis(5)).await;
289
290        // Simple search simulation
291        let mut results = Vec::new();
292        for (i, chunk) in graph.chunks().enumerate().take(3) {
293            if chunk.content.to_lowercase().contains(&query.to_lowercase()) {
294                results.push(SearchResult {
295                    id: chunk.id.to_string(),
296                    content: chunk.content.clone(),
297                    score: 0.8 - (i as f32 * 0.1),
298                    result_type: crate::retrieval::ResultType::Chunk,
299                    entities: chunk.entities.iter().map(|e| e.to_string()).collect(),
300                    source_chunks: vec![chunk.id.to_string()],
301                });
302            }
303        }
304
305        Ok(results)
306    }
307
308    /// Query using hierarchical summarization asynchronously
309    pub async fn hierarchical_query(
310        &self,
311        query: &str,
312        max_results: usize,
313    ) -> Result<Vec<QueryResult>> {
314        let trees_guard = self.document_trees.read().await;
315        let mut all_results = Vec::new();
316
317        // Query all document trees
318        for tree in trees_guard.values() {
319            // In full implementation, DocumentTree would have async query method
320            let tree_results = tree.query(query, max_results)?;
321            all_results.extend(tree_results);
322        }
323
324        // Sort by score and limit results
325        all_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
326        all_results.truncate(max_results);
327
328        Ok(all_results)
329    }
330
331    /// Generate answer using async language model
332    async fn generate_answer_async(
333        &self,
334        question: &str,
335        search_results: Vec<SearchResult>,
336        hierarchical_results: Vec<QueryResult>,
337        llm: &BoxedAsyncLanguageModel,
338    ) -> Result<GeneratedAnswer> {
339        // Assemble context
340        let context = self.assemble_context_async(search_results, hierarchical_results).await?;
341
342        // Create prompt
343        let prompt = self.create_qa_prompt(question, &context)?;
344
345        // Generate response using async LLM
346        let response = llm.complete(&prompt).await?;
347
348        // Create answer with metadata
349        Ok(GeneratedAnswer {
350            answer_text: response,
351            confidence_score: context.confidence_score,
352            sources: context.get_sources(),
353            entities_mentioned: context.entities,
354            mode_used: crate::generation::AnswerMode::Abstractive,
355            context_quality: context.confidence_score,
356        })
357    }
358
359    /// Assemble context asynchronously
360    async fn assemble_context_async(
361        &self,
362        search_results: Vec<SearchResult>,
363        hierarchical_results: Vec<QueryResult>,
364    ) -> Result<AnswerContext> {
365        // Simulate async context assembly
366        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
367
368        let mut context = AnswerContext::new();
369
370        // Process search results
371        for result in search_results {
372            context.primary_chunks.push(result);
373        }
374
375        // Process hierarchical results
376        context.hierarchical_summaries = hierarchical_results;
377
378        // Calculate confidence score
379        let avg_score = if context.primary_chunks.is_empty() {
380            0.0
381        } else {
382            context.primary_chunks.iter().map(|r| r.score).sum::<f32>()
383                / context.primary_chunks.len() as f32
384        };
385
386        context.confidence_score = avg_score;
387        context.source_count = context.primary_chunks.len() + context.hierarchical_summaries.len();
388
389        Ok(context)
390    }
391
392    /// Create QA prompt from context
393    fn create_qa_prompt(&self, question: &str, context: &AnswerContext) -> Result<String> {
394        let combined_content = context.get_combined_content();
395
396        let mut values = HashMap::new();
397        values.insert("context".to_string(), combined_content);
398        values.insert("question".to_string(), question.to_string());
399
400        let template = PromptTemplate::new(
401            "Context:\n{context}\n\nQuestion: {question}\n\nBased on the provided context, please answer the question. If the context doesn't contain enough information, please say so.".to_string()
402        );
403
404        template.fill(&values)
405    }
406
407    /// Batch process multiple documents concurrently
408    pub async fn add_documents_batch(&mut self, documents: Vec<Document>) -> Result<()> {
409        tracing::info!(document_count = documents.len(), "Processing documents concurrently");
410
411        // Process documents sequentially for now to avoid borrowing issues
412        // In a production implementation, you'd use channels or other concurrency patterns
413        for document in documents {
414            self.add_document(document).await?;
415        }
416
417        tracing::info!("All documents processed successfully");
418        Ok(())
419    }
420
421    /// Batch answer multiple questions concurrently
422    pub async fn answer_questions_batch(&self, questions: &[&str]) -> Result<Vec<GeneratedAnswer>> {
423        use futures::stream::{FuturesUnordered, StreamExt};
424
425        let mut futures = FuturesUnordered::new();
426
427        for question in questions {
428            futures.push(self.answer_question(question));
429        }
430
431        let mut answers = Vec::with_capacity(questions.len());
432        while let Some(result) = futures.next().await {
433            answers.push(result?);
434        }
435
436        Ok(answers)
437    }
438
439    /// Get performance statistics
440    pub async fn get_performance_stats(&self) -> AsyncPerformanceStats {
441        let graph_guard = self.knowledge_graph.read().await;
442        let trees_guard = self.document_trees.read().await;
443
444        AsyncPerformanceStats {
445            total_documents: trees_guard.len(),
446            total_entities: graph_guard
447                .as_ref()
448                .map(|g| g.entity_count())
449                .unwrap_or(0),
450            total_chunks: graph_guard
451                .as_ref()
452                .map(|g| g.chunks().count())
453                .unwrap_or(0),
454            health_status: AsyncHealthStatus::Healthy,
455        }
456    }
457
458    /// Health check for all async components
459    pub async fn health_check(&self) -> Result<AsyncHealthStatus> {
460        // Check language model
461        if let Some(llm) = &self.language_model {
462            if !llm.health_check().await.unwrap_or(false) {
463                return Ok(AsyncHealthStatus::Degraded);
464            }
465        }
466
467        // Check if knowledge graph is initialized
468        let graph_guard = self.knowledge_graph.read().await;
469        if graph_guard.is_none() {
470            return Ok(AsyncHealthStatus::Degraded);
471        }
472
473        Ok(AsyncHealthStatus::Healthy)
474    }
475
476    /// Save state asynchronously
477    pub async fn save_state_async(&self, output_dir: &str) -> Result<()> {
478        use std::fs;
479
480        // Create output directory
481        fs::create_dir_all(output_dir)?;
482
483        // Save knowledge graph
484        let graph_guard = self.knowledge_graph.read().await;
485        if let Some(graph) = &*graph_guard {
486            graph.save_to_json(&format!("{output_dir}/async_knowledge_graph.json"))?;
487        }
488
489        // Save document trees
490        let trees_guard = self.document_trees.read().await;
491        for (doc_id, tree) in trees_guard.iter() {
492            let filename = format!("{output_dir}/{doc_id}_async_tree.json");
493            let json_content = tree.to_json()?;
494            fs::write(&filename, json_content)?;
495        }
496
497        tracing::info!(output_dir = %output_dir, "Async state saved");
498        Ok(())
499    }
500}
501
502/// Performance statistics for async GraphRAG
503#[derive(Debug)]
504pub struct AsyncPerformanceStats {
505    /// Total number of documents processed in the system
506    pub total_documents: usize,
507    /// Total number of entities extracted across all documents
508    pub total_entities: usize,
509    /// Total number of text chunks created from documents
510    pub total_chunks: usize,
511    /// Current health status of the async GraphRAG system
512    pub health_status: AsyncHealthStatus,
513}
514
515/// Health status for async components
516#[derive(Debug, Clone, PartialEq, Eq)]
517pub enum AsyncHealthStatus {
518    /// All async components are functioning normally with no issues detected
519    Healthy,
520    /// Some async components are experiencing issues but the system remains operational
521    Degraded,
522    /// Critical async components have failed and the system is not functioning properly
523    Unhealthy,
524}
525
526/// Builder for AsyncGraphRAG
527pub struct AsyncGraphRAGBuilder {
528    config: Config,
529    language_model: Option<Arc<BoxedAsyncLanguageModel>>,
530    hierarchical_config: Option<HierarchicalConfig>,
531}
532
533impl AsyncGraphRAGBuilder {
534    /// Create a new async builder
535    pub fn new() -> Self {
536        Self {
537            config: Config::default(),
538            language_model: None,
539            hierarchical_config: None,
540        }
541    }
542
543    /// Set configuration
544    pub fn config(mut self, config: Config) -> Self {
545        self.config = config;
546        self
547    }
548
549    /// Set async language model
550    pub fn language_model(mut self, model: BoxedAsyncLanguageModel) -> Self {
551        self.language_model = Some(Arc::new(model));
552        self
553    }
554
555    /// Set hierarchical configuration
556    pub fn hierarchical_config(mut self, config: HierarchicalConfig) -> Self {
557        self.hierarchical_config = Some(config);
558        self
559    }
560
561    /// Build with async mock LLM
562    #[cfg(feature = "async-traits")]
563    pub async fn with_async_mock_llm(mut self) -> Result<Self> {
564        let mock_llm = crate::generation::async_mock_llm::AsyncMockLLM::new().await?;
565        self.language_model = Some(Arc::new(mock_llm));
566        Ok(self)
567    }
568
569    /// Build with async Ollama LLM
570    #[cfg(all(feature = "ollama", feature = "async-traits"))]
571    pub async fn with_async_ollama(
572        mut self,
573        config: crate::ollama::OllamaConfig,
574    ) -> Result<Self> {
575        let ollama_llm = crate::ollama::AsyncOllamaGenerator::new(config).await?;
576        self.language_model = Some(Arc::new(ollama_llm));
577        Ok(self)
578    }
579
580    /// Build the async GraphRAG instance
581    pub async fn build(self) -> Result<AsyncGraphRAG> {
582        let hierarchical_config = self.hierarchical_config.unwrap_or_default();
583
584        let mut graphrag = AsyncGraphRAG::with_hierarchical_config(self.config, hierarchical_config).await?;
585
586        if let Some(llm) = self.language_model {
587            graphrag.set_language_model(llm).await;
588        }
589
590        graphrag.initialize().await?;
591
592        Ok(graphrag)
593    }
594}
595
596impl Default for AsyncGraphRAGBuilder {
597    fn default() -> Self {
598        Self::new()
599    }
600}
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605
606    #[tokio::test]
607    async fn test_async_graphrag_creation() {
608        let config = Config::default();
609        let graphrag = AsyncGraphRAG::new(config).await;
610        assert!(graphrag.is_ok());
611    }
612
613    #[tokio::test]
614    async fn test_async_graphrag_initialization() {
615        let config = Config::default();
616        let mut graphrag = AsyncGraphRAG::new(config).await.unwrap();
617        let result = graphrag.initialize().await;
618        assert!(result.is_ok());
619    }
620
621    #[tokio::test]
622    async fn test_async_builder() {
623        let result = AsyncGraphRAGBuilder::new().build().await;
624        assert!(result.is_ok());
625    }
626
627    #[tokio::test]
628    #[cfg(feature = "async-traits")]
629    async fn test_with_async_mock_llm() {
630        let result = AsyncGraphRAGBuilder::new()
631            .with_async_mock_llm()
632            .await
633            .unwrap()
634            .build()
635            .await;
636        assert!(result.is_ok());
637    }
638
639    #[tokio::test]
640    async fn test_health_check() {
641        let config = Config::default();
642        let mut graphrag = AsyncGraphRAG::new(config).await.unwrap();
643        graphrag.initialize().await.unwrap();
644
645        let health = graphrag.health_check().await.unwrap();
646        assert_eq!(health, AsyncHealthStatus::Healthy);
647    }
648
649    #[tokio::test]
650    async fn test_performance_stats() {
651        let config = Config::default();
652        let mut graphrag = AsyncGraphRAG::new(config).await.unwrap();
653        graphrag.initialize().await.unwrap();
654
655        let stats = graphrag.get_performance_stats().await;
656        assert_eq!(stats.total_documents, 0);
657        assert_eq!(stats.health_status, AsyncHealthStatus::Healthy);
658    }
659}