1use crate::{
7 config::Config,
8 core::{
9 traits::BoxedAsyncLanguageModel,
10 Document, DocumentId, Entity, EntityId, GraphRAGError, KnowledgeGraph, Result,
11 TextChunk,
12 },
13 generation::{AnswerContext, GeneratedAnswer, PromptTemplate},
14 retrieval::SearchResult,
15 summarization::{DocumentTree, HierarchicalConfig, QueryResult, LLMClient},
16};
17use std::collections::HashMap;
18use std::sync::Arc;
19use tokio::sync::RwLock;
20
21pub struct AsyncLanguageModelAdapter {
23 model: Arc<BoxedAsyncLanguageModel>,
24}
25
26impl AsyncLanguageModelAdapter {
27 pub fn new(model: Arc<BoxedAsyncLanguageModel>) -> Self {
35 Self { model }
36 }
37}
38
39#[async_trait::async_trait]
40impl LLMClient for AsyncLanguageModelAdapter {
41 async fn generate_summary(
42 &self,
43 text: &str,
44 prompt: &str,
45 _max_tokens: usize,
46 _temperature: f32,
47 ) -> crate::Result<String> {
48 let full_prompt = format!("{}\n\nText: {}", prompt, text);
49
50 let response = self.model
51 .complete(&full_prompt)
52 .await.map_err(|e| crate::core::GraphRAGError::Generation { message: e.to_string() })?;
53
54 Ok(response)
55 }
56
57 fn model_name(&self) -> &str {
58 "async_language_model"
59 }
60}
61
62pub struct AsyncGraphRAG {
64 #[allow(dead_code)] config: Config,
65 knowledge_graph: Arc<RwLock<Option<KnowledgeGraph>>>,
66 document_trees: Arc<RwLock<HashMap<DocumentId, DocumentTree>>>,
67 hierarchical_config: HierarchicalConfig,
68 language_model: Option<Arc<BoxedAsyncLanguageModel>>,
69}
70
71impl AsyncGraphRAG {
72 pub async fn new(config: Config) -> Result<Self> {
74 let hierarchical_config = config.summarization.clone();
75 Ok(Self {
76 config,
77 knowledge_graph: Arc::new(RwLock::new(None)),
78 document_trees: Arc::new(RwLock::new(HashMap::new())),
79 hierarchical_config,
80 language_model: None,
81 })
82 }
83
84 pub async fn with_hierarchical_config(
86 config: Config,
87 hierarchical_config: HierarchicalConfig,
88 ) -> Result<Self> {
89 Ok(Self {
90 config,
91 knowledge_graph: Arc::new(RwLock::new(None)),
92 document_trees: Arc::new(RwLock::new(HashMap::new())),
93 hierarchical_config,
94 language_model: None,
95 })
96 }
97
98 pub async fn set_language_model(&mut self, model: Arc<BoxedAsyncLanguageModel>) {
100 self.language_model = Some(model);
101 }
102
103 pub async fn initialize(&mut self) -> Result<()> {
105 tracing::info!("Initializing async GraphRAG system");
106
107 {
109 let mut graph_guard = self.knowledge_graph.write().await;
110 *graph_guard = Some(KnowledgeGraph::new());
111 }
112
113 if self.language_model.is_none() {
115 #[cfg(feature = "async-traits")]
116 {
117 let mock_llm = crate::generation::async_mock_llm::AsyncMockLLM::new().await?;
118 self.language_model = Some(Arc::new(mock_llm));
119 }
120 #[cfg(not(feature = "async-traits"))]
121 {
122 return Err(GraphRAGError::Config {
123 message: "No async language model available and async-traits feature disabled".to_string(),
124 });
125 }
126 }
127
128 tracing::info!("Async GraphRAG system initialized successfully");
129 Ok(())
130 }
131
132 pub async fn add_document(&mut self, document: Document) -> Result<()> {
134 self.build_document_tree(&document).await?;
136
137 let mut graph_guard = self.knowledge_graph.write().await;
138 let graph = graph_guard
139 .as_mut()
140 .ok_or_else(|| GraphRAGError::Config {
141 message: "Knowledge graph not initialized".to_string(),
142 })?;
143
144 graph.add_document(document)
145 }
146
147 pub async fn build_document_tree(&mut self, document: &Document) -> Result<()> {
149 if document.chunks.is_empty() {
150 return Ok(());
151 }
152
153 tracing::debug!(document_id = %document.id, "Building hierarchical tree for document");
154
155 let tree = if self.hierarchical_config.llm_config.enabled {
156 if let Some(ref lm) = self.language_model {
158 let llm_client = Arc::new(AsyncLanguageModelAdapter::new(Arc::clone(lm)));
159 DocumentTree::with_llm_client(
160 document.id.clone(),
161 self.hierarchical_config.clone(),
162 llm_client,
163 )?
164 } else {
165 DocumentTree::new(document.id.clone(), self.hierarchical_config.clone())?
166 }
167 } else {
168 DocumentTree::new(document.id.clone(), self.hierarchical_config.clone())?
170 };
171 {
174 let mut trees_guard = self.document_trees.write().await;
175 trees_guard.insert(document.id.clone(), tree);
176 }
177
178 Ok(())
179 }
180
181 pub async fn build_graph(&mut self) -> Result<()> {
183 let mut graph_guard = self.knowledge_graph.write().await;
184 let graph = graph_guard
185 .as_mut()
186 .ok_or_else(|| GraphRAGError::Config {
187 message: "Knowledge graph not initialized".to_string(),
188 })?;
189
190 tracing::info!("Building knowledge graph asynchronously");
191
192 let chunks: Vec<_> = graph.chunks().cloned().collect();
194 let mut total_entities = 0;
195
196 for chunk in &chunks {
198 let entities = self.extract_entities_async(chunk).await?;
200
201 let mut chunk_entity_ids = Vec::new();
203 for entity in entities {
204 chunk_entity_ids.push(entity.id.clone());
205 graph.add_entity(entity)?;
206 total_entities += 1;
207 }
208
209 if let Some(existing_chunk) = graph.get_chunk_mut(&chunk.id) {
211 existing_chunk.entities = chunk_entity_ids;
212 }
213 }
214
215 tracing::info!(entity_count = total_entities, "Knowledge graph built asynchronously");
216 Ok(())
217 }
218
219 async fn extract_entities_async(&self, chunk: &TextChunk) -> Result<Vec<Entity>> {
221 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
223
224 let content = chunk.content.to_lowercase();
226 let mut entities = Vec::new();
227
228 let names = ["tom", "huck", "polly", "sid", "mary", "jim"];
230 for (i, name) in names.iter().enumerate() {
231 if content.contains(name) {
232 let entity = Entity::new(
233 EntityId::new(format!("{name}_{i}")),
234 name.to_string(),
235 "PERSON".to_string(),
236 0.8,
237 );
238 entities.push(entity);
239 }
240 }
241
242 Ok(entities)
243 }
244
245 pub async fn query(&self, query: &str) -> Result<Vec<String>> {
247 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
249
250 Ok(vec![format!("Async result for: {}", query)])
252 }
253
254 pub async fn answer_question(&self, question: &str) -> Result<GeneratedAnswer> {
256 let graph_guard = self.knowledge_graph.read().await;
257 let graph = graph_guard
258 .as_ref()
259 .ok_or_else(|| GraphRAGError::Generation {
260 message: "Knowledge graph not initialized".to_string(),
261 })?;
262
263 let llm = self
264 .language_model
265 .as_ref()
266 .ok_or_else(|| GraphRAGError::Generation {
267 message: "Language model not initialized".to_string(),
268 })?;
269
270 let search_results = self.async_retrieval(question, graph).await?;
272
273 let hierarchical_results = self.hierarchical_query(question, 5).await?;
275
276 self.generate_answer_async(question, search_results, hierarchical_results, llm)
278 .await
279 }
280
281 async fn async_retrieval(
283 &self,
284 query: &str,
285 graph: &KnowledgeGraph,
286 ) -> Result<Vec<SearchResult>> {
287 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
289
290 let mut results = Vec::new();
292 for (i, chunk) in graph.chunks().enumerate().take(3) {
293 if chunk.content.to_lowercase().contains(&query.to_lowercase()) {
294 results.push(SearchResult {
295 id: chunk.id.to_string(),
296 content: chunk.content.clone(),
297 score: 0.8 - (i as f32 * 0.1),
298 result_type: crate::retrieval::ResultType::Chunk,
299 entities: chunk.entities.iter().map(|e| e.to_string()).collect(),
300 source_chunks: vec![chunk.id.to_string()],
301 });
302 }
303 }
304
305 Ok(results)
306 }
307
308 pub async fn hierarchical_query(
310 &self,
311 query: &str,
312 max_results: usize,
313 ) -> Result<Vec<QueryResult>> {
314 let trees_guard = self.document_trees.read().await;
315 let mut all_results = Vec::new();
316
317 for tree in trees_guard.values() {
319 let tree_results = tree.query(query, max_results)?;
321 all_results.extend(tree_results);
322 }
323
324 all_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
326 all_results.truncate(max_results);
327
328 Ok(all_results)
329 }
330
331 async fn generate_answer_async(
333 &self,
334 question: &str,
335 search_results: Vec<SearchResult>,
336 hierarchical_results: Vec<QueryResult>,
337 llm: &BoxedAsyncLanguageModel,
338 ) -> Result<GeneratedAnswer> {
339 let context = self.assemble_context_async(search_results, hierarchical_results).await?;
341
342 let prompt = self.create_qa_prompt(question, &context)?;
344
345 let response = llm.complete(&prompt).await?;
347
348 Ok(GeneratedAnswer {
350 answer_text: response,
351 confidence_score: context.confidence_score,
352 sources: context.get_sources(),
353 entities_mentioned: context.entities,
354 mode_used: crate::generation::AnswerMode::Abstractive,
355 context_quality: context.confidence_score,
356 })
357 }
358
359 async fn assemble_context_async(
361 &self,
362 search_results: Vec<SearchResult>,
363 hierarchical_results: Vec<QueryResult>,
364 ) -> Result<AnswerContext> {
365 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
367
368 let mut context = AnswerContext::new();
369
370 for result in search_results {
372 context.primary_chunks.push(result);
373 }
374
375 context.hierarchical_summaries = hierarchical_results;
377
378 let avg_score = if context.primary_chunks.is_empty() {
380 0.0
381 } else {
382 context.primary_chunks.iter().map(|r| r.score).sum::<f32>()
383 / context.primary_chunks.len() as f32
384 };
385
386 context.confidence_score = avg_score;
387 context.source_count = context.primary_chunks.len() + context.hierarchical_summaries.len();
388
389 Ok(context)
390 }
391
392 fn create_qa_prompt(&self, question: &str, context: &AnswerContext) -> Result<String> {
394 let combined_content = context.get_combined_content();
395
396 let mut values = HashMap::new();
397 values.insert("context".to_string(), combined_content);
398 values.insert("question".to_string(), question.to_string());
399
400 let template = PromptTemplate::new(
401 "Context:\n{context}\n\nQuestion: {question}\n\nBased on the provided context, please answer the question. If the context doesn't contain enough information, please say so.".to_string()
402 );
403
404 template.fill(&values)
405 }
406
407 pub async fn add_documents_batch(&mut self, documents: Vec<Document>) -> Result<()> {
409 tracing::info!(document_count = documents.len(), "Processing documents concurrently");
410
411 for document in documents {
414 self.add_document(document).await?;
415 }
416
417 tracing::info!("All documents processed successfully");
418 Ok(())
419 }
420
421 pub async fn answer_questions_batch(&self, questions: &[&str]) -> Result<Vec<GeneratedAnswer>> {
423 use futures::stream::{FuturesUnordered, StreamExt};
424
425 let mut futures = FuturesUnordered::new();
426
427 for question in questions {
428 futures.push(self.answer_question(question));
429 }
430
431 let mut answers = Vec::with_capacity(questions.len());
432 while let Some(result) = futures.next().await {
433 answers.push(result?);
434 }
435
436 Ok(answers)
437 }
438
439 pub async fn get_performance_stats(&self) -> AsyncPerformanceStats {
441 let graph_guard = self.knowledge_graph.read().await;
442 let trees_guard = self.document_trees.read().await;
443
444 AsyncPerformanceStats {
445 total_documents: trees_guard.len(),
446 total_entities: graph_guard
447 .as_ref()
448 .map(|g| g.entity_count())
449 .unwrap_or(0),
450 total_chunks: graph_guard
451 .as_ref()
452 .map(|g| g.chunks().count())
453 .unwrap_or(0),
454 health_status: AsyncHealthStatus::Healthy,
455 }
456 }
457
458 pub async fn health_check(&self) -> Result<AsyncHealthStatus> {
460 if let Some(llm) = &self.language_model {
462 if !llm.health_check().await.unwrap_or(false) {
463 return Ok(AsyncHealthStatus::Degraded);
464 }
465 }
466
467 let graph_guard = self.knowledge_graph.read().await;
469 if graph_guard.is_none() {
470 return Ok(AsyncHealthStatus::Degraded);
471 }
472
473 Ok(AsyncHealthStatus::Healthy)
474 }
475
476 pub async fn save_state_async(&self, output_dir: &str) -> Result<()> {
478 use std::fs;
479
480 fs::create_dir_all(output_dir)?;
482
483 let graph_guard = self.knowledge_graph.read().await;
485 if let Some(graph) = &*graph_guard {
486 graph.save_to_json(&format!("{output_dir}/async_knowledge_graph.json"))?;
487 }
488
489 let trees_guard = self.document_trees.read().await;
491 for (doc_id, tree) in trees_guard.iter() {
492 let filename = format!("{output_dir}/{doc_id}_async_tree.json");
493 let json_content = tree.to_json()?;
494 fs::write(&filename, json_content)?;
495 }
496
497 tracing::info!(output_dir = %output_dir, "Async state saved");
498 Ok(())
499 }
500}
501
502#[derive(Debug)]
504pub struct AsyncPerformanceStats {
505 pub total_documents: usize,
507 pub total_entities: usize,
509 pub total_chunks: usize,
511 pub health_status: AsyncHealthStatus,
513}
514
515#[derive(Debug, Clone, PartialEq, Eq)]
517pub enum AsyncHealthStatus {
518 Healthy,
520 Degraded,
522 Unhealthy,
524}
525
526pub struct AsyncGraphRAGBuilder {
528 config: Config,
529 language_model: Option<Arc<BoxedAsyncLanguageModel>>,
530 hierarchical_config: Option<HierarchicalConfig>,
531}
532
533impl AsyncGraphRAGBuilder {
534 pub fn new() -> Self {
536 Self {
537 config: Config::default(),
538 language_model: None,
539 hierarchical_config: None,
540 }
541 }
542
543 pub fn config(mut self, config: Config) -> Self {
545 self.config = config;
546 self
547 }
548
549 pub fn language_model(mut self, model: BoxedAsyncLanguageModel) -> Self {
551 self.language_model = Some(Arc::new(model));
552 self
553 }
554
555 pub fn hierarchical_config(mut self, config: HierarchicalConfig) -> Self {
557 self.hierarchical_config = Some(config);
558 self
559 }
560
561 #[cfg(feature = "async-traits")]
563 pub async fn with_async_mock_llm(mut self) -> Result<Self> {
564 let mock_llm = crate::generation::async_mock_llm::AsyncMockLLM::new().await?;
565 self.language_model = Some(Arc::new(mock_llm));
566 Ok(self)
567 }
568
569 #[cfg(all(feature = "ollama", feature = "async-traits"))]
571 pub async fn with_async_ollama(
572 mut self,
573 config: crate::ollama::OllamaConfig,
574 ) -> Result<Self> {
575 let ollama_llm = crate::ollama::AsyncOllamaGenerator::new(config).await?;
576 self.language_model = Some(Arc::new(ollama_llm));
577 Ok(self)
578 }
579
580 pub async fn build(self) -> Result<AsyncGraphRAG> {
582 let hierarchical_config = self.hierarchical_config.unwrap_or_default();
583
584 let mut graphrag = AsyncGraphRAG::with_hierarchical_config(self.config, hierarchical_config).await?;
585
586 if let Some(llm) = self.language_model {
587 graphrag.set_language_model(llm).await;
588 }
589
590 graphrag.initialize().await?;
591
592 Ok(graphrag)
593 }
594}
595
596impl Default for AsyncGraphRAGBuilder {
597 fn default() -> Self {
598 Self::new()
599 }
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 #[tokio::test]
607 async fn test_async_graphrag_creation() {
608 let config = Config::default();
609 let graphrag = AsyncGraphRAG::new(config).await;
610 assert!(graphrag.is_ok());
611 }
612
613 #[tokio::test]
614 async fn test_async_graphrag_initialization() {
615 let config = Config::default();
616 let mut graphrag = AsyncGraphRAG::new(config).await.unwrap();
617 let result = graphrag.initialize().await;
618 assert!(result.is_ok());
619 }
620
621 #[tokio::test]
622 async fn test_async_builder() {
623 let result = AsyncGraphRAGBuilder::new().build().await;
624 assert!(result.is_ok());
625 }
626
627 #[tokio::test]
628 #[cfg(feature = "async-traits")]
629 async fn test_with_async_mock_llm() {
630 let result = AsyncGraphRAGBuilder::new()
631 .with_async_mock_llm()
632 .await
633 .unwrap()
634 .build()
635 .await;
636 assert!(result.is_ok());
637 }
638
639 #[tokio::test]
640 async fn test_health_check() {
641 let config = Config::default();
642 let mut graphrag = AsyncGraphRAG::new(config).await.unwrap();
643 graphrag.initialize().await.unwrap();
644
645 let health = graphrag.health_check().await.unwrap();
646 assert_eq!(health, AsyncHealthStatus::Healthy);
647 }
648
649 #[tokio::test]
650 async fn test_performance_stats() {
651 let config = Config::default();
652 let mut graphrag = AsyncGraphRAG::new(config).await.unwrap();
653 graphrag.initialize().await.unwrap();
654
655 let stats = graphrag.get_performance_stats().await;
656 assert_eq!(stats.total_documents, 0);
657 assert_eq!(stats.health_status, AsyncHealthStatus::Healthy);
658 }
659}