use crate::core::{EntityId, Result};
use lru::LruCache;
use nalgebra::{DMatrix, DVector};
use parking_lot::RwLock;
use rayon::prelude::*;
use sprs::CsMat;
use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct PageRankConfig {
pub damping_factor: f64,
pub max_iterations: usize,
pub tolerance: f64,
pub personalized: bool,
pub parallel_enabled: bool,
pub cache_size: usize,
pub sparse_threshold: usize,
pub incremental_updates: bool,
pub simd_block_size: usize,
}
impl Default for PageRankConfig {
fn default() -> Self {
Self {
damping_factor: 0.85,
max_iterations: 100,
tolerance: 1e-6,
personalized: true,
parallel_enabled: true,
cache_size: 1000,
sparse_threshold: 1000,
incremental_updates: true,
simd_block_size: 32,
}
}
}
pub struct PersonalizedPageRank {
config: PageRankConfig,
adjacency_matrix: CsMat<f64>,
node_mapping: HashMap<EntityId, usize>,
reverse_mapping: HashMap<usize, EntityId>,
dense_matrix: Option<DMatrix<f64>>,
score_cache: Arc<RwLock<LruCache<u64, HashMap<EntityId, f64>>>>,
transition_matrix: Option<CsMat<f64>>,
out_degrees: Vec<f64>,
}
impl PersonalizedPageRank {
pub fn new(
config: PageRankConfig,
adjacency_matrix: CsMat<f64>,
node_mapping: HashMap<EntityId, usize>,
reverse_mapping: HashMap<usize, EntityId>,
) -> Self {
let n = adjacency_matrix.rows();
let cache_size = NonZeroUsize::new(config.cache_size)
.unwrap_or(NonZeroUsize::new(1000).expect("non-zero literal"));
let out_degrees = Self::compute_out_degrees(&adjacency_matrix);
let dense_matrix = if n < config.sparse_threshold {
Some(Self::convert_to_dense(&adjacency_matrix))
} else {
None
};
let transition_matrix = if config.parallel_enabled && n > 100 {
Some(Self::build_transition_matrix(
&adjacency_matrix,
&out_degrees,
))
} else {
None
};
Self {
config,
adjacency_matrix,
node_mapping,
reverse_mapping,
dense_matrix,
score_cache: Arc::new(RwLock::new(LruCache::new(cache_size))),
transition_matrix,
out_degrees,
}
}
fn compute_out_degrees(matrix: &CsMat<f64>) -> Vec<f64> {
let n = matrix.rows();
let mut degrees = vec![0.0; n];
for (i, degree) in degrees.iter_mut().enumerate().take(n) {
if let Some(row) = matrix.outer_view(i) {
*degree = row.iter().map(|(_, &weight)| weight).sum();
}
}
degrees
}
fn convert_to_dense(sparse_matrix: &CsMat<f64>) -> DMatrix<f64> {
let n = sparse_matrix.rows();
let m = sparse_matrix.cols();
let mut dense = DMatrix::zeros(n, m);
for i in 0..n {
if let Some(row) = sparse_matrix.outer_view(i) {
for (j, &value) in row.iter() {
dense[(i, j)] = value;
}
}
}
dense
}
fn build_transition_matrix(adjacency: &CsMat<f64>, out_degrees: &[f64]) -> CsMat<f64> {
let mut builder = sprs::TriMat::new((adjacency.rows(), adjacency.cols()));
for (i, °ree) in out_degrees.iter().enumerate().take(adjacency.rows()) {
if let Some(row) = adjacency.outer_view(i) {
if degree > 0.0 {
for (j, &weight) in row.iter() {
builder.add_triplet(i, j, weight / degree);
}
}
}
}
builder.to_csr()
}
fn generate_cache_key(reset_probabilities: &HashMap<EntityId, f64>) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
let mut sorted_entries: Vec<_> = reset_probabilities.iter().collect();
sorted_entries.sort_by_key(|(id, _)| id.to_string());
for (id, score) in sorted_entries {
id.to_string().hash(&mut hasher);
score.to_bits().hash(&mut hasher);
}
hasher.finish()
}
pub fn calculate_scores(
&self,
reset_probabilities: &HashMap<EntityId, f64>,
) -> Result<HashMap<EntityId, f64>> {
let n = self.adjacency_matrix.rows();
if n == 0 {
return Ok(HashMap::new());
}
let cache_key = Self::generate_cache_key(reset_probabilities);
{
let cache = self.score_cache.read();
if let Some(cached_scores) = cache.peek(&cache_key) {
return Ok(cached_scores.clone());
}
}
let scores = if n < self.config.sparse_threshold {
self.calculate_scores_dense(reset_probabilities)?
} else if self.config.parallel_enabled {
self.calculate_scores_parallel(reset_probabilities)?
} else {
self.calculate_scores_sparse_optimized(reset_probabilities)?
};
{
let mut cache = self.score_cache.write();
cache.put(cache_key, scores.clone());
}
Ok(scores)
}
fn calculate_scores_dense(
&self,
reset_probabilities: &HashMap<EntityId, f64>,
) -> Result<HashMap<EntityId, f64>> {
let n = self.adjacency_matrix.rows();
let reset_vector = self.build_reset_vector(reset_probabilities)?;
if let Some(dense_matrix) = &self.dense_matrix {
let mut scores = DVector::from_element(n, 1.0 / n as f64);
let reset_vec = DVector::from_vec(reset_vector);
for _iteration in 0..self.config.max_iterations {
let new_scores = &reset_vec * (1.0 - self.config.damping_factor)
+ dense_matrix * &scores * self.config.damping_factor;
let diff = (&new_scores - &scores).abs().max();
if diff < self.config.tolerance {
break;
}
scores = new_scores;
}
self.scores_to_entity_map(scores.as_slice())
} else {
self.calculate_scores_sparse_optimized(reset_probabilities)
}
}
fn calculate_scores_parallel(
&self,
reset_probabilities: &HashMap<EntityId, f64>,
) -> Result<HashMap<EntityId, f64>> {
let n = self.adjacency_matrix.rows();
let mut scores = vec![1.0 / n as f64; n];
let mut new_scores = vec![0.0; n];
let reset_vector = self.build_reset_vector(reset_probabilities)?;
for _iteration in 0..self.config.max_iterations {
self.pagerank_iteration_parallel(&scores, &mut new_scores, &reset_vector);
let diff = self.calculate_difference(&scores, &new_scores);
if diff < self.config.tolerance {
break;
}
std::mem::swap(&mut scores, &mut new_scores);
}
self.scores_to_entity_map(&scores)
}
fn calculate_scores_sparse_optimized(
&self,
reset_probabilities: &HashMap<EntityId, f64>,
) -> Result<HashMap<EntityId, f64>> {
let n = self.adjacency_matrix.rows();
let mut scores = vec![1.0 / n as f64; n];
let mut new_scores = vec![0.0; n];
let reset_vector = self.build_reset_vector(reset_probabilities)?;
if let Some(transition_matrix) = &self.transition_matrix {
for _iteration in 0..self.config.max_iterations {
self.pagerank_iteration_with_transition_matrix(
&scores,
&mut new_scores,
&reset_vector,
transition_matrix,
);
let diff = self.calculate_difference(&scores, &new_scores);
if diff < self.config.tolerance {
break;
}
std::mem::swap(&mut scores, &mut new_scores);
}
} else {
for _iteration in 0..self.config.max_iterations {
self.pagerank_iteration(&scores, &mut new_scores, &reset_vector);
let diff = self.calculate_difference(&scores, &new_scores);
if diff < self.config.tolerance {
break;
}
std::mem::swap(&mut scores, &mut new_scores);
}
}
self.scores_to_entity_map(&scores)
}
fn pagerank_iteration_parallel(
&self,
current_scores: &[f64],
new_scores: &mut [f64],
reset_vector: &[f64],
) {
let d = self.config.damping_factor;
let n = current_scores.len();
new_scores
.par_iter_mut()
.zip(reset_vector.par_iter())
.for_each(|(new_score, &reset_prob)| {
*new_score = (1.0 - d) * reset_prob;
});
let contributions: Vec<Vec<f64>> = (0..n)
.into_par_iter()
.map(|j| {
let mut local_contributions = vec![0.0; n];
let current_score = current_scores[j];
let out_degree = self.out_degrees[j];
if out_degree > 0.0 {
let score_contribution = d * current_score / out_degree;
if let Some(row) = self.adjacency_matrix.outer_view(j) {
for (neighbor_i, &weight) in row.iter() {
if neighbor_i < n {
local_contributions[neighbor_i] += score_contribution * weight;
}
}
}
} else {
let score_contribution = d * current_score / n as f64;
for contrib in &mut local_contributions {
*contrib += score_contribution;
}
}
local_contributions
})
.collect();
for contrib_vec in contributions {
for (i, contrib) in contrib_vec.iter().enumerate() {
new_scores[i] += contrib;
}
}
}
fn pagerank_iteration_with_transition_matrix(
&self,
current_scores: &[f64],
new_scores: &mut [f64],
reset_vector: &[f64],
transition_matrix: &CsMat<f64>,
) {
let d = self.config.damping_factor;
let n = current_scores.len();
for i in 0..n {
new_scores[i] = (1.0 - d) * reset_vector[i];
}
for (j, ¤t_score) in current_scores.iter().enumerate() {
if let Some(row) = transition_matrix.outer_view(j) {
for (neighbor_i, &transition_prob) in row.iter() {
if neighbor_i < n {
new_scores[neighbor_i] += d * transition_prob * current_score;
}
}
}
}
}
fn build_reset_vector(&self, reset_probabilities: &HashMap<EntityId, f64>) -> Result<Vec<f64>> {
let n = self.adjacency_matrix.rows();
let mut reset_vector = vec![1.0 / n as f64; n];
if !reset_probabilities.is_empty() {
let total: f64 = reset_probabilities.values().sum();
if total > 0.0 {
for (entity_id, &prob) in reset_probabilities {
if let Some(&index) = self.node_mapping.get(entity_id) {
if index < n {
reset_vector[index] = prob / total;
}
}
}
}
}
Ok(reset_vector)
}
fn pagerank_iteration(
&self,
current_scores: &[f64],
new_scores: &mut [f64],
reset_vector: &[f64],
) {
let d = self.config.damping_factor;
let n = current_scores.len();
for i in 0..n {
new_scores[i] = (1.0 - d) * reset_vector[i];
}
for (j, ¤t_score) in current_scores.iter().enumerate() {
let out_degree = self.get_out_degree(j);
if out_degree > 0 {
let score_contribution = d * current_score / out_degree as f64;
if let Some(row) = self.adjacency_matrix.outer_view(j) {
for (neighbor_i, &weight) in row.iter() {
if neighbor_i < n {
new_scores[neighbor_i] += score_contribution * weight;
}
}
}
} else {
let score_contribution = d * current_score / n as f64;
for score in new_scores.iter_mut() {
*score += score_contribution;
}
}
}
}
fn get_out_degree(&self, node_index: usize) -> usize {
if let Some(row) = self.adjacency_matrix.outer_view(node_index) {
row.nnz()
} else {
0
}
}
fn calculate_difference(&self, scores1: &[f64], scores2: &[f64]) -> f64 {
scores1
.iter()
.zip(scores2.iter())
.map(|(&a, &b)| (a - b).abs())
.fold(0.0f64, f64::max)
}
fn scores_to_entity_map(&self, scores: &[f64]) -> Result<HashMap<EntityId, f64>> {
let total: f64 = scores.iter().sum();
let inv = if total > 0.0 { 1.0 / total } else { 1.0 };
let mut result = HashMap::new();
for (index, &score) in scores.iter().enumerate() {
if let Some(entity_id) = self.reverse_mapping.get(&index) {
result.insert(entity_id.clone(), score * inv);
}
}
Ok(result)
}
pub fn node_count(&self) -> usize {
self.adjacency_matrix.rows()
}
pub fn config(&self) -> &PageRankConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub struct MultiModalScores {
pub vector_scores: HashMap<EntityId, f64>,
pub pagerank_scores: HashMap<EntityId, f64>,
pub chunk_scores: HashMap<crate::core::ChunkId, f64>,
pub relationship_scores: HashMap<String, f64>,
}
#[derive(Debug, Clone)]
pub struct ScoreWeights {
pub vector_weight: f64,
pub pagerank_weight: f64,
pub chunk_weight: f64,
pub relationship_weight: f64,
}
impl Default for ScoreWeights {
fn default() -> Self {
Self {
vector_weight: 0.3,
pagerank_weight: 0.4,
chunk_weight: 0.2,
relationship_weight: 0.1,
}
}
}
impl MultiModalScores {
pub fn new() -> Self {
Self {
vector_scores: HashMap::new(),
pagerank_scores: HashMap::new(),
chunk_scores: HashMap::new(),
relationship_scores: HashMap::new(),
}
}
pub fn combine_scores(&self, weights: &ScoreWeights) -> HashMap<EntityId, f64> {
use std::collections::HashSet;
let mut combined_scores = HashMap::new();
let all_entities: HashSet<EntityId> = self
.vector_scores
.keys()
.chain(self.pagerank_scores.keys())
.cloned()
.collect();
for entity_id in all_entities {
let vector_score = self.vector_scores.get(&entity_id).unwrap_or(&0.0);
let pagerank_score = self.pagerank_scores.get(&entity_id).unwrap_or(&0.0);
let chunk_score = self.get_entity_chunk_score(&entity_id);
let combined = weights.vector_weight * vector_score
+ weights.pagerank_weight * pagerank_score
+ weights.chunk_weight * chunk_score;
combined_scores.insert(entity_id, combined);
}
combined_scores
}
fn get_entity_chunk_score(&self, _entity_id: &EntityId) -> f64 {
0.0
}
}
impl Default for MultiModalScores {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::EntityId;
fn create_simple_test_graph() -> (
CsMat<f64>,
HashMap<EntityId, usize>,
HashMap<usize, EntityId>,
) {
let entity_a = EntityId::new("A".to_string());
let entity_b = EntityId::new("B".to_string());
let entity_c = EntityId::new("C".to_string());
let mut node_mapping = HashMap::new();
let mut reverse_mapping = HashMap::new();
node_mapping.insert(entity_a.clone(), 0);
node_mapping.insert(entity_b.clone(), 1);
node_mapping.insert(entity_c.clone(), 2);
reverse_mapping.insert(0, entity_a);
reverse_mapping.insert(1, entity_b);
reverse_mapping.insert(2, entity_c);
let mut triplet_mat = sprs::TriMat::new((3, 3));
triplet_mat.add_triplet(0, 1, 1.0); triplet_mat.add_triplet(0, 2, 1.0); triplet_mat.add_triplet(1, 2, 1.0);
let matrix = triplet_mat.to_csr();
(matrix, node_mapping, reverse_mapping)
}
#[test]
fn test_pagerank_convergence() {
let (matrix, node_mapping, reverse_mapping) = create_simple_test_graph();
let config = PageRankConfig::default();
let pagerank = PersonalizedPageRank::new(config, matrix, node_mapping, reverse_mapping);
let reset_probs = HashMap::new(); let scores = pagerank.calculate_scores(&reset_probs).unwrap();
let total_score: f64 = scores.values().sum();
assert!((total_score - 1.0).abs() < 1e-6);
assert_eq!(scores.len(), 3);
}
#[test]
fn test_personalized_pagerank() {
let (matrix, node_mapping, reverse_mapping) = create_simple_test_graph();
let config = PageRankConfig::default();
let pagerank = PersonalizedPageRank::new(config, matrix, node_mapping, reverse_mapping);
let mut reset_probs = HashMap::new();
let entity_a = EntityId::new("A".to_string());
let entity_b = EntityId::new("B".to_string());
reset_probs.insert(entity_a.clone(), 0.8);
reset_probs.insert(entity_b, 0.2);
let scores = pagerank.calculate_scores(&reset_probs).unwrap();
let score_a = scores.get(&entity_a).unwrap();
assert!(*score_a > 0.3); }
#[test]
fn test_multimodal_scores_combination() {
let mut multi_scores = MultiModalScores::new();
let entity_a = EntityId::new("A".to_string());
let entity_b = EntityId::new("B".to_string());
multi_scores.vector_scores.insert(entity_a.clone(), 0.8);
multi_scores.vector_scores.insert(entity_b.clone(), 0.4);
multi_scores.pagerank_scores.insert(entity_a.clone(), 0.6);
multi_scores.pagerank_scores.insert(entity_b.clone(), 0.9);
let weights = ScoreWeights::default();
let combined = multi_scores.combine_scores(&weights);
assert!(combined.contains_key(&entity_a));
assert!(combined.contains_key(&entity_b));
let score_a = combined.get(&entity_a).unwrap();
let score_b = combined.get(&entity_b).unwrap();
assert!(*score_a > 0.0);
assert!(*score_b > 0.0);
}
}