1use crate::{
7 config::Config,
8 core::{
9 traits::BoxedAsyncLanguageModel, Document, DocumentId, Entity, EntityId, GraphRAGError,
10 KnowledgeGraph, Result, TextChunk,
11 },
12 generation::{AnswerContext, GeneratedAnswer, PromptTemplate},
13 retrieval::SearchResult,
14 summarization::{DocumentTree, HierarchicalConfig, LLMClient, QueryResult},
15};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19
20pub struct AsyncLanguageModelAdapter {
22 model: Arc<BoxedAsyncLanguageModel>,
23}
24
25impl AsyncLanguageModelAdapter {
26 pub fn new(model: Arc<BoxedAsyncLanguageModel>) -> Self {
34 Self { model }
35 }
36}
37
38#[async_trait::async_trait]
39impl LLMClient for AsyncLanguageModelAdapter {
40 async fn generate_summary(
41 &self,
42 text: &str,
43 prompt: &str,
44 _max_tokens: usize,
45 _temperature: f32,
46 ) -> crate::Result<String> {
47 let full_prompt = format!("{}\n\nText: {}", prompt, text);
48
49 let response = self.model.complete(&full_prompt).await.map_err(|e| {
50 crate::core::GraphRAGError::Generation {
51 message: e.to_string(),
52 }
53 })?;
54
55 Ok(response)
56 }
57
58 fn model_name(&self) -> &str {
59 "async_language_model"
60 }
61}
62
63pub struct AsyncGraphRAG {
65 #[allow(dead_code)]
66 config: Config,
67 knowledge_graph: Arc<RwLock<Option<KnowledgeGraph>>>,
68 document_trees: Arc<RwLock<HashMap<DocumentId, DocumentTree>>>,
69 hierarchical_config: HierarchicalConfig,
70 language_model: Option<Arc<BoxedAsyncLanguageModel>>,
71}
72
73impl AsyncGraphRAG {
74 pub async fn new(config: Config) -> Result<Self> {
76 let hierarchical_config = config.summarization.clone();
77 Ok(Self {
78 config,
79 knowledge_graph: Arc::new(RwLock::new(None)),
80 document_trees: Arc::new(RwLock::new(HashMap::new())),
81 hierarchical_config,
82 language_model: None,
83 })
84 }
85
86 pub async fn with_hierarchical_config(
88 config: Config,
89 hierarchical_config: HierarchicalConfig,
90 ) -> Result<Self> {
91 Ok(Self {
92 config,
93 knowledge_graph: Arc::new(RwLock::new(None)),
94 document_trees: Arc::new(RwLock::new(HashMap::new())),
95 hierarchical_config,
96 language_model: None,
97 })
98 }
99
100 pub async fn set_language_model(&mut self, model: Arc<BoxedAsyncLanguageModel>) {
102 self.language_model = Some(model);
103 }
104
105 pub async fn initialize(&mut self) -> Result<()> {
107 tracing::info!("Initializing async GraphRAG system");
108
109 {
111 let mut graph_guard = self.knowledge_graph.write().await;
112 *graph_guard = Some(KnowledgeGraph::new());
113 }
114
115 if self.language_model.is_none() {
117 #[cfg(feature = "async-traits")]
118 {
119 let mock_llm = crate::generation::async_mock_llm::AsyncMockLLM::new().await?;
120 self.language_model = Some(Arc::new(mock_llm));
121 }
122 #[cfg(not(feature = "async-traits"))]
123 {
124 return Err(GraphRAGError::Config {
125 message: "No async language model available and async-traits feature disabled"
126 .to_string(),
127 });
128 }
129 }
130
131 tracing::info!("Async GraphRAG system initialized successfully");
132 Ok(())
133 }
134
135 pub async fn add_document(&mut self, document: Document) -> Result<()> {
137 self.build_document_tree(&document).await?;
139
140 let mut graph_guard = self.knowledge_graph.write().await;
141 let graph = graph_guard.as_mut().ok_or_else(|| GraphRAGError::Config {
142 message: "Knowledge graph not initialized".to_string(),
143 })?;
144
145 graph.add_document(document)
146 }
147
148 pub async fn build_document_tree(&mut self, document: &Document) -> Result<()> {
150 if document.chunks.is_empty() {
151 return Ok(());
152 }
153
154 tracing::debug!(document_id = %document.id, "Building hierarchical tree for document");
155
156 let tree = if self.hierarchical_config.llm_config.enabled {
157 if let Some(ref lm) = self.language_model {
159 let llm_client = Arc::new(AsyncLanguageModelAdapter::new(Arc::clone(lm)));
160 DocumentTree::with_llm_client(
161 document.id.clone(),
162 self.hierarchical_config.clone(),
163 llm_client,
164 )?
165 } else {
166 DocumentTree::new(document.id.clone(), self.hierarchical_config.clone())?
167 }
168 } else {
169 DocumentTree::new(document.id.clone(), self.hierarchical_config.clone())?
171 };
172 {
175 let mut trees_guard = self.document_trees.write().await;
176 trees_guard.insert(document.id.clone(), tree);
177 }
178
179 Ok(())
180 }
181
182 pub async fn build_graph(&mut self) -> Result<()> {
184 let mut graph_guard = self.knowledge_graph.write().await;
185 let graph = graph_guard.as_mut().ok_or_else(|| GraphRAGError::Config {
186 message: "Knowledge graph not initialized".to_string(),
187 })?;
188
189 tracing::info!("Building knowledge graph asynchronously");
190
191 let chunks: Vec<_> = graph.chunks().cloned().collect();
193 let mut total_entities = 0;
194
195 for chunk in &chunks {
197 let entities = self.extract_entities_async(chunk).await?;
199
200 let mut chunk_entity_ids = Vec::new();
202 for entity in entities {
203 chunk_entity_ids.push(entity.id.clone());
204 graph.add_entity(entity)?;
205 total_entities += 1;
206 }
207
208 if let Some(existing_chunk) = graph.get_chunk_mut(&chunk.id) {
210 existing_chunk.entities = chunk_entity_ids;
211 }
212 }
213
214 tracing::info!(
215 entity_count = total_entities,
216 "Knowledge graph built asynchronously"
217 );
218 Ok(())
219 }
220
221 async fn extract_entities_async(&self, chunk: &TextChunk) -> Result<Vec<Entity>> {
223 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
225
226 let content = chunk.content.to_lowercase();
228 let mut entities = Vec::new();
229
230 let names = ["tom", "huck", "polly", "sid", "mary", "jim"];
232 for (i, name) in names.iter().enumerate() {
233 if content.contains(name) {
234 let entity = Entity::new(
235 EntityId::new(format!("{name}_{i}")),
236 name.to_string(),
237 "PERSON".to_string(),
238 0.8,
239 );
240 entities.push(entity);
241 }
242 }
243
244 Ok(entities)
245 }
246
247 pub async fn query(&self, query: &str) -> Result<Vec<String>> {
249 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
251
252 Ok(vec![format!("Async result for: {}", query)])
254 }
255
256 pub async fn answer_question(&self, question: &str) -> Result<GeneratedAnswer> {
258 let graph_guard = self.knowledge_graph.read().await;
259 let graph = graph_guard
260 .as_ref()
261 .ok_or_else(|| GraphRAGError::Generation {
262 message: "Knowledge graph not initialized".to_string(),
263 })?;
264
265 let llm = self
266 .language_model
267 .as_ref()
268 .ok_or_else(|| GraphRAGError::Generation {
269 message: "Language model not initialized".to_string(),
270 })?;
271
272 let search_results = self.async_retrieval(question, graph).await?;
274
275 let hierarchical_results = self.hierarchical_query(question, 5).await?;
277
278 self.generate_answer_async(question, search_results, hierarchical_results, llm)
280 .await
281 }
282
283 async fn async_retrieval(
285 &self,
286 query: &str,
287 graph: &KnowledgeGraph,
288 ) -> Result<Vec<SearchResult>> {
289 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
291
292 let mut results = Vec::new();
294 for (i, chunk) in graph.chunks().enumerate().take(3) {
295 if chunk.content.to_lowercase().contains(&query.to_lowercase()) {
296 results.push(SearchResult {
297 id: chunk.id.to_string(),
298 content: chunk.content.clone(),
299 score: 0.8 - (i as f32 * 0.1),
300 result_type: crate::retrieval::ResultType::Chunk,
301 entities: chunk.entities.iter().map(|e| e.to_string()).collect(),
302 source_chunks: vec![chunk.id.to_string()],
303 });
304 }
305 }
306
307 Ok(results)
308 }
309
310 pub async fn hierarchical_query(
312 &self,
313 query: &str,
314 max_results: usize,
315 ) -> Result<Vec<QueryResult>> {
316 let trees_guard = self.document_trees.read().await;
317 let mut all_results = Vec::new();
318
319 for tree in trees_guard.values() {
321 let tree_results = tree.query(query, max_results)?;
323 all_results.extend(tree_results);
324 }
325
326 all_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
328 all_results.truncate(max_results);
329
330 Ok(all_results)
331 }
332
333 async fn generate_answer_async(
335 &self,
336 question: &str,
337 search_results: Vec<SearchResult>,
338 hierarchical_results: Vec<QueryResult>,
339 llm: &BoxedAsyncLanguageModel,
340 ) -> Result<GeneratedAnswer> {
341 let context = self
343 .assemble_context_async(search_results, hierarchical_results)
344 .await?;
345
346 let prompt = self.create_qa_prompt(question, &context)?;
348
349 let response = llm.complete(&prompt).await?;
351
352 Ok(GeneratedAnswer {
354 answer_text: response,
355 confidence_score: context.confidence_score,
356 sources: context.get_sources(),
357 entities_mentioned: context.entities,
358 mode_used: crate::generation::AnswerMode::Abstractive,
359 context_quality: context.confidence_score,
360 })
361 }
362
363 async fn assemble_context_async(
365 &self,
366 search_results: Vec<SearchResult>,
367 hierarchical_results: Vec<QueryResult>,
368 ) -> Result<AnswerContext> {
369 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
371
372 let mut context = AnswerContext::new();
373
374 for result in search_results {
376 context.primary_chunks.push(result);
377 }
378
379 context.hierarchical_summaries = hierarchical_results;
381
382 let avg_score = if context.primary_chunks.is_empty() {
384 0.0
385 } else {
386 context.primary_chunks.iter().map(|r| r.score).sum::<f32>()
387 / context.primary_chunks.len() as f32
388 };
389
390 context.confidence_score = avg_score;
391 context.source_count = context.primary_chunks.len() + context.hierarchical_summaries.len();
392
393 Ok(context)
394 }
395
396 fn create_qa_prompt(&self, question: &str, context: &AnswerContext) -> Result<String> {
398 let combined_content = context.get_combined_content();
399
400 let mut values = HashMap::new();
401 values.insert("context".to_string(), combined_content);
402 values.insert("question".to_string(), question.to_string());
403
404 let template = PromptTemplate::new(
405 "Context:\n{context}\n\nQuestion: {question}\n\nBased on the provided context, please answer the question. If the context doesn't contain enough information, please say so.".to_string()
406 );
407
408 template.fill(&values)
409 }
410
411 pub async fn add_documents_batch(&mut self, documents: Vec<Document>) -> Result<()> {
413 tracing::info!(
414 document_count = documents.len(),
415 "Processing documents concurrently"
416 );
417
418 for document in documents {
421 self.add_document(document).await?;
422 }
423
424 tracing::info!("All documents processed successfully");
425 Ok(())
426 }
427
428 pub async fn answer_questions_batch(&self, questions: &[&str]) -> Result<Vec<GeneratedAnswer>> {
430 use futures::stream::{FuturesUnordered, StreamExt};
431
432 let mut futures = FuturesUnordered::new();
433
434 for question in questions {
435 futures.push(self.answer_question(question));
436 }
437
438 let mut answers = Vec::with_capacity(questions.len());
439 while let Some(result) = futures.next().await {
440 answers.push(result?);
441 }
442
443 Ok(answers)
444 }
445
446 pub async fn get_performance_stats(&self) -> AsyncPerformanceStats {
448 let graph_guard = self.knowledge_graph.read().await;
449 let trees_guard = self.document_trees.read().await;
450
451 AsyncPerformanceStats {
452 total_documents: trees_guard.len(),
453 total_entities: graph_guard.as_ref().map(|g| g.entity_count()).unwrap_or(0),
454 total_chunks: graph_guard
455 .as_ref()
456 .map(|g| g.chunks().count())
457 .unwrap_or(0),
458 health_status: AsyncHealthStatus::Healthy,
459 }
460 }
461
462 pub async fn health_check(&self) -> Result<AsyncHealthStatus> {
464 if let Some(llm) = &self.language_model {
466 if !llm.health_check().await.unwrap_or(false) {
467 return Ok(AsyncHealthStatus::Degraded);
468 }
469 }
470
471 let graph_guard = self.knowledge_graph.read().await;
473 if graph_guard.is_none() {
474 return Ok(AsyncHealthStatus::Degraded);
475 }
476
477 Ok(AsyncHealthStatus::Healthy)
478 }
479
480 pub async fn save_state_async(&self, output_dir: &str) -> Result<()> {
482 use std::fs;
483
484 fs::create_dir_all(output_dir)?;
486
487 let graph_guard = self.knowledge_graph.read().await;
489 if let Some(graph) = &*graph_guard {
490 graph.save_to_json(&format!("{output_dir}/async_knowledge_graph.json"))?;
491 }
492
493 let trees_guard = self.document_trees.read().await;
495 for (doc_id, tree) in trees_guard.iter() {
496 let filename = format!("{output_dir}/{doc_id}_async_tree.json");
497 let json_content = tree.to_json()?;
498 fs::write(&filename, json_content)?;
499 }
500
501 tracing::info!(output_dir = %output_dir, "Async state saved");
502 Ok(())
503 }
504}
505
506#[derive(Debug)]
508pub struct AsyncPerformanceStats {
509 pub total_documents: usize,
511 pub total_entities: usize,
513 pub total_chunks: usize,
515 pub health_status: AsyncHealthStatus,
517}
518
519#[derive(Debug, Clone, PartialEq, Eq)]
521pub enum AsyncHealthStatus {
522 Healthy,
524 Degraded,
526 Unhealthy,
528}
529
530pub struct AsyncGraphRAGBuilder {
532 config: Config,
533 language_model: Option<Arc<BoxedAsyncLanguageModel>>,
534 hierarchical_config: Option<HierarchicalConfig>,
535}
536
537impl AsyncGraphRAGBuilder {
538 pub fn new() -> Self {
540 Self {
541 config: Config::default(),
542 language_model: None,
543 hierarchical_config: None,
544 }
545 }
546
547 pub fn config(mut self, config: Config) -> Self {
549 self.config = config;
550 self
551 }
552
553 pub fn language_model(mut self, model: BoxedAsyncLanguageModel) -> Self {
555 self.language_model = Some(Arc::new(model));
556 self
557 }
558
559 pub fn hierarchical_config(mut self, config: HierarchicalConfig) -> Self {
561 self.hierarchical_config = Some(config);
562 self
563 }
564
565 #[cfg(feature = "async-traits")]
567 pub async fn with_async_mock_llm(mut self) -> Result<Self> {
568 let mock_llm = crate::generation::async_mock_llm::AsyncMockLLM::new().await?;
569 self.language_model = Some(Arc::new(mock_llm));
570 Ok(self)
571 }
572
573 #[cfg(all(feature = "ollama", feature = "async-traits"))]
575 pub async fn with_async_ollama(mut self, config: crate::ollama::OllamaConfig) -> Result<Self> {
576 let ollama_llm = crate::ollama::AsyncOllamaGenerator::new(config).await?;
577 self.language_model = Some(Arc::new(ollama_llm));
578 Ok(self)
579 }
580
581 pub async fn build(self) -> Result<AsyncGraphRAG> {
583 let hierarchical_config = self.hierarchical_config.unwrap_or_default();
584
585 let mut graphrag =
586 AsyncGraphRAG::with_hierarchical_config(self.config, hierarchical_config).await?;
587
588 if let Some(llm) = self.language_model {
589 graphrag.set_language_model(llm).await;
590 }
591
592 graphrag.initialize().await?;
593
594 Ok(graphrag)
595 }
596}
597
598impl Default for AsyncGraphRAGBuilder {
599 fn default() -> Self {
600 Self::new()
601 }
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607
608 #[tokio::test]
609 async fn test_async_graphrag_creation() {
610 let config = Config::default();
611 let graphrag = AsyncGraphRAG::new(config).await;
612 assert!(graphrag.is_ok());
613 }
614
615 #[tokio::test]
616 async fn test_async_graphrag_initialization() {
617 let config = Config::default();
618 let mut graphrag = AsyncGraphRAG::new(config).await.unwrap();
619 let result = graphrag.initialize().await;
620 assert!(result.is_ok());
621 }
622
623 #[tokio::test]
624 async fn test_async_builder() {
625 let result = AsyncGraphRAGBuilder::new().build().await;
626 assert!(result.is_ok());
627 }
628
629 #[tokio::test]
630 #[cfg(feature = "async-traits")]
631 async fn test_with_async_mock_llm() {
632 let result = AsyncGraphRAGBuilder::new()
633 .with_async_mock_llm()
634 .await
635 .unwrap()
636 .build()
637 .await;
638 assert!(result.is_ok());
639 }
640
641 #[tokio::test]
642 async fn test_health_check() {
643 let config = Config::default();
644 let mut graphrag = AsyncGraphRAG::new(config).await.unwrap();
645 graphrag.initialize().await.unwrap();
646
647 let health = graphrag.health_check().await.unwrap();
648 assert_eq!(health, AsyncHealthStatus::Healthy);
649 }
650
651 #[tokio::test]
652 async fn test_performance_stats() {
653 let config = Config::default();
654 let mut graphrag = AsyncGraphRAG::new(config).await.unwrap();
655 graphrag.initialize().await.unwrap();
656
657 let stats = graphrag.get_performance_stats().await;
658 assert_eq!(stats.total_documents, 0);
659 assert_eq!(stats.health_status, AsyncHealthStatus::Healthy);
660 }
661}