1use crate::{
7 config::Config,
8 core::{
9 traits::BoxedAsyncLanguageModel, Document, DocumentId, Entity, EntityId, GraphRAGError,
10 KnowledgeGraph, Result, TextChunk,
11 },
12 generation::{AnswerContext, GeneratedAnswer, PromptTemplate},
13 retrieval::SearchResult,
14 summarization::{DocumentTree, HierarchicalConfig, LLMClient, QueryResult},
15};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19
20pub struct AsyncLanguageModelAdapter {
22 model: Arc<BoxedAsyncLanguageModel>,
23}
24
25impl AsyncLanguageModelAdapter {
26 pub fn new(model: Arc<BoxedAsyncLanguageModel>) -> Self {
34 Self { model }
35 }
36}
37
38#[async_trait::async_trait]
39impl LLMClient for AsyncLanguageModelAdapter {
40 async fn generate_summary(
41 &self,
42 text: &str,
43 prompt: &str,
44 _max_tokens: usize,
45 _temperature: f32,
46 ) -> crate::Result<String> {
47 let full_prompt = format!("{}\n\nText: {}", prompt, text);
48
49 let response = self.model.complete(&full_prompt).await.map_err(|e| {
50 crate::core::GraphRAGError::Generation {
51 message: e.to_string(),
52 }
53 })?;
54
55 Ok(response)
56 }
57
58 fn model_name(&self) -> &str {
59 "async_language_model"
60 }
61}
62
63pub struct AsyncGraphRAG {
65 #[allow(dead_code)]
66 config: Config,
67 knowledge_graph: Arc<RwLock<Option<KnowledgeGraph>>>,
68 document_trees: Arc<RwLock<HashMap<DocumentId, DocumentTree>>>,
69 hierarchical_config: HierarchicalConfig,
70 language_model: Option<Arc<BoxedAsyncLanguageModel>>,
71}
72
73impl AsyncGraphRAG {
74 pub async fn new(config: Config) -> Result<Self> {
76 let hierarchical_config = config.summarization.clone();
77 Ok(Self {
78 config,
79 knowledge_graph: Arc::new(RwLock::new(None)),
80 document_trees: Arc::new(RwLock::new(HashMap::new())),
81 hierarchical_config,
82 language_model: None,
83 })
84 }
85
86 pub async fn with_hierarchical_config(
88 config: Config,
89 hierarchical_config: HierarchicalConfig,
90 ) -> Result<Self> {
91 Ok(Self {
92 config,
93 knowledge_graph: Arc::new(RwLock::new(None)),
94 document_trees: Arc::new(RwLock::new(HashMap::new())),
95 hierarchical_config,
96 language_model: None,
97 })
98 }
99
100 pub async fn set_language_model(&mut self, model: Arc<BoxedAsyncLanguageModel>) {
102 self.language_model = Some(model);
103 }
104
105 pub async fn initialize(&mut self) -> Result<()> {
107 tracing::info!("Initializing async GraphRAG system");
108
109 {
111 let mut graph_guard = self.knowledge_graph.write().await;
112 *graph_guard = Some(KnowledgeGraph::new());
113 }
114
115 if self.language_model.is_none() {
117 #[cfg(feature = "async-traits")]
118 {
119 let mock_llm = crate::generation::async_mock_llm::AsyncMockLLM::new().await?;
120 self.language_model = Some(Arc::new(mock_llm));
121 }
122 #[cfg(not(feature = "async-traits"))]
123 {
124 return Err(GraphRAGError::Config {
125 message: "No async language model available and async-traits feature disabled"
126 .to_string(),
127 });
128 }
129 }
130
131 tracing::info!("Async GraphRAG system initialized successfully");
132 Ok(())
133 }
134
135 pub async fn add_document(&mut self, document: Document) -> Result<()> {
137 self.build_document_tree(&document).await?;
139
140 let mut graph_guard = self.knowledge_graph.write().await;
141 let graph = graph_guard.as_mut().ok_or_else(|| GraphRAGError::Config {
142 message: "Knowledge graph not initialized".to_string(),
143 })?;
144
145 graph.add_document(document)
146 }
147
148 pub async fn build_document_tree(&mut self, document: &Document) -> Result<()> {
150 if document.chunks.is_empty() {
151 return Ok(());
152 }
153
154 tracing::debug!(document_id = %document.id, "Building hierarchical tree for document");
155
156 let tree = if self.hierarchical_config.llm_config.enabled {
157 if let Some(ref lm) = self.language_model {
159 let llm_client = Arc::new(AsyncLanguageModelAdapter::new(Arc::clone(lm)));
160 DocumentTree::with_llm_client(
161 document.id.clone(),
162 self.hierarchical_config.clone(),
163 llm_client,
164 )?
165 } else {
166 DocumentTree::new(document.id.clone(), self.hierarchical_config.clone())?
167 }
168 } else {
169 DocumentTree::new(document.id.clone(), self.hierarchical_config.clone())?
171 };
172 {
175 let mut trees_guard = self.document_trees.write().await;
176 trees_guard.insert(document.id.clone(), tree);
177 }
178
179 Ok(())
180 }
181
182 pub async fn build_graph(&mut self) -> Result<()> {
184 let mut graph_guard = self.knowledge_graph.write().await;
185 let graph = graph_guard.as_mut().ok_or_else(|| GraphRAGError::Config {
186 message: "Knowledge graph not initialized".to_string(),
187 })?;
188
189 tracing::info!("Building knowledge graph asynchronously");
190
191 let chunks: Vec<_> = graph.chunks().cloned().collect();
193 let mut total_entities = 0;
194
195 for chunk in &chunks {
197 let entities = self.extract_entities_async(chunk).await?;
199
200 let mut chunk_entity_ids = Vec::new();
202 for entity in entities {
203 chunk_entity_ids.push(entity.id.clone());
204 graph.add_entity(entity)?;
205 total_entities += 1;
206 }
207
208 if let Some(existing_chunk) = graph.get_chunk_mut(&chunk.id) {
210 existing_chunk.entities = chunk_entity_ids;
211 }
212 }
213
214 tracing::info!(
215 entity_count = total_entities,
216 "Knowledge graph built asynchronously"
217 );
218 Ok(())
219 }
220
221 async fn extract_entities_async(&self, chunk: &TextChunk) -> Result<Vec<Entity>> {
223 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
225
226 let content = chunk.content.to_lowercase();
228 let mut entities = Vec::new();
229
230 let names = ["tom", "huck", "polly", "sid", "mary", "jim"];
232 for (i, name) in names.iter().enumerate() {
233 if content.contains(name) {
234 let entity = Entity::new(
235 EntityId::new(format!("{name}_{i}")),
236 name.to_string(),
237 "PERSON".to_string(),
238 0.8,
239 );
240 entities.push(entity);
241 }
242 }
243
244 Ok(entities)
245 }
246
247 pub async fn query(&self, query: &str) -> Result<Vec<String>> {
249 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
251
252 Ok(vec![format!("Async result for: {}", query)])
254 }
255
256 pub async fn answer_question(&self, question: &str) -> Result<GeneratedAnswer> {
258 let graph_guard = self.knowledge_graph.read().await;
259 let graph = graph_guard
260 .as_ref()
261 .ok_or_else(|| GraphRAGError::Generation {
262 message: "Knowledge graph not initialized".to_string(),
263 })?;
264
265 let llm = self
266 .language_model
267 .as_ref()
268 .ok_or_else(|| GraphRAGError::Generation {
269 message: "Language model not initialized".to_string(),
270 })?;
271
272 let search_results = self.async_retrieval(question, graph).await?;
274
275 let hierarchical_results = self.hierarchical_query(question, 5).await?;
277
278 self.generate_answer_async(question, search_results, hierarchical_results, llm)
280 .await
281 }
282
283 async fn async_retrieval(
285 &self,
286 query: &str,
287 graph: &KnowledgeGraph,
288 ) -> Result<Vec<SearchResult>> {
289 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
291
292 let mut results = Vec::new();
294 for (i, chunk) in graph.chunks().enumerate().take(3) {
295 if chunk.content.to_lowercase().contains(&query.to_lowercase()) {
296 results.push(SearchResult {
297 id: chunk.id.to_string(),
298 content: chunk.content.clone(),
299 score: 0.8 - (i as f32 * 0.1),
300 result_type: crate::retrieval::ResultType::Chunk,
301 entities: chunk.entities.iter().map(|e| e.to_string()).collect(),
302 source_chunks: vec![chunk.id.to_string()],
303 });
304 }
305 }
306
307 Ok(results)
308 }
309
310 pub async fn hierarchical_query(
312 &self,
313 query: &str,
314 max_results: usize,
315 ) -> Result<Vec<QueryResult>> {
316 let trees_guard = self.document_trees.read().await;
317 let mut all_results = Vec::new();
318
319 for tree in trees_guard.values() {
321 let tree_results = tree.query(query, max_results)?;
323 all_results.extend(tree_results);
324 }
325
326 all_results.sort_by(|a, b| {
328 b.score
329 .partial_cmp(&a.score)
330 .unwrap_or(std::cmp::Ordering::Equal)
331 });
332 all_results.truncate(max_results);
333
334 Ok(all_results)
335 }
336
337 async fn generate_answer_async(
339 &self,
340 question: &str,
341 search_results: Vec<SearchResult>,
342 hierarchical_results: Vec<QueryResult>,
343 llm: &BoxedAsyncLanguageModel,
344 ) -> Result<GeneratedAnswer> {
345 let context = self
347 .assemble_context_async(search_results, hierarchical_results)
348 .await?;
349
350 let prompt = self.create_qa_prompt(question, &context)?;
352
353 let response = llm.complete(&prompt).await?;
355
356 Ok(GeneratedAnswer {
358 answer_text: response,
359 confidence_score: context.confidence_score,
360 sources: context.get_sources(),
361 entities_mentioned: context.entities,
362 mode_used: crate::generation::AnswerMode::Abstractive,
363 context_quality: context.confidence_score,
364 })
365 }
366
367 async fn assemble_context_async(
369 &self,
370 search_results: Vec<SearchResult>,
371 hierarchical_results: Vec<QueryResult>,
372 ) -> Result<AnswerContext> {
373 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
375
376 let mut context = AnswerContext::new();
377
378 for result in search_results {
380 context.primary_chunks.push(result);
381 }
382
383 context.hierarchical_summaries = hierarchical_results;
385
386 let avg_score = if context.primary_chunks.is_empty() {
388 0.0
389 } else {
390 context.primary_chunks.iter().map(|r| r.score).sum::<f32>()
391 / context.primary_chunks.len() as f32
392 };
393
394 context.confidence_score = avg_score;
395 context.source_count = context.primary_chunks.len() + context.hierarchical_summaries.len();
396
397 Ok(context)
398 }
399
400 fn create_qa_prompt(&self, question: &str, context: &AnswerContext) -> Result<String> {
402 let combined_content = context.get_combined_content();
403
404 let mut values = HashMap::new();
405 values.insert("context".to_string(), combined_content);
406 values.insert("question".to_string(), question.to_string());
407
408 let template = PromptTemplate::new(
409 "Context:\n{context}\n\nQuestion: {question}\n\nBased on the provided context, please answer the question. If the context doesn't contain enough information, please say so.".to_string()
410 );
411
412 template.fill(&values)
413 }
414
415 pub async fn add_documents_batch(&mut self, documents: Vec<Document>) -> Result<()> {
417 tracing::info!(
418 document_count = documents.len(),
419 "Processing documents concurrently"
420 );
421
422 for document in documents {
425 self.add_document(document).await?;
426 }
427
428 tracing::info!("All documents processed successfully");
429 Ok(())
430 }
431
432 pub async fn answer_questions_batch(&self, questions: &[&str]) -> Result<Vec<GeneratedAnswer>> {
434 use futures::stream::{FuturesUnordered, StreamExt};
435
436 let mut futures = FuturesUnordered::new();
437
438 for question in questions {
439 futures.push(self.answer_question(question));
440 }
441
442 let mut answers = Vec::with_capacity(questions.len());
443 while let Some(result) = futures.next().await {
444 answers.push(result?);
445 }
446
447 Ok(answers)
448 }
449
450 pub async fn get_performance_stats(&self) -> AsyncPerformanceStats {
452 let graph_guard = self.knowledge_graph.read().await;
453 let trees_guard = self.document_trees.read().await;
454
455 AsyncPerformanceStats {
456 total_documents: trees_guard.len(),
457 total_entities: graph_guard.as_ref().map(|g| g.entity_count()).unwrap_or(0),
458 total_chunks: graph_guard
459 .as_ref()
460 .map(|g| g.chunks().count())
461 .unwrap_or(0),
462 health_status: AsyncHealthStatus::Healthy,
463 }
464 }
465
466 pub async fn health_check(&self) -> Result<AsyncHealthStatus> {
468 if let Some(llm) = &self.language_model {
470 if !llm.health_check().await.unwrap_or(false) {
471 return Ok(AsyncHealthStatus::Degraded);
472 }
473 }
474
475 let graph_guard = self.knowledge_graph.read().await;
477 if graph_guard.is_none() {
478 return Ok(AsyncHealthStatus::Degraded);
479 }
480
481 Ok(AsyncHealthStatus::Healthy)
482 }
483
484 pub async fn save_state_async(&self, output_dir: &str) -> Result<()> {
486 use std::fs;
487
488 fs::create_dir_all(output_dir)?;
490
491 let graph_guard = self.knowledge_graph.read().await;
493 if let Some(graph) = &*graph_guard {
494 graph.save_to_json(&format!("{output_dir}/async_knowledge_graph.json"))?;
495 }
496
497 let trees_guard = self.document_trees.read().await;
499 for (doc_id, tree) in trees_guard.iter() {
500 let filename = format!("{output_dir}/{doc_id}_async_tree.json");
501 let json_content = tree.to_json()?;
502 fs::write(&filename, json_content)?;
503 }
504
505 tracing::info!(output_dir = %output_dir, "Async state saved");
506 Ok(())
507 }
508}
509
510#[derive(Debug)]
512pub struct AsyncPerformanceStats {
513 pub total_documents: usize,
515 pub total_entities: usize,
517 pub total_chunks: usize,
519 pub health_status: AsyncHealthStatus,
521}
522
523#[derive(Debug, Clone, PartialEq, Eq)]
525pub enum AsyncHealthStatus {
526 Healthy,
528 Degraded,
530 Unhealthy,
532}
533
534pub struct AsyncGraphRAGBuilder {
536 config: Config,
537 language_model: Option<Arc<BoxedAsyncLanguageModel>>,
538 hierarchical_config: Option<HierarchicalConfig>,
539}
540
541impl AsyncGraphRAGBuilder {
542 pub fn new() -> Self {
544 Self {
545 config: Config::default(),
546 language_model: None,
547 hierarchical_config: None,
548 }
549 }
550
551 pub fn config(mut self, config: Config) -> Self {
553 self.config = config;
554 self
555 }
556
557 pub fn language_model(mut self, model: BoxedAsyncLanguageModel) -> Self {
559 self.language_model = Some(Arc::new(model));
560 self
561 }
562
563 pub fn hierarchical_config(mut self, config: HierarchicalConfig) -> Self {
565 self.hierarchical_config = Some(config);
566 self
567 }
568
569 #[cfg(feature = "async-traits")]
571 pub async fn with_async_mock_llm(mut self) -> Result<Self> {
572 let mock_llm = crate::generation::async_mock_llm::AsyncMockLLM::new().await?;
573 self.language_model = Some(Arc::new(mock_llm));
574 Ok(self)
575 }
576
577 #[cfg(all(feature = "ollama", feature = "async-traits"))]
579 pub async fn with_async_ollama(mut self, config: crate::ollama::OllamaConfig) -> Result<Self> {
580 let ollama_llm = crate::ollama::AsyncOllamaGenerator::new(config).await?;
581 self.language_model = Some(Arc::new(ollama_llm));
582 Ok(self)
583 }
584
585 pub async fn build(self) -> Result<AsyncGraphRAG> {
587 let hierarchical_config = self.hierarchical_config.unwrap_or_default();
588
589 let mut graphrag =
590 AsyncGraphRAG::with_hierarchical_config(self.config, hierarchical_config).await?;
591
592 if let Some(llm) = self.language_model {
593 graphrag.set_language_model(llm).await;
594 }
595
596 graphrag.initialize().await?;
597
598 Ok(graphrag)
599 }
600}
601
602impl Default for AsyncGraphRAGBuilder {
603 fn default() -> Self {
604 Self::new()
605 }
606}