#![cfg(feature = "pagerank")]
use crate::{
core::traits::Retriever,
core::{ChunkId, EntityId, GraphRAGError, KnowledgeGraph, Result},
graph::pagerank::{MultiModalScores, PageRankConfig, PersonalizedPageRank, ScoreWeights},
vector::VectorIndex,
};
use lru::LruCache;
use parking_lot::RwLock;
use rayon::prelude::*;
use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::sync::Arc;
pub struct PageRankRetrievalSystem {
vector_index: Option<VectorIndex>,
score_weights: ScoreWeights,
max_results: usize,
min_score_threshold: f64,
pagerank_config: PageRankConfig,
query_cache: Arc<RwLock<LruCache<String, Vec<ScoredResult>>>>,
entity_rank_cache: Arc<RwLock<LruCache<String, HashMap<EntityId, f64>>>>,
incremental_mode: bool,
global_pagerank: Option<HashMap<EntityId, f64>>,
graph: Option<Arc<KnowledgeGraph>>,
}
#[derive(Debug, Clone)]
pub struct ScoredResult {
pub entity_id: EntityId,
pub chunk_id: ChunkId,
pub content: String,
pub score: f64,
pub vector_score: f64,
pub pagerank_score: f64,
pub combined_score: f64,
}
impl PageRankRetrievalSystem {
pub fn new(max_results: usize) -> Self {
let cache_size = NonZeroUsize::new(1000).expect("non-zero literal");
Self {
vector_index: None,
score_weights: ScoreWeights::default(),
max_results,
min_score_threshold: 0.1,
pagerank_config: PageRankConfig::default(),
query_cache: Arc::new(RwLock::new(LruCache::new(cache_size))),
entity_rank_cache: Arc::new(RwLock::new(LruCache::new(cache_size))),
incremental_mode: true,
global_pagerank: None,
graph: None,
}
}
pub fn with_pagerank_config(mut self, config: PageRankConfig) -> Self {
self.pagerank_config = config;
self
}
pub fn with_incremental_mode(mut self, enabled: bool) -> Self {
self.incremental_mode = enabled;
self
}
pub fn with_score_weights(mut self, weights: ScoreWeights) -> Self {
self.score_weights = weights;
self
}
pub fn with_min_threshold(mut self, threshold: f64) -> Self {
self.min_score_threshold = threshold;
self
}
pub fn initialize_vector_index(&mut self, graph: &KnowledgeGraph) -> Result<()> {
let mut content_items = Vec::new();
for entity in graph.entities() {
let content = format!("{} {}", entity.name, entity.entity_type);
content_items.push((entity.id.to_string(), content));
}
for chunk in graph.chunks() {
content_items.push((chunk.id.to_string(), chunk.content.clone()));
}
self.vector_index = Some(VectorIndex::new());
println!(
"🔍 Vector index initialized with {} items",
content_items.len()
);
Ok(())
}
pub fn set_graph(&mut self, graph: Arc<KnowledgeGraph>) {
self.graph = Some(graph);
}
pub fn search_with_pagerank(
&self,
query: &str,
graph: &KnowledgeGraph,
max_results: Option<usize>,
) -> Result<Vec<ScoredResult>> {
let max_results = max_results.unwrap_or(self.max_results);
println!("🔍 Starting PageRank-enhanced search for: '{query}'");
let vector_scores = self.vector_similarity_search(query, graph)?;
println!("📊 Vector search found {} candidates", vector_scores.len());
if vector_scores.is_empty() {
return Ok(Vec::new());
}
let pagerank_calculator = graph.build_pagerank_calculator()?;
let pagerank_scores =
self.compute_personalized_pagerank(&vector_scores, &pagerank_calculator)?;
println!("📈 PageRank computation completed");
let mut multi_scores = MultiModalScores::new();
multi_scores.vector_scores = vector_scores;
multi_scores.pagerank_scores = pagerank_scores;
let combined_scores = multi_scores.combine_scores(&self.score_weights);
let mut scored_results = Vec::new();
for (entity_id, combined_score) in combined_scores {
if combined_score < self.min_score_threshold {
continue;
}
for chunk in graph.chunks() {
if chunk.entities.contains(&entity_id) {
let vector_score = multi_scores.vector_scores.get(&entity_id).unwrap_or(&0.0);
let pagerank_score =
multi_scores.pagerank_scores.get(&entity_id).unwrap_or(&0.0);
let result = ScoredResult {
entity_id: entity_id.clone(),
chunk_id: chunk.id.clone(),
content: chunk.content.clone(),
score: combined_score,
vector_score: *vector_score,
pagerank_score: *pagerank_score,
combined_score,
};
scored_results.push(result);
}
}
if scored_results.is_empty() || !scored_results.iter().any(|r| r.entity_id == entity_id)
{
if let Some(entity) = graph.get_entity(&entity_id) {
let vector_score = multi_scores.vector_scores.get(&entity_id).unwrap_or(&0.0);
let pagerank_score =
multi_scores.pagerank_scores.get(&entity_id).unwrap_or(&0.0);
let result = ScoredResult {
entity_id: entity_id.clone(),
chunk_id: ChunkId::new(format!("entity_{entity_id}")),
content: format!("{}: {}", entity.name, entity.entity_type),
score: combined_score,
vector_score: *vector_score,
pagerank_score: *pagerank_score,
combined_score,
};
scored_results.push(result);
}
}
}
scored_results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
scored_results.truncate(max_results);
println!(
"✅ Search completed: {} results returned",
scored_results.len()
);
Ok(scored_results)
}
fn vector_similarity_search(
&self,
query: &str,
graph: &KnowledgeGraph,
) -> Result<HashMap<EntityId, f64>> {
let mut scores = HashMap::new();
let query_lower = query.to_lowercase();
let query_words: Vec<&str> = query_lower.split_whitespace().collect();
for entity in graph.entities() {
let entity_text = format!(
"{} {}",
entity.name.to_lowercase(),
entity.entity_type.to_lowercase()
);
let entity_words: Vec<&str> = entity_text.split_whitespace().collect();
let intersection_count = query_words
.iter()
.filter(|word| entity_words.contains(word))
.count();
if intersection_count > 0 {
let union_count = query_words.len() + entity_words.len() - intersection_count;
let similarity = intersection_count as f64 / union_count as f64;
if similarity > 0.1 {
scores.insert(entity.id.clone(), similarity);
}
}
for mention in &entity.mentions {
if let Some(chunk) = graph.get_chunk(&mention.chunk_id) {
let chunk_lower = chunk.content.to_lowercase();
if chunk_lower.contains(&query_lower) {
let bonus_score = 0.3;
let current_score = scores.get(&entity.id).unwrap_or(&0.0);
scores.insert(entity.id.clone(), current_score + bonus_score);
}
}
}
}
Ok(scores)
}
fn compute_personalized_pagerank(
&self,
vector_scores: &HashMap<EntityId, f64>,
pagerank_calculator: &PersonalizedPageRank,
) -> Result<HashMap<EntityId, f64>> {
if vector_scores.is_empty() {
return Ok(HashMap::new());
}
let reset_probabilities = self.normalize_reset_probabilities(vector_scores);
let pagerank_scores = pagerank_calculator.calculate_scores(&reset_probabilities)?;
Ok(pagerank_scores)
}
fn normalize_reset_probabilities(
&self,
vector_scores: &HashMap<EntityId, f64>,
) -> HashMap<EntityId, f64> {
let total_score: f64 = vector_scores.values().sum();
if total_score > 0.0 {
vector_scores
.iter()
.map(|(id, score)| (id.clone(), score / total_score))
.collect()
} else {
HashMap::new()
}
}
pub fn get_search_statistics(&self) -> SearchStatistics {
SearchStatistics {
has_vector_index: self.vector_index.is_some(),
score_weights: self.score_weights.clone(),
max_results: self.max_results,
min_score_threshold: self.min_score_threshold,
}
}
pub fn update_score_weights(&mut self, weights: ScoreWeights) {
self.score_weights = weights;
}
pub fn quick_entity_search(
&self,
entity_name: &str,
graph: &KnowledgeGraph,
) -> Vec<ScoredResult> {
let name_lower = entity_name.to_lowercase();
let mut results = Vec::new();
for entity in graph.entities() {
if entity.name.to_lowercase().contains(&name_lower) {
let score = if entity.name.to_lowercase() == name_lower {
1.0 } else {
0.8 };
let result = ScoredResult {
entity_id: entity.id.clone(),
chunk_id: ChunkId::new(format!("entity_{}", entity.id)),
content: format!("{}: {}", entity.name, entity.entity_type),
score,
vector_score: score,
pagerank_score: 0.0,
combined_score: score,
};
results.push(result);
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(self.max_results);
results
}
pub fn precompute_global_pagerank(&mut self, graph: &KnowledgeGraph) -> Result<()> {
println!("🚀 Pre-computing global PageRank scores...");
let pagerank_calculator = graph.build_pagerank_calculator()?;
let empty_reset = HashMap::new();
let global_scores = pagerank_calculator.calculate_scores(&empty_reset)?;
self.global_pagerank = Some(global_scores);
println!(
"✅ Global PageRank scores computed for {} entities",
self.global_pagerank.as_ref().expect("checked above").len()
);
Ok(())
}
pub fn batch_search(
&self,
queries: &[&str],
graph: &KnowledgeGraph,
max_results_per_query: Option<usize>,
) -> Result<Vec<Vec<ScoredResult>>> {
if queries.is_empty() {
return Ok(Vec::new());
}
println!("🔍 Starting batch search for {} queries", queries.len());
let results: Result<Vec<_>> = queries
.par_iter()
.map(|&query| self.search_with_pagerank(query, graph, max_results_per_query))
.collect();
let batch_results = results?;
println!("✅ Batch search completed");
Ok(batch_results)
}
pub fn search_with_dynamic_weights(
&self,
query: &str,
graph: &KnowledgeGraph,
query_embedding: Option<&[f32]>,
max_results: Option<usize>,
) -> Result<Vec<ScoredResult>> {
let max_results = max_results.unwrap_or(self.max_results);
#[cfg(feature = "tracing")]
tracing::debug!(
query = %query,
has_embedding = query_embedding.is_some(),
"Starting dynamic weight search"
);
let query_concepts = extract_query_concepts(query);
#[cfg(feature = "tracing")]
tracing::debug!(
concepts = ?query_concepts,
"Extracted query concepts"
);
let vector_scores = self.vector_similarity_search(query, graph)?;
if vector_scores.is_empty() {
return Ok(Vec::new());
}
let mut weighted_edges: HashMap<(EntityId, EntityId), f32> = HashMap::new();
for rel in graph.get_all_relationships() {
let dynamic_weight = graph.dynamic_weight(rel, query_embedding, &query_concepts);
weighted_edges.insert((rel.source.clone(), rel.target.clone()), dynamic_weight);
#[cfg(feature = "tracing")]
if dynamic_weight > rel.confidence * 1.5 {
tracing::trace!(
source = %rel.source.0,
target = %rel.target.0,
original_weight = rel.confidence,
dynamic_weight = dynamic_weight,
"Applied significant boost"
);
}
}
let pagerank_calculator = graph.build_pagerank_calculator()?;
let pagerank_scores =
self.compute_personalized_pagerank(&vector_scores, &pagerank_calculator)?;
let mut boosted_scores = pagerank_scores.clone();
for ((source_id, target_id), weight) in &weighted_edges {
if let Some(target_score) = boosted_scores.get_mut(target_id) {
if let Some(&source_pr) = pagerank_scores.get(source_id) {
let boost = source_pr * (*weight as f64 - 1.0).max(0.0) * 0.1;
*target_score += boost;
}
}
}
let mut multi_scores = MultiModalScores::new();
multi_scores.vector_scores = vector_scores;
multi_scores.pagerank_scores = boosted_scores;
let combined_scores = multi_scores.combine_scores(&self.score_weights);
let mut scored_results = Vec::new();
for (entity_id, combined_score) in combined_scores {
if combined_score < self.min_score_threshold {
continue;
}
for chunk in graph.chunks() {
if chunk.entities.contains(&entity_id) {
let vector_score = multi_scores.vector_scores.get(&entity_id).unwrap_or(&0.0);
let pagerank_score =
multi_scores.pagerank_scores.get(&entity_id).unwrap_or(&0.0);
let result = ScoredResult {
entity_id: entity_id.clone(),
chunk_id: chunk.id.clone(),
content: chunk.content.clone(),
score: combined_score,
vector_score: *vector_score,
pagerank_score: *pagerank_score,
combined_score,
};
scored_results.push(result);
}
}
}
scored_results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
scored_results.truncate(max_results);
#[cfg(feature = "tracing")]
tracing::info!(
results_count = scored_results.len(),
"Dynamic weight search completed"
);
Ok(scored_results)
}
}
fn extract_query_concepts(query: &str) -> Vec<String> {
let stopwords = [
"what", "when", "where", "who", "which", "how", "does", "the", "a", "an", "is", "are",
"was", "were", "been", "be",
];
query
.to_lowercase()
.split_whitespace()
.filter(|word| word.len() > 3 && !stopwords.contains(word))
.map(|word| {
word.trim_matches(|c: char| !c.is_alphanumeric())
.to_string()
})
.filter(|word| !word.is_empty())
.collect()
}
impl Retriever for PageRankRetrievalSystem {
type Query = String;
type Result = ScoredResult;
type Error = GraphRAGError;
fn search(&self, _query: String, _k: usize) -> Result<Vec<ScoredResult>> {
Err(GraphRAGError::Retrieval {
message: "Use search_with_pagerank method with KnowledgeGraph parameter".to_string(),
})
}
fn search_with_context(
&self,
query: String,
context: &str,
k: usize,
) -> Result<Vec<ScoredResult>> {
let enhanced_query = format!("{query} {context}");
if let Some(graph) = &self.graph {
if self.vector_index.is_none() {
}
self.search_with_pagerank(&enhanced_query, graph, Some(k))
} else {
Err(GraphRAGError::Retrieval {
message: "No KnowledgeGraph set. Call set_graph(Arc<KnowledgeGraph>) or use search_with_pagerank(query, &graph, ...)".to_string(),
})
}
}
fn update(&mut self, _content: Vec<String>) -> Result<()> {
{
let mut query_cache = self.query_cache.write();
query_cache.clear();
}
{
let mut entity_cache = self.entity_rank_cache.write();
entity_cache.clear();
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SearchStatistics {
pub has_vector_index: bool,
pub score_weights: ScoreWeights,
pub max_results: usize,
pub min_score_threshold: f64,
}
impl SearchStatistics {
pub fn print(&self) {
println!("🔍 PageRank Retrieval Statistics");
println!(
" Vector index: {}",
if self.has_vector_index {
"Available"
} else {
"Not initialized"
}
);
println!(" Score weights:");
println!(" Vector: {:.2}", self.score_weights.vector_weight);
println!(" PageRank: {:.2}", self.score_weights.pagerank_weight);
println!(" Chunk: {:.2}", self.score_weights.chunk_weight);
println!(
" Relationship: {:.2}",
self.score_weights.relationship_weight
);
println!(" Max results: {}", self.max_results);
println!(" Min threshold: {:.3}", self.min_score_threshold);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{
ChunkId, DocumentId, Entity, EntityId, KnowledgeGraph, Relationship, TextChunk,
};
fn create_test_graph() -> KnowledgeGraph {
let mut graph = KnowledgeGraph::new();
let entity1 = Entity::new(
EntityId::new("entity1".to_string()),
"Apple Inc".to_string(),
"ORGANIZATION".to_string(),
0.9,
);
let entity2 = Entity::new(
EntityId::new("entity2".to_string()),
"iPhone".to_string(),
"PRODUCT".to_string(),
0.8,
);
let entity3 = Entity::new(
EntityId::new("entity3".to_string()),
"Steve Jobs".to_string(),
"PERSON".to_string(),
0.9,
);
graph.add_entity(entity1).unwrap();
graph.add_entity(entity2).unwrap();
graph.add_entity(entity3).unwrap();
let relationship = Relationship {
source: EntityId::new("entity1".to_string()),
target: EntityId::new("entity2".to_string()),
relation_type: "PRODUCES".to_string(),
confidence: 0.8,
context: vec![],
embedding: None,
temporal_type: None,
temporal_range: None,
causal_strength: None,
};
graph.add_relationship(relationship).unwrap();
let chunk1 = TextChunk::new(
ChunkId::new("chunk1".to_string()),
DocumentId::new("doc1".to_string()),
"Apple Inc is a technology company that produces the iPhone.".to_string(),
0,
56,
)
.with_entities(vec![
EntityId::new("entity1".to_string()),
EntityId::new("entity2".to_string()),
]);
graph.add_chunk(chunk1).unwrap();
graph
}
#[test]
fn test_pagerank_retrieval_system_creation() {
let retrieval = PageRankRetrievalSystem::new(10);
let stats = retrieval.get_search_statistics();
assert_eq!(stats.max_results, 10);
assert!(!stats.has_vector_index);
}
#[test]
fn test_vector_similarity_search() {
let graph = create_test_graph();
let retrieval = PageRankRetrievalSystem::new(10);
let scores = retrieval
.vector_similarity_search("Apple technology", &graph)
.unwrap();
assert!(!scores.is_empty());
assert!(scores.contains_key(&EntityId::new("entity1".to_string())));
}
#[test]
fn test_quick_entity_search() {
let graph = create_test_graph();
let retrieval = PageRankRetrievalSystem::new(10);
let results = retrieval.quick_entity_search("Apple", &graph);
assert!(!results.is_empty());
assert_eq!(results[0].entity_id, EntityId::new("entity1".to_string()));
assert!(results[0].score > 0.7);
}
#[test]
fn test_search_with_pagerank() {
let graph = create_test_graph();
let mut retrieval = PageRankRetrievalSystem::new(10);
retrieval.initialize_vector_index(&graph).unwrap();
let results = retrieval
.search_with_pagerank("Apple iPhone", &graph, None)
.unwrap();
if !results.is_empty() {
assert!(results[0].vector_score >= 0.0);
assert!(results[0].pagerank_score >= 0.0);
assert!(results[0].combined_score > 0.0);
}
}
#[test]
fn test_precompute_global_pagerank() {
let graph = create_test_graph();
let mut retrieval = PageRankRetrievalSystem::new(10);
retrieval.precompute_global_pagerank(&graph).unwrap();
assert!(retrieval.global_pagerank.is_some());
let global_scores = retrieval.global_pagerank.as_ref().unwrap();
assert!(!global_scores.is_empty());
let total_score: f64 = global_scores.values().sum();
assert!((total_score - 1.0).abs() < 0.01);
}
#[test]
fn test_batch_search() {
let graph = create_test_graph();
let mut retrieval = PageRankRetrievalSystem::new(5);
retrieval.initialize_vector_index(&graph).unwrap();
let queries = vec!["Apple", "iPhone", "Steve Jobs"];
let results = retrieval.batch_search(&queries, &graph, Some(3)).unwrap();
assert_eq!(results.len(), 3);
for query_results in &results {
assert!(query_results.len() <= 3);
}
}
#[test]
fn test_pagerank_config_performance() {
let graph = create_test_graph();
let parallel_config = PageRankConfig {
parallel_enabled: true,
cache_size: 100,
..PageRankConfig::default()
};
let sequential_config = PageRankConfig {
parallel_enabled: false,
cache_size: 100,
..PageRankConfig::default()
};
let mut parallel_retrieval =
PageRankRetrievalSystem::new(5).with_pagerank_config(parallel_config);
let mut sequential_retrieval =
PageRankRetrievalSystem::new(5).with_pagerank_config(sequential_config);
parallel_retrieval.initialize_vector_index(&graph).unwrap();
sequential_retrieval
.initialize_vector_index(&graph)
.unwrap();
let query = "Apple iPhone";
let parallel_results = parallel_retrieval
.search_with_pagerank(query, &graph, None)
.unwrap();
let sequential_results = sequential_retrieval
.search_with_pagerank(query, &graph, None)
.unwrap();
assert!(!parallel_results.is_empty());
assert!(!sequential_results.is_empty());
}
#[test]
fn test_cache_effectiveness() {
let graph = create_test_graph();
let mut retrieval = PageRankRetrievalSystem::new(5).with_pagerank_config(PageRankConfig {
cache_size: 1000,
..PageRankConfig::default()
});
retrieval.initialize_vector_index(&graph).unwrap();
let query = "Apple";
let start_time = std::time::Instant::now();
let _results1 = retrieval.search_with_pagerank(query, &graph, None).unwrap();
let first_duration = start_time.elapsed();
let start_time = std::time::Instant::now();
let _results2 = retrieval.search_with_pagerank(query, &graph, None).unwrap();
let second_duration = start_time.elapsed();
assert!(first_duration > std::time::Duration::from_nanos(0));
assert!(second_duration > std::time::Duration::from_nanos(0));
}
#[test]
fn test_incremental_mode() {
let graph = create_test_graph();
let mut retrieval = PageRankRetrievalSystem::new(5).with_incremental_mode(true);
retrieval.initialize_vector_index(&graph).unwrap();
let results = retrieval
.search_with_pagerank("Apple", &graph, None)
.unwrap();
assert!(!results.is_empty());
let update_result = retrieval.update(vec!["new content".to_string()]);
assert!(update_result.is_ok());
}
}