Skip to main content

graphrag_core/
async_graphrag.rs

1//! Async GraphRAG System
2//!
3//! This module provides a complete async implementation of the GraphRAG system
4//! that leverages all async traits for maximum performance and scalability.
5
6use crate::{
7    config::Config,
8    core::{
9        traits::BoxedAsyncLanguageModel, Document, DocumentId, Entity, EntityId, GraphRAGError,
10        KnowledgeGraph, Result, TextChunk,
11    },
12    generation::{AnswerContext, GeneratedAnswer, PromptTemplate},
13    retrieval::SearchResult,
14    summarization::{DocumentTree, HierarchicalConfig, LLMClient, QueryResult},
15};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19
20/// Adapter to connect BoxedAsyncLanguageModel to LLMClient trait
21pub struct AsyncLanguageModelAdapter {
22    model: Arc<BoxedAsyncLanguageModel>,
23}
24
25impl AsyncLanguageModelAdapter {
26    /// Creates a new adapter wrapping a BoxedAsyncLanguageModel.
27    ///
28    /// # Arguments
29    /// * `model` - The async language model to wrap in the adapter
30    ///
31    /// # Returns
32    /// A new AsyncLanguageModelAdapter instance
33    pub fn new(model: Arc<BoxedAsyncLanguageModel>) -> Self {
34        Self { model }
35    }
36}
37
38#[async_trait::async_trait]
39impl LLMClient for AsyncLanguageModelAdapter {
40    async fn generate_summary(
41        &self,
42        text: &str,
43        prompt: &str,
44        _max_tokens: usize,
45        _temperature: f32,
46    ) -> crate::Result<String> {
47        let full_prompt = format!("{}\n\nText: {}", prompt, text);
48
49        let response = self.model.complete(&full_prompt).await.map_err(|e| {
50            crate::core::GraphRAGError::Generation {
51                message: e.to_string(),
52            }
53        })?;
54
55        Ok(response)
56    }
57
58    fn model_name(&self) -> &str {
59        "async_language_model"
60    }
61}
62
63/// Async version of the main GraphRAG system
64pub struct AsyncGraphRAG {
65    #[allow(dead_code)]
66    config: Config,
67    knowledge_graph: Arc<RwLock<Option<KnowledgeGraph>>>,
68    document_trees: Arc<RwLock<HashMap<DocumentId, DocumentTree>>>,
69    hierarchical_config: HierarchicalConfig,
70    language_model: Option<Arc<BoxedAsyncLanguageModel>>,
71}
72
73impl AsyncGraphRAG {
74    /// Create a new async GraphRAG instance
75    pub async fn new(config: Config) -> Result<Self> {
76        let hierarchical_config = config.summarization.clone();
77        Ok(Self {
78            config,
79            knowledge_graph: Arc::new(RwLock::new(None)),
80            document_trees: Arc::new(RwLock::new(HashMap::new())),
81            hierarchical_config,
82            language_model: None,
83        })
84    }
85
86    /// Create with custom hierarchical configuration
87    pub async fn with_hierarchical_config(
88        config: Config,
89        hierarchical_config: HierarchicalConfig,
90    ) -> Result<Self> {
91        Ok(Self {
92            config,
93            knowledge_graph: Arc::new(RwLock::new(None)),
94            document_trees: Arc::new(RwLock::new(HashMap::new())),
95            hierarchical_config,
96            language_model: None,
97        })
98    }
99
100    /// Set the async language model
101    pub async fn set_language_model(&mut self, model: Arc<BoxedAsyncLanguageModel>) {
102        self.language_model = Some(model);
103    }
104
105    /// Initialize the async GraphRAG system
106    pub async fn initialize(&mut self) -> Result<()> {
107        tracing::info!("Initializing async GraphRAG system");
108
109        // Initialize knowledge graph
110        {
111            let mut graph_guard = self.knowledge_graph.write().await;
112            *graph_guard = Some(KnowledgeGraph::new());
113        }
114
115        // Initialize with default async mock LLM if none provided
116        if self.language_model.is_none() {
117            #[cfg(feature = "async-traits")]
118            {
119                let mock_llm = crate::generation::async_mock_llm::AsyncMockLLM::new().await?;
120                self.language_model = Some(Arc::new(mock_llm));
121            }
122            #[cfg(not(feature = "async-traits"))]
123            {
124                return Err(GraphRAGError::Config {
125                    message: "No async language model available and async-traits feature disabled"
126                        .to_string(),
127                });
128            }
129        }
130
131        tracing::info!("Async GraphRAG system initialized successfully");
132        Ok(())
133    }
134
135    /// Add a document to the system asynchronously
136    pub async fn add_document(&mut self, document: Document) -> Result<()> {
137        // Build hierarchical tree for the document first
138        self.build_document_tree(&document).await?;
139
140        let mut graph_guard = self.knowledge_graph.write().await;
141        let graph = graph_guard.as_mut().ok_or_else(|| GraphRAGError::Config {
142            message: "Knowledge graph not initialized".to_string(),
143        })?;
144
145        graph.add_document(document)
146    }
147
148    /// Build hierarchical tree for a document asynchronously
149    pub async fn build_document_tree(&mut self, document: &Document) -> Result<()> {
150        if document.chunks.is_empty() {
151            return Ok(());
152        }
153
154        tracing::debug!(document_id = %document.id, "Building hierarchical tree for document");
155
156        let tree = if self.hierarchical_config.llm_config.enabled {
157            // Use LLM-powered summarization if enabled in config
158            if let Some(ref lm) = self.language_model {
159                let llm_client = Arc::new(AsyncLanguageModelAdapter::new(Arc::clone(lm)));
160                DocumentTree::with_llm_client(
161                    document.id.clone(),
162                    self.hierarchical_config.clone(),
163                    llm_client,
164                )?
165            } else {
166                DocumentTree::new(document.id.clone(), self.hierarchical_config.clone())?
167            }
168        } else {
169            // Use extractive summarization
170            DocumentTree::new(document.id.clone(), self.hierarchical_config.clone())?
171        };
172        // Note: In a full async implementation, DocumentTree would also be async
173
174        {
175            let mut trees_guard = self.document_trees.write().await;
176            trees_guard.insert(document.id.clone(), tree);
177        }
178
179        Ok(())
180    }
181
182    /// Build the knowledge graph from documents asynchronously
183    pub async fn build_graph(&mut self) -> Result<()> {
184        let mut graph_guard = self.knowledge_graph.write().await;
185        let graph = graph_guard.as_mut().ok_or_else(|| GraphRAGError::Config {
186            message: "Knowledge graph not initialized".to_string(),
187        })?;
188
189        tracing::info!("Building knowledge graph asynchronously");
190
191        // Extract entities from all chunks asynchronously
192        let chunks: Vec<_> = graph.chunks().cloned().collect();
193        let mut total_entities = 0;
194
195        // For each chunk, extract entities (would use AsyncEntityExtractor in full implementation)
196        for chunk in &chunks {
197            // Simulate async entity extraction
198            let entities = self.extract_entities_async(chunk).await?;
199
200            // Add entities to the graph
201            let mut chunk_entity_ids = Vec::new();
202            for entity in entities {
203                chunk_entity_ids.push(entity.id.clone());
204                graph.add_entity(entity)?;
205                total_entities += 1;
206            }
207
208            // Update chunk with entity references
209            if let Some(existing_chunk) = graph.get_chunk_mut(&chunk.id) {
210                existing_chunk.entities = chunk_entity_ids;
211            }
212        }
213
214        tracing::info!(
215            entity_count = total_entities,
216            "Knowledge graph built asynchronously"
217        );
218        Ok(())
219    }
220
221    /// Simulate async entity extraction (would use actual AsyncEntityExtractor)
222    async fn extract_entities_async(&self, chunk: &TextChunk) -> Result<Vec<Entity>> {
223        // Simulate async processing delay
224        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
225
226        // Simple entity extraction for demo (would use actual async implementation)
227        let content = chunk.content.to_lowercase();
228        let mut entities = Vec::new();
229
230        // Extract simple named entities
231        let names = ["tom", "huck", "polly", "sid", "mary", "jim"];
232        for (i, name) in names.iter().enumerate() {
233            if content.contains(name) {
234                let entity = Entity::new(
235                    EntityId::new(format!("{name}_{i}")),
236                    name.to_string(),
237                    "PERSON".to_string(),
238                    0.8,
239                );
240                entities.push(entity);
241            }
242        }
243
244        Ok(entities)
245    }
246
247    /// Query the system asynchronously
248    pub async fn query(&self, query: &str) -> Result<Vec<String>> {
249        // Simulate async retrieval (would use actual AsyncRetriever)
250        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
251
252        // For demo, return simple response
253        Ok(vec![format!("Async result for: {}", query)])
254    }
255
256    /// Generate an answer to a question using async pipeline
257    pub async fn answer_question(&self, question: &str) -> Result<GeneratedAnswer> {
258        let graph_guard = self.knowledge_graph.read().await;
259        let graph = graph_guard
260            .as_ref()
261            .ok_or_else(|| GraphRAGError::Generation {
262                message: "Knowledge graph not initialized".to_string(),
263            })?;
264
265        let llm = self
266            .language_model
267            .as_ref()
268            .ok_or_else(|| GraphRAGError::Generation {
269                message: "Language model not initialized".to_string(),
270            })?;
271
272        // Perform async retrieval
273        let search_results = self.async_retrieval(question, graph).await?;
274
275        // Get hierarchical results
276        let hierarchical_results = self.hierarchical_query(question, 5).await?;
277
278        // Generate answer using async LLM
279        self.generate_answer_async(question, search_results, hierarchical_results, llm)
280            .await
281    }
282
283    /// Perform async retrieval
284    async fn async_retrieval(
285        &self,
286        query: &str,
287        graph: &KnowledgeGraph,
288    ) -> Result<Vec<SearchResult>> {
289        // Simulate async retrieval processing
290        tokio::time::sleep(std::time::Duration::from_millis(5)).await;
291
292        // Simple search simulation
293        let mut results = Vec::new();
294        for (i, chunk) in graph.chunks().enumerate().take(3) {
295            if chunk.content.to_lowercase().contains(&query.to_lowercase()) {
296                results.push(SearchResult {
297                    id: chunk.id.to_string(),
298                    content: chunk.content.clone(),
299                    score: 0.8 - (i as f32 * 0.1),
300                    result_type: crate::retrieval::ResultType::Chunk,
301                    entities: chunk.entities.iter().map(|e| e.to_string()).collect(),
302                    source_chunks: vec![chunk.id.to_string()],
303                });
304            }
305        }
306
307        Ok(results)
308    }
309
310    /// Query using hierarchical summarization asynchronously
311    pub async fn hierarchical_query(
312        &self,
313        query: &str,
314        max_results: usize,
315    ) -> Result<Vec<QueryResult>> {
316        let trees_guard = self.document_trees.read().await;
317        let mut all_results = Vec::new();
318
319        // Query all document trees
320        for tree in trees_guard.values() {
321            // In full implementation, DocumentTree would have async query method
322            let tree_results = tree.query(query, max_results)?;
323            all_results.extend(tree_results);
324        }
325
326        // Sort by score and limit results
327        all_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
328        all_results.truncate(max_results);
329
330        Ok(all_results)
331    }
332
333    /// Generate answer using async language model
334    async fn generate_answer_async(
335        &self,
336        question: &str,
337        search_results: Vec<SearchResult>,
338        hierarchical_results: Vec<QueryResult>,
339        llm: &BoxedAsyncLanguageModel,
340    ) -> Result<GeneratedAnswer> {
341        // Assemble context
342        let context = self
343            .assemble_context_async(search_results, hierarchical_results)
344            .await?;
345
346        // Create prompt
347        let prompt = self.create_qa_prompt(question, &context)?;
348
349        // Generate response using async LLM
350        let response = llm.complete(&prompt).await?;
351
352        // Create answer with metadata
353        Ok(GeneratedAnswer {
354            answer_text: response,
355            confidence_score: context.confidence_score,
356            sources: context.get_sources(),
357            entities_mentioned: context.entities,
358            mode_used: crate::generation::AnswerMode::Abstractive,
359            context_quality: context.confidence_score,
360        })
361    }
362
363    /// Assemble context asynchronously
364    async fn assemble_context_async(
365        &self,
366        search_results: Vec<SearchResult>,
367        hierarchical_results: Vec<QueryResult>,
368    ) -> Result<AnswerContext> {
369        // Simulate async context assembly
370        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
371
372        let mut context = AnswerContext::new();
373
374        // Process search results
375        for result in search_results {
376            context.primary_chunks.push(result);
377        }
378
379        // Process hierarchical results
380        context.hierarchical_summaries = hierarchical_results;
381
382        // Calculate confidence score
383        let avg_score = if context.primary_chunks.is_empty() {
384            0.0
385        } else {
386            context.primary_chunks.iter().map(|r| r.score).sum::<f32>()
387                / context.primary_chunks.len() as f32
388        };
389
390        context.confidence_score = avg_score;
391        context.source_count = context.primary_chunks.len() + context.hierarchical_summaries.len();
392
393        Ok(context)
394    }
395
396    /// Create QA prompt from context
397    fn create_qa_prompt(&self, question: &str, context: &AnswerContext) -> Result<String> {
398        let combined_content = context.get_combined_content();
399
400        let mut values = HashMap::new();
401        values.insert("context".to_string(), combined_content);
402        values.insert("question".to_string(), question.to_string());
403
404        let template = PromptTemplate::new(
405            "Context:\n{context}\n\nQuestion: {question}\n\nBased on the provided context, please answer the question. If the context doesn't contain enough information, please say so.".to_string()
406        );
407
408        template.fill(&values)
409    }
410
411    /// Batch process multiple documents concurrently
412    pub async fn add_documents_batch(&mut self, documents: Vec<Document>) -> Result<()> {
413        tracing::info!(
414            document_count = documents.len(),
415            "Processing documents concurrently"
416        );
417
418        // Process documents sequentially for now to avoid borrowing issues
419        // In a production implementation, you'd use channels or other concurrency patterns
420        for document in documents {
421            self.add_document(document).await?;
422        }
423
424        tracing::info!("All documents processed successfully");
425        Ok(())
426    }
427
428    /// Batch answer multiple questions concurrently
429    pub async fn answer_questions_batch(&self, questions: &[&str]) -> Result<Vec<GeneratedAnswer>> {
430        use futures::stream::{FuturesUnordered, StreamExt};
431
432        let mut futures = FuturesUnordered::new();
433
434        for question in questions {
435            futures.push(self.answer_question(question));
436        }
437
438        let mut answers = Vec::with_capacity(questions.len());
439        while let Some(result) = futures.next().await {
440            answers.push(result?);
441        }
442
443        Ok(answers)
444    }
445
446    /// Get performance statistics
447    pub async fn get_performance_stats(&self) -> AsyncPerformanceStats {
448        let graph_guard = self.knowledge_graph.read().await;
449        let trees_guard = self.document_trees.read().await;
450
451        AsyncPerformanceStats {
452            total_documents: trees_guard.len(),
453            total_entities: graph_guard.as_ref().map(|g| g.entity_count()).unwrap_or(0),
454            total_chunks: graph_guard
455                .as_ref()
456                .map(|g| g.chunks().count())
457                .unwrap_or(0),
458            health_status: AsyncHealthStatus::Healthy,
459        }
460    }
461
462    /// Health check for all async components
463    pub async fn health_check(&self) -> Result<AsyncHealthStatus> {
464        // Check language model
465        if let Some(llm) = &self.language_model {
466            if !llm.health_check().await.unwrap_or(false) {
467                return Ok(AsyncHealthStatus::Degraded);
468            }
469        }
470
471        // Check if knowledge graph is initialized
472        let graph_guard = self.knowledge_graph.read().await;
473        if graph_guard.is_none() {
474            return Ok(AsyncHealthStatus::Degraded);
475        }
476
477        Ok(AsyncHealthStatus::Healthy)
478    }
479
480    /// Save state asynchronously
481    pub async fn save_state_async(&self, output_dir: &str) -> Result<()> {
482        use std::fs;
483
484        // Create output directory
485        fs::create_dir_all(output_dir)?;
486
487        // Save knowledge graph
488        let graph_guard = self.knowledge_graph.read().await;
489        if let Some(graph) = &*graph_guard {
490            graph.save_to_json(&format!("{output_dir}/async_knowledge_graph.json"))?;
491        }
492
493        // Save document trees
494        let trees_guard = self.document_trees.read().await;
495        for (doc_id, tree) in trees_guard.iter() {
496            let filename = format!("{output_dir}/{doc_id}_async_tree.json");
497            let json_content = tree.to_json()?;
498            fs::write(&filename, json_content)?;
499        }
500
501        tracing::info!(output_dir = %output_dir, "Async state saved");
502        Ok(())
503    }
504}
505
506/// Performance statistics for async GraphRAG
507#[derive(Debug)]
508pub struct AsyncPerformanceStats {
509    /// Total number of documents processed in the system
510    pub total_documents: usize,
511    /// Total number of entities extracted across all documents
512    pub total_entities: usize,
513    /// Total number of text chunks created from documents
514    pub total_chunks: usize,
515    /// Current health status of the async GraphRAG system
516    pub health_status: AsyncHealthStatus,
517}
518
519/// Health status for async components
520#[derive(Debug, Clone, PartialEq, Eq)]
521pub enum AsyncHealthStatus {
522    /// All async components are functioning normally with no issues detected
523    Healthy,
524    /// Some async components are experiencing issues but the system remains operational
525    Degraded,
526    /// Critical async components have failed and the system is not functioning properly
527    Unhealthy,
528}
529
530/// Builder for AsyncGraphRAG
531pub struct AsyncGraphRAGBuilder {
532    config: Config,
533    language_model: Option<Arc<BoxedAsyncLanguageModel>>,
534    hierarchical_config: Option<HierarchicalConfig>,
535}
536
537impl AsyncGraphRAGBuilder {
538    /// Create a new async builder
539    pub fn new() -> Self {
540        Self {
541            config: Config::default(),
542            language_model: None,
543            hierarchical_config: None,
544        }
545    }
546
547    /// Set configuration
548    pub fn config(mut self, config: Config) -> Self {
549        self.config = config;
550        self
551    }
552
553    /// Set async language model
554    pub fn language_model(mut self, model: BoxedAsyncLanguageModel) -> Self {
555        self.language_model = Some(Arc::new(model));
556        self
557    }
558
559    /// Set hierarchical configuration
560    pub fn hierarchical_config(mut self, config: HierarchicalConfig) -> Self {
561        self.hierarchical_config = Some(config);
562        self
563    }
564
565    /// Build with async mock LLM
566    #[cfg(feature = "async-traits")]
567    pub async fn with_async_mock_llm(mut self) -> Result<Self> {
568        let mock_llm = crate::generation::async_mock_llm::AsyncMockLLM::new().await?;
569        self.language_model = Some(Arc::new(mock_llm));
570        Ok(self)
571    }
572
573    /// Build with async Ollama LLM
574    #[cfg(all(feature = "ollama", feature = "async-traits"))]
575    pub async fn with_async_ollama(mut self, config: crate::ollama::OllamaConfig) -> Result<Self> {
576        let ollama_llm = crate::ollama::AsyncOllamaGenerator::new(config).await?;
577        self.language_model = Some(Arc::new(ollama_llm));
578        Ok(self)
579    }
580
581    /// Build the async GraphRAG instance
582    pub async fn build(self) -> Result<AsyncGraphRAG> {
583        let hierarchical_config = self.hierarchical_config.unwrap_or_default();
584
585        let mut graphrag =
586            AsyncGraphRAG::with_hierarchical_config(self.config, hierarchical_config).await?;
587
588        if let Some(llm) = self.language_model {
589            graphrag.set_language_model(llm).await;
590        }
591
592        graphrag.initialize().await?;
593
594        Ok(graphrag)
595    }
596}
597
598impl Default for AsyncGraphRAGBuilder {
599    fn default() -> Self {
600        Self::new()
601    }
602}
603
604#[cfg(test)]
605mod tests {
606    use super::*;
607
608    #[tokio::test]
609    async fn test_async_graphrag_creation() {
610        let config = Config::default();
611        let graphrag = AsyncGraphRAG::new(config).await;
612        assert!(graphrag.is_ok());
613    }
614
615    #[tokio::test]
616    async fn test_async_graphrag_initialization() {
617        let config = Config::default();
618        let mut graphrag = AsyncGraphRAG::new(config).await.unwrap();
619        let result = graphrag.initialize().await;
620        assert!(result.is_ok());
621    }
622
623    #[tokio::test]
624    async fn test_async_builder() {
625        let result = AsyncGraphRAGBuilder::new().build().await;
626        assert!(result.is_ok());
627    }
628
629    #[tokio::test]
630    #[cfg(feature = "async-traits")]
631    async fn test_with_async_mock_llm() {
632        let result = AsyncGraphRAGBuilder::new()
633            .with_async_mock_llm()
634            .await
635            .unwrap()
636            .build()
637            .await;
638        assert!(result.is_ok());
639    }
640
641    #[tokio::test]
642    async fn test_health_check() {
643        let config = Config::default();
644        let mut graphrag = AsyncGraphRAG::new(config).await.unwrap();
645        graphrag.initialize().await.unwrap();
646
647        let health = graphrag.health_check().await.unwrap();
648        assert_eq!(health, AsyncHealthStatus::Healthy);
649    }
650
651    #[tokio::test]
652    async fn test_performance_stats() {
653        let config = Config::default();
654        let mut graphrag = AsyncGraphRAG::new(config).await.unwrap();
655        graphrag.initialize().await.unwrap();
656
657        let stats = graphrag.get_performance_stats().await;
658        assert_eq!(stats.total_documents, 0);
659        assert_eq!(stats.health_status, AsyncHealthStatus::Healthy);
660    }
661}