use super::{
algorithms::PageRankConfig,
entity::{
entities_to_nodes, relationships_to_edges, EntityExtractionConfig, EntityExtractor,
RuleBasedEntityExtractor,
},
query_expansion::{ExpansionConfig, ExpansionStrategy},
storage::{GraphStorage, GraphStorageConfig, InMemoryGraphStorage},
GraphNode, GraphRetrievalConfig, GraphRetriever, KnowledgeGraph,
};
use crate::{Document, DocumentChunk, RragResult};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub struct GraphRetrievalBuilder {
config: GraphBuildConfig,
entity_extractor: Option<Box<dyn EntityExtractor>>,
storage: Option<Box<dyn GraphStorage>>,
_embedding_service: Option<()>,
retrieval_config: GraphRetrievalConfig,
}
#[derive(Debug, Clone)]
pub struct GraphBuildConfig {
pub entity_config: EntityExtractionConfig,
pub storage_config: GraphStorageConfig,
pub expansion_config: ExpansionConfig,
pub generate_entity_embeddings: bool,
pub calculate_pagerank: bool,
pub batch_size: usize,
pub enable_parallel_processing: bool,
pub num_workers: usize,
}
impl Default for GraphBuildConfig {
fn default() -> Self {
Self {
entity_config: EntityExtractionConfig::default(),
storage_config: GraphStorageConfig::default(),
expansion_config: ExpansionConfig::default(),
generate_entity_embeddings: true,
calculate_pagerank: true,
batch_size: 100,
enable_parallel_processing: true,
num_workers: num_cpus::get(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphBuildProgress {
pub phase: BuildPhase,
pub documents_processed: usize,
pub total_documents: usize,
pub entities_extracted: usize,
pub relationships_found: usize,
pub graph_nodes: usize,
pub graph_edges: usize,
pub processing_speed: f32,
pub estimated_remaining_seconds: u64,
pub errors: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum BuildPhase {
Initializing,
EntityExtraction,
GraphConstruction,
EmbeddingGeneration,
MetricComputation,
Indexing,
Completed,
Failed(String),
}
#[async_trait]
pub trait ProgressCallback: Send + Sync {
async fn on_progress(&self, progress: &GraphBuildProgress);
}
impl GraphRetrievalBuilder {
pub fn new() -> Self {
Self {
config: GraphBuildConfig::default(),
entity_extractor: None,
storage: None,
_embedding_service: None,
retrieval_config: GraphRetrievalConfig::default(),
}
}
pub fn with_config(mut self, config: GraphBuildConfig) -> Self {
self.config = config;
self
}
pub fn with_entity_extractor(mut self, extractor: Box<dyn EntityExtractor>) -> Self {
self.entity_extractor = Some(extractor);
self
}
pub fn with_rule_based_entity_extractor(
mut self,
config: EntityExtractionConfig,
) -> RragResult<Self> {
let extractor = RuleBasedEntityExtractor::new(config)?;
self.entity_extractor = Some(Box::new(extractor));
Ok(self)
}
pub fn with_storage(mut self, storage: Box<dyn GraphStorage>) -> Self {
self.storage = Some(storage);
self
}
pub fn with_in_memory_storage(mut self, config: GraphStorageConfig) -> Self {
let storage = InMemoryGraphStorage::with_config(config);
self.storage = Some(Box::new(storage));
self
}
pub fn with_embedding_service(mut self) -> Self {
self._embedding_service = Some(());
self
}
pub fn with_retrieval_config(mut self, config: GraphRetrievalConfig) -> Self {
self.retrieval_config = config;
self
}
pub fn with_query_expansion(mut self, enabled: bool) -> Self {
self.retrieval_config.enable_query_expansion = enabled;
self
}
pub fn with_pagerank_scoring(mut self, enabled: bool) -> Self {
self.retrieval_config.enable_pagerank_scoring = enabled;
self
}
pub fn with_scoring_weights(mut self, graph_weight: f32, similarity_weight: f32) -> Self {
self.retrieval_config.graph_weight = graph_weight;
self.retrieval_config.similarity_weight = similarity_weight;
self
}
pub fn with_max_graph_hops(mut self, max_hops: usize) -> Self {
self.retrieval_config.max_graph_hops = max_hops;
self
}
pub fn with_expansion_strategies(mut self, strategies: Vec<ExpansionStrategy>) -> Self {
self.retrieval_config.expansion_options.strategies = strategies;
self
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.config.batch_size = batch_size;
self
}
pub fn with_parallel_processing(mut self, enabled: bool) -> Self {
self.config.enable_parallel_processing = enabled;
self
}
pub async fn build_from_documents(
mut self,
documents: Vec<Document>,
progress_callback: Option<Box<dyn ProgressCallback>>,
) -> RragResult<GraphRetriever> {
let entity_extractor = self.entity_extractor.take().unwrap_or_else(|| {
Box::new(RuleBasedEntityExtractor::new(self.config.entity_config.clone()).unwrap())
});
let storage = self.storage.take().unwrap_or_else(|| {
Box::new(InMemoryGraphStorage::with_config(
self.config.storage_config.clone(),
))
});
let graph = self
.build_graph_from_documents(&documents, &*entity_extractor, progress_callback)
.await?;
GraphRetriever::new(graph, storage, self.retrieval_config)
}
pub async fn build_from_chunks(
self,
chunks: Vec<DocumentChunk>,
progress_callback: Option<Box<dyn ProgressCallback>>,
) -> RragResult<GraphRetriever> {
let documents: Vec<Document> = chunks
.into_iter()
.map(|chunk| {
Document::with_id(
format!("chunk_{}_{}", chunk.document_id, chunk.chunk_index),
chunk.content.clone(),
)
.with_metadata(
"source_document",
serde_json::Value::String(chunk.document_id),
)
.with_metadata(
"chunk_index",
serde_json::Value::Number(chunk.chunk_index.into()),
)
})
.collect();
self.build_from_documents(documents, progress_callback)
.await
}
async fn build_graph_from_documents(
&self,
documents: &[Document],
entity_extractor: &dyn EntityExtractor,
progress_callback: Option<Box<dyn ProgressCallback>>,
) -> RragResult<KnowledgeGraph> {
let mut progress = GraphBuildProgress {
phase: BuildPhase::Initializing,
documents_processed: 0,
total_documents: documents.len(),
entities_extracted: 0,
relationships_found: 0,
graph_nodes: 0,
graph_edges: 0,
processing_speed: 0.0,
estimated_remaining_seconds: 0,
errors: Vec::new(),
};
if let Some(callback) = &progress_callback {
callback.on_progress(&progress).await;
}
let mut graph = KnowledgeGraph::new();
let start_time = std::time::Instant::now();
progress.phase = BuildPhase::EntityExtraction;
if let Some(callback) = &progress_callback {
callback.on_progress(&progress).await;
}
let mut all_entities = Vec::new();
let mut all_relationships = Vec::new();
if self.config.enable_parallel_processing && documents.len() > self.config.batch_size {
for (_batch_idx, batch) in documents.chunks(self.config.batch_size).enumerate() {
let batch_start = std::time::Instant::now();
let mut batch_entities = Vec::new();
let mut batch_relationships = Vec::new();
for document in batch {
match entity_extractor
.extract_all(&document.content_str(), &document.id)
.await
{
Ok((entities, relationships)) => {
progress.entities_extracted += entities.len();
progress.relationships_found += relationships.len();
batch_entities.extend(entities);
batch_relationships.extend(relationships);
}
Err(e) => {
progress
.errors
.push(format!("Document {}: {}", document.id, e));
}
}
progress.documents_processed += 1;
}
all_entities.extend(batch_entities);
all_relationships.extend(batch_relationships);
let batch_time = batch_start.elapsed().as_secs_f32();
progress.processing_speed = batch.len() as f32 / batch_time;
let remaining_docs = documents.len() - progress.documents_processed;
progress.estimated_remaining_seconds =
(remaining_docs as f32 / progress.processing_speed.max(0.1)) as u64;
if let Some(callback) = &progress_callback {
callback.on_progress(&progress).await;
}
}
} else {
for (doc_idx, document) in documents.iter().enumerate() {
let _doc_start = std::time::Instant::now();
match entity_extractor
.extract_all(&document.content_str(), &document.id)
.await
{
Ok((entities, relationships)) => {
progress.entities_extracted += entities.len();
progress.relationships_found += relationships.len();
all_entities.extend(entities);
all_relationships.extend(relationships);
}
Err(e) => {
progress
.errors
.push(format!("Document {}: {}", document.id, e));
}
}
progress.documents_processed += 1;
if doc_idx % 10 == 0 {
let elapsed = start_time.elapsed().as_secs_f32();
progress.processing_speed = progress.documents_processed as f32 / elapsed;
let remaining_docs = documents.len() - progress.documents_processed;
progress.estimated_remaining_seconds =
(remaining_docs as f32 / progress.processing_speed.max(0.1)) as u64;
if let Some(callback) = &progress_callback {
callback.on_progress(&progress).await;
}
}
}
}
progress.phase = BuildPhase::GraphConstruction;
if let Some(callback) = &progress_callback {
callback.on_progress(&progress).await;
}
let entity_nodes = entities_to_nodes(&all_entities);
progress.graph_nodes = entity_nodes.len();
let mut entity_node_map = HashMap::new();
for node in &entity_nodes {
if let Some(original_text) = node.attributes.get("original_text") {
if let Some(text) = original_text.as_str() {
entity_node_map.insert(text.to_string(), node.id.clone());
}
}
entity_node_map.insert(node.label.clone(), node.id.clone());
}
for node in entity_nodes {
graph.add_node(node)?;
}
let relationship_edges = relationships_to_edges(&all_relationships, &entity_node_map);
progress.graph_edges = relationship_edges.len();
for edge in relationship_edges {
if let Err(e) = graph.add_edge(edge) {
progress.errors.push(format!("Failed to add edge: {}", e));
}
}
for document in documents {
let doc_node =
GraphNode::new(format!("doc_{}", document.id), super::NodeType::Document)
.with_source_document(document.id.clone())
.with_attribute(
"title",
serde_json::Value::String(
document
.metadata
.get("title")
.and_then(|v| v.as_str())
.unwrap_or(&document.id)
.to_string(),
),
);
graph.add_node(doc_node)?;
progress.graph_nodes += 1;
}
if self.config.generate_entity_embeddings && self._embedding_service.is_some() {
progress.phase = BuildPhase::EmbeddingGeneration;
if let Some(callback) = &progress_callback {
callback.on_progress(&progress).await;
}
}
if self.config.calculate_pagerank {
progress.phase = BuildPhase::MetricComputation;
if let Some(callback) = &progress_callback {
callback.on_progress(&progress).await;
}
let pagerank_config = PageRankConfig::default();
match super::algorithms::GraphAlgorithms::pagerank(&graph, &pagerank_config) {
Ok(pagerank_scores) => {
for (node_id, score) in pagerank_scores {
if let Some(node) = graph.nodes.get_mut(&node_id) {
node.pagerank_score = Some(score);
}
}
}
Err(e) => {
progress
.errors
.push(format!("PageRank computation failed: {}", e));
}
}
}
progress.phase = BuildPhase::Indexing;
if let Some(callback) = &progress_callback {
callback.on_progress(&progress).await;
}
progress.phase = BuildPhase::Completed;
progress.processing_speed =
progress.documents_processed as f32 / start_time.elapsed().as_secs_f32();
progress.estimated_remaining_seconds = 0;
if let Some(callback) = &progress_callback {
callback.on_progress(&progress).await;
}
Ok(graph)
}
pub async fn build_empty(mut self) -> RragResult<GraphRetriever> {
let storage = self.storage.take().unwrap_or_else(|| {
Box::new(InMemoryGraphStorage::with_config(
self.config.storage_config.clone(),
))
});
let graph = KnowledgeGraph::new();
GraphRetriever::new(graph, storage, self.retrieval_config)
}
}
impl Default for GraphRetrievalBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct PrintProgressCallback;
#[async_trait]
impl ProgressCallback for PrintProgressCallback {
async fn on_progress(&self, progress: &GraphBuildProgress) {
match &progress.phase {
BuildPhase::Initializing => {
println!("Initializing graph builder...");
}
BuildPhase::EntityExtraction => {
println!(
"Extracting entities: {}/{} documents processed ({:.1} docs/sec), {} entities found, {} relationships found",
progress.documents_processed,
progress.total_documents,
progress.processing_speed,
progress.entities_extracted,
progress.relationships_found
);
}
BuildPhase::GraphConstruction => {
println!(
"Building graph: {} nodes, {} edges",
progress.graph_nodes, progress.graph_edges
);
}
BuildPhase::EmbeddingGeneration => {
println!("Generating embeddings for entities...");
}
BuildPhase::MetricComputation => {
println!("Computing graph metrics (PageRank, centrality, etc.)...");
}
BuildPhase::Indexing => {
println!("Building search indices...");
}
BuildPhase::Completed => {
println!(
"Graph construction completed! Processed {} documents, extracted {} entities, found {} relationships",
progress.documents_processed,
progress.entities_extracted,
progress.relationships_found
);
println!(
"Final graph: {} nodes, {} edges",
progress.graph_nodes, progress.graph_edges
);
if !progress.errors.is_empty() {
println!(
"Encountered {} errors during processing",
progress.errors.len()
);
}
}
BuildPhase::Failed(error) => {
println!("Graph construction failed: {}", error);
}
}
if progress.estimated_remaining_seconds > 0 {
println!(
"Estimated time remaining: {} seconds",
progress.estimated_remaining_seconds
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_builder_creation() {
let builder = GraphRetrievalBuilder::new();
let retriever = builder.build_empty().await.unwrap();
assert_eq!(retriever.name(), "graph_retriever");
}
#[tokio::test]
async fn test_builder_configuration() {
let builder = GraphRetrievalBuilder::new()
.with_batch_size(50)
.with_parallel_processing(false)
.with_query_expansion(true)
.with_pagerank_scoring(true)
.with_max_graph_hops(2)
.with_scoring_weights(0.5, 0.5);
assert_eq!(builder.config.batch_size, 50);
assert!(!builder.config.enable_parallel_processing);
assert!(builder.retrieval_config.enable_query_expansion);
assert!(builder.retrieval_config.enable_pagerank_scoring);
assert_eq!(builder.retrieval_config.max_graph_hops, 2);
assert_eq!(builder.retrieval_config.graph_weight, 0.5);
assert_eq!(builder.retrieval_config.similarity_weight, 0.5);
}
#[tokio::test]
async fn test_build_from_documents() {
let documents = vec![
Document::new("John Smith works at Google. He is a software engineer."),
Document::new("Google is a technology company in California."),
];
let config = GraphBuildConfig {
calculate_pagerank: false,
generate_entity_embeddings: false,
enable_parallel_processing: false,
..Default::default()
};
let builder = GraphRetrievalBuilder::new().with_config(config);
let progress_callback = Box::new(PrintProgressCallback);
let result = builder
.build_from_documents(documents, Some(progress_callback))
.await;
match result {
Ok(retriever) => {
assert_eq!(retriever.name(), "graph_retriever");
let health = retriever.health_check().await.unwrap();
assert!(health);
}
Err(e) => {
println!("Builder test failed: {}", e);
}
}
}
}