1use super::{
6 algorithms::PageRankConfig,
7 entity::{
8 entities_to_nodes, relationships_to_edges, EntityExtractionConfig, EntityExtractor,
9 RuleBasedEntityExtractor,
10 },
11 query_expansion::{ExpansionConfig, ExpansionStrategy},
12 storage::{GraphStorage, GraphStorageConfig, InMemoryGraphStorage},
13 GraphNode, GraphRetrievalConfig, GraphRetriever, KnowledgeGraph,
14};
15use crate::{Document, DocumentChunk, RragResult};
16use async_trait::async_trait;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19
20pub struct GraphRetrievalBuilder {
22 config: GraphBuildConfig,
24
25 entity_extractor: Option<Box<dyn EntityExtractor>>,
27
28 storage: Option<Box<dyn GraphStorage>>,
30
31 _embedding_service: Option<()>,
33
34 retrieval_config: GraphRetrievalConfig,
36}
37
38#[derive(Debug, Clone)]
40pub struct GraphBuildConfig {
41 pub entity_config: EntityExtractionConfig,
43
44 pub storage_config: GraphStorageConfig,
46
47 pub expansion_config: ExpansionConfig,
49
50 pub generate_entity_embeddings: bool,
52
53 pub calculate_pagerank: bool,
55
56 pub batch_size: usize,
58
59 pub enable_parallel_processing: bool,
61
62 pub num_workers: usize,
64}
65
66impl Default for GraphBuildConfig {
67 fn default() -> Self {
68 Self {
69 entity_config: EntityExtractionConfig::default(),
70 storage_config: GraphStorageConfig::default(),
71 expansion_config: ExpansionConfig::default(),
72 generate_entity_embeddings: true,
73 calculate_pagerank: true,
74 batch_size: 100,
75 enable_parallel_processing: true,
76 num_workers: num_cpus::get(),
77 }
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct GraphBuildProgress {
84 pub phase: BuildPhase,
86
87 pub documents_processed: usize,
89
90 pub total_documents: usize,
92
93 pub entities_extracted: usize,
95
96 pub relationships_found: usize,
98
99 pub graph_nodes: usize,
101
102 pub graph_edges: usize,
104
105 pub processing_speed: f32,
107
108 pub estimated_remaining_seconds: u64,
110
111 pub errors: Vec<String>,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub enum BuildPhase {
118 Initializing,
120
121 EntityExtraction,
123
124 GraphConstruction,
126
127 EmbeddingGeneration,
129
130 MetricComputation,
132
133 Indexing,
135
136 Completed,
138
139 Failed(String),
141}
142
143#[async_trait]
145pub trait ProgressCallback: Send + Sync {
146 async fn on_progress(&self, progress: &GraphBuildProgress);
147}
148
149impl GraphRetrievalBuilder {
150 pub fn new() -> Self {
152 Self {
153 config: GraphBuildConfig::default(),
154 entity_extractor: None,
155 storage: None,
156 _embedding_service: None,
157 retrieval_config: GraphRetrievalConfig::default(),
158 }
159 }
160
161 pub fn with_config(mut self, config: GraphBuildConfig) -> Self {
163 self.config = config;
164 self
165 }
166
167 pub fn with_entity_extractor(mut self, extractor: Box<dyn EntityExtractor>) -> Self {
169 self.entity_extractor = Some(extractor);
170 self
171 }
172
173 pub fn with_rule_based_entity_extractor(
175 mut self,
176 config: EntityExtractionConfig,
177 ) -> RragResult<Self> {
178 let extractor = RuleBasedEntityExtractor::new(config)?;
179 self.entity_extractor = Some(Box::new(extractor));
180 Ok(self)
181 }
182
183 pub fn with_storage(mut self, storage: Box<dyn GraphStorage>) -> Self {
185 self.storage = Some(storage);
186 self
187 }
188
189 pub fn with_in_memory_storage(mut self, config: GraphStorageConfig) -> Self {
191 let storage = InMemoryGraphStorage::with_config(config);
192 self.storage = Some(Box::new(storage));
193 self
194 }
195
196 pub fn with_embedding_service(mut self) -> Self {
198 self._embedding_service = Some(());
199 self
200 }
201
202 pub fn with_retrieval_config(mut self, config: GraphRetrievalConfig) -> Self {
204 self.retrieval_config = config;
205 self
206 }
207
208 pub fn with_query_expansion(mut self, enabled: bool) -> Self {
210 self.retrieval_config.enable_query_expansion = enabled;
211 self
212 }
213
214 pub fn with_pagerank_scoring(mut self, enabled: bool) -> Self {
216 self.retrieval_config.enable_pagerank_scoring = enabled;
217 self
218 }
219
220 pub fn with_scoring_weights(mut self, graph_weight: f32, similarity_weight: f32) -> Self {
222 self.retrieval_config.graph_weight = graph_weight;
223 self.retrieval_config.similarity_weight = similarity_weight;
224 self
225 }
226
227 pub fn with_max_graph_hops(mut self, max_hops: usize) -> Self {
229 self.retrieval_config.max_graph_hops = max_hops;
230 self
231 }
232
233 pub fn with_expansion_strategies(mut self, strategies: Vec<ExpansionStrategy>) -> Self {
235 self.retrieval_config.expansion_options.strategies = strategies;
236 self
237 }
238
239 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
241 self.config.batch_size = batch_size;
242 self
243 }
244
245 pub fn with_parallel_processing(mut self, enabled: bool) -> Self {
247 self.config.enable_parallel_processing = enabled;
248 self
249 }
250
251 pub async fn build_from_documents(
253 mut self,
254 documents: Vec<Document>,
255 progress_callback: Option<Box<dyn ProgressCallback>>,
256 ) -> RragResult<GraphRetriever> {
257 let entity_extractor = self.entity_extractor.take().unwrap_or_else(|| {
259 Box::new(RuleBasedEntityExtractor::new(self.config.entity_config.clone()).unwrap())
260 });
261
262 let storage = self.storage.take().unwrap_or_else(|| {
263 Box::new(InMemoryGraphStorage::with_config(
264 self.config.storage_config.clone(),
265 ))
266 });
267
268 let graph = self
270 .build_graph_from_documents(&documents, &*entity_extractor, progress_callback)
271 .await?;
272
273 GraphRetriever::new(graph, storage, self.retrieval_config)
275 }
276
277 pub async fn build_from_chunks(
279 self,
280 chunks: Vec<DocumentChunk>,
281 progress_callback: Option<Box<dyn ProgressCallback>>,
282 ) -> RragResult<GraphRetriever> {
283 let documents: Vec<Document> = chunks
285 .into_iter()
286 .map(|chunk| {
287 Document::with_id(
288 format!("chunk_{}_{}", chunk.document_id, chunk.chunk_index),
289 chunk.content.clone(),
290 )
291 .with_metadata(
292 "source_document",
293 serde_json::Value::String(chunk.document_id),
294 )
295 .with_metadata(
296 "chunk_index",
297 serde_json::Value::Number(chunk.chunk_index.into()),
298 )
299 })
300 .collect();
301
302 self.build_from_documents(documents, progress_callback)
303 .await
304 }
305
306 async fn build_graph_from_documents(
308 &self,
309 documents: &[Document],
310 entity_extractor: &dyn EntityExtractor,
311 progress_callback: Option<Box<dyn ProgressCallback>>,
312 ) -> RragResult<KnowledgeGraph> {
313 let mut progress = GraphBuildProgress {
314 phase: BuildPhase::Initializing,
315 documents_processed: 0,
316 total_documents: documents.len(),
317 entities_extracted: 0,
318 relationships_found: 0,
319 graph_nodes: 0,
320 graph_edges: 0,
321 processing_speed: 0.0,
322 estimated_remaining_seconds: 0,
323 errors: Vec::new(),
324 };
325
326 if let Some(callback) = &progress_callback {
327 callback.on_progress(&progress).await;
328 }
329
330 let mut graph = KnowledgeGraph::new();
331 let start_time = std::time::Instant::now();
332
333 progress.phase = BuildPhase::EntityExtraction;
335 if let Some(callback) = &progress_callback {
336 callback.on_progress(&progress).await;
337 }
338
339 let mut all_entities = Vec::new();
340 let mut all_relationships = Vec::new();
341
342 if self.config.enable_parallel_processing && documents.len() > self.config.batch_size {
343 for (_batch_idx, batch) in documents.chunks(self.config.batch_size).enumerate() {
345 let batch_start = std::time::Instant::now();
346 let mut batch_entities = Vec::new();
347 let mut batch_relationships = Vec::new();
348
349 for document in batch {
351 match entity_extractor
352 .extract_all(&document.content_str(), &document.id)
353 .await
354 {
355 Ok((entities, relationships)) => {
356 progress.entities_extracted += entities.len();
357 progress.relationships_found += relationships.len();
358 batch_entities.extend(entities);
359 batch_relationships.extend(relationships);
360 }
361 Err(e) => {
362 progress
363 .errors
364 .push(format!("Document {}: {}", document.id, e));
365 }
366 }
367 progress.documents_processed += 1;
368 }
369
370 all_entities.extend(batch_entities);
371 all_relationships.extend(batch_relationships);
372
373 let batch_time = batch_start.elapsed().as_secs_f32();
375 progress.processing_speed = batch.len() as f32 / batch_time;
376 let remaining_docs = documents.len() - progress.documents_processed;
377 progress.estimated_remaining_seconds =
378 (remaining_docs as f32 / progress.processing_speed.max(0.1)) as u64;
379
380 if let Some(callback) = &progress_callback {
381 callback.on_progress(&progress).await;
382 }
383 }
384 } else {
385 for (doc_idx, document) in documents.iter().enumerate() {
387 let _doc_start = std::time::Instant::now();
388
389 match entity_extractor
390 .extract_all(&document.content_str(), &document.id)
391 .await
392 {
393 Ok((entities, relationships)) => {
394 progress.entities_extracted += entities.len();
395 progress.relationships_found += relationships.len();
396 all_entities.extend(entities);
397 all_relationships.extend(relationships);
398 }
399 Err(e) => {
400 progress
401 .errors
402 .push(format!("Document {}: {}", document.id, e));
403 }
404 }
405
406 progress.documents_processed += 1;
407
408 if doc_idx % 10 == 0 {
410 let elapsed = start_time.elapsed().as_secs_f32();
411 progress.processing_speed = progress.documents_processed as f32 / elapsed;
412 let remaining_docs = documents.len() - progress.documents_processed;
413 progress.estimated_remaining_seconds =
414 (remaining_docs as f32 / progress.processing_speed.max(0.1)) as u64;
415
416 if let Some(callback) = &progress_callback {
417 callback.on_progress(&progress).await;
418 }
419 }
420 }
421 }
422
423 progress.phase = BuildPhase::GraphConstruction;
425 if let Some(callback) = &progress_callback {
426 callback.on_progress(&progress).await;
427 }
428
429 let entity_nodes = entities_to_nodes(&all_entities);
431 progress.graph_nodes = entity_nodes.len();
432
433 let mut entity_node_map = HashMap::new();
435 for node in &entity_nodes {
436 if let Some(original_text) = node.attributes.get("original_text") {
438 if let Some(text) = original_text.as_str() {
439 entity_node_map.insert(text.to_string(), node.id.clone());
440 }
441 }
442 entity_node_map.insert(node.label.clone(), node.id.clone());
443 }
444
445 for node in entity_nodes {
447 graph.add_node(node)?;
448 }
449
450 let relationship_edges = relationships_to_edges(&all_relationships, &entity_node_map);
452 progress.graph_edges = relationship_edges.len();
453
454 for edge in relationship_edges {
456 if let Err(e) = graph.add_edge(edge) {
457 progress.errors.push(format!("Failed to add edge: {}", e));
458 }
459 }
460
461 for document in documents {
463 let doc_node =
464 GraphNode::new(format!("doc_{}", document.id), super::NodeType::Document)
465 .with_source_document(document.id.clone())
466 .with_attribute(
467 "title",
468 serde_json::Value::String(
469 document
470 .metadata
471 .get("title")
472 .and_then(|v| v.as_str())
473 .unwrap_or(&document.id)
474 .to_string(),
475 ),
476 );
477
478 graph.add_node(doc_node)?;
479 progress.graph_nodes += 1;
480 }
481
482 if self.config.generate_entity_embeddings && self._embedding_service.is_some() {
484 progress.phase = BuildPhase::EmbeddingGeneration;
485 if let Some(callback) = &progress_callback {
486 callback.on_progress(&progress).await;
487 }
488
489 }
493
494 if self.config.calculate_pagerank {
496 progress.phase = BuildPhase::MetricComputation;
497 if let Some(callback) = &progress_callback {
498 callback.on_progress(&progress).await;
499 }
500
501 let pagerank_config = PageRankConfig::default();
503 match super::algorithms::GraphAlgorithms::pagerank(&graph, &pagerank_config) {
504 Ok(pagerank_scores) => {
505 for (node_id, score) in pagerank_scores {
507 if let Some(node) = graph.nodes.get_mut(&node_id) {
508 node.pagerank_score = Some(score);
509 }
510 }
511 }
512 Err(e) => {
513 progress
514 .errors
515 .push(format!("PageRank computation failed: {}", e));
516 }
517 }
518 }
519
520 progress.phase = BuildPhase::Indexing;
522 if let Some(callback) = &progress_callback {
523 callback.on_progress(&progress).await;
524 }
525
526 progress.phase = BuildPhase::Completed;
531 progress.processing_speed =
532 progress.documents_processed as f32 / start_time.elapsed().as_secs_f32();
533 progress.estimated_remaining_seconds = 0;
534
535 if let Some(callback) = &progress_callback {
536 callback.on_progress(&progress).await;
537 }
538
539 Ok(graph)
540 }
541
542 pub async fn build_empty(mut self) -> RragResult<GraphRetriever> {
544 let storage = self.storage.take().unwrap_or_else(|| {
545 Box::new(InMemoryGraphStorage::with_config(
546 self.config.storage_config.clone(),
547 ))
548 });
549
550 let graph = KnowledgeGraph::new();
551 GraphRetriever::new(graph, storage, self.retrieval_config)
552 }
553}
554
555impl Default for GraphRetrievalBuilder {
556 fn default() -> Self {
557 Self::new()
558 }
559}
560
561pub struct PrintProgressCallback;
563
564#[async_trait]
565impl ProgressCallback for PrintProgressCallback {
566 async fn on_progress(&self, progress: &GraphBuildProgress) {
567 match &progress.phase {
568 BuildPhase::Initializing => {
569 tracing::debug!("Initializing graph builder...");
570 }
571 BuildPhase::EntityExtraction => {
572 tracing::debug!(
573 "Extracting entities: {}/{} documents processed ({:.1} docs/sec), {} entities found, {} relationships found",
574 progress.documents_processed,
575 progress.total_documents,
576 progress.processing_speed,
577 progress.entities_extracted,
578 progress.relationships_found
579 );
580 }
581 BuildPhase::GraphConstruction => {
582 tracing::debug!(
583 "Building graph: {} nodes, {} edges",
584 progress.graph_nodes,
585 progress.graph_edges
586 );
587 }
588 BuildPhase::EmbeddingGeneration => {
589 tracing::debug!("Generating embeddings for entities...");
590 }
591 BuildPhase::MetricComputation => {
592 tracing::debug!("Computing graph metrics (PageRank, centrality, etc.)...");
593 }
594 BuildPhase::Indexing => {
595 tracing::debug!("Building search indices...");
596 }
597 BuildPhase::Completed => {
598 tracing::debug!(
599 "Graph construction completed! Processed {} documents, extracted {} entities, found {} relationships",
600 progress.documents_processed,
601 progress.entities_extracted,
602 progress.relationships_found
603 );
604 tracing::debug!(
605 "Final graph: {} nodes, {} edges",
606 progress.graph_nodes,
607 progress.graph_edges
608 );
609 if !progress.errors.is_empty() {
610 tracing::debug!(
611 "Encountered {} errors during processing",
612 progress.errors.len()
613 );
614 }
615 }
616 BuildPhase::Failed(error) => {
617 tracing::debug!("Graph construction failed: {}", error);
618 }
619 }
620
621 if progress.estimated_remaining_seconds > 0 {
622 tracing::debug!(
623 "Estimated time remaining: {} seconds",
624 progress.estimated_remaining_seconds
625 );
626 }
627 }
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633
634 #[tokio::test]
635 async fn test_builder_creation() {
636 let builder = GraphRetrievalBuilder::new();
637
638 let retriever = builder.build_empty().await.unwrap();
640 assert_eq!(retriever.name(), "graph_retriever");
641 }
642
643 #[tokio::test]
644 async fn test_builder_configuration() {
645 let builder = GraphRetrievalBuilder::new()
646 .with_batch_size(50)
647 .with_parallel_processing(false)
648 .with_query_expansion(true)
649 .with_pagerank_scoring(true)
650 .with_max_graph_hops(2)
651 .with_scoring_weights(0.5, 0.5);
652
653 assert_eq!(builder.config.batch_size, 50);
654 assert!(!builder.config.enable_parallel_processing);
655 assert!(builder.retrieval_config.enable_query_expansion);
656 assert!(builder.retrieval_config.enable_pagerank_scoring);
657 assert_eq!(builder.retrieval_config.max_graph_hops, 2);
658 assert_eq!(builder.retrieval_config.graph_weight, 0.5);
659 assert_eq!(builder.retrieval_config.similarity_weight, 0.5);
660 }
661
662 #[tokio::test]
663 async fn test_build_from_documents() {
664 let documents = vec![
665 Document::new("John Smith works at Google. He is a software engineer."),
666 Document::new("Google is a technology company in California."),
667 ];
668
669 let config = GraphBuildConfig {
670 calculate_pagerank: false,
671 generate_entity_embeddings: false,
672 enable_parallel_processing: false,
673 ..Default::default()
674 };
675
676 let builder = GraphRetrievalBuilder::new().with_config(config);
677
678 let progress_callback = Box::new(PrintProgressCallback);
679 let result = builder
680 .build_from_documents(documents, Some(progress_callback))
681 .await;
682
683 match result {
684 Ok(retriever) => {
685 assert_eq!(retriever.name(), "graph_retriever");
686 let health = retriever.health_check().await.unwrap();
688 assert!(health);
689 }
690 Err(e) => {
691 tracing::debug!("Builder test failed: {}", e);
692 }
695 }
696 }
697}