use super::concept_graph::ConceptGraph;
use petgraph::algo::page_rank;
use petgraph::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConceptSelectionConfig {
pub top_k: usize,
pub min_score: f32,
pub degree_weight: f32,
pub pagerank_weight: f32,
pub idf_weight: f32,
pub pagerank_damping: f64,
pub pagerank_tolerance: f64,
pub use_semantic_matching: bool,
}
impl Default for ConceptSelectionConfig {
fn default() -> Self {
Self {
top_k: 20,
min_score: 0.1,
degree_weight: 0.4,
pagerank_weight: 0.4,
idf_weight: 0.2,
pagerank_damping: 0.85,
pagerank_tolerance: 1e-6,
use_semantic_matching: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RankedConcept {
pub concept: String,
pub score: f32,
pub degree_score: f32,
pub pagerank_score: f32,
pub idf_score: f32,
pub document_frequency: usize,
pub total_frequency: usize,
}
pub struct ConceptRanker {
config: ConceptSelectionConfig,
}
impl ConceptRanker {
pub fn new(config: ConceptSelectionConfig) -> Self {
Self { config }
}
pub fn rank_concepts(
&self,
graph: &ConceptGraph,
total_documents: usize,
) -> Vec<RankedConcept> {
let degree_scores = self.calculate_degree_centrality(graph);
let pagerank_scores = self.calculate_pagerank(graph);
let idf_scores = self.calculate_idf(graph, total_documents);
let mut ranked_concepts = Vec::new();
for (concept_text, concept_data) in &graph.concepts {
let degree_score = degree_scores.get(concept_text).copied().unwrap_or(0.0);
let pagerank_score = pagerank_scores.get(concept_text).copied().unwrap_or(0.0);
let idf_score = idf_scores.get(concept_text).copied().unwrap_or(0.0);
let combined_score = (self.config.degree_weight * degree_score)
+ (self.config.pagerank_weight * pagerank_score)
+ (self.config.idf_weight * idf_score);
if combined_score < self.config.min_score {
continue;
}
ranked_concepts.push(RankedConcept {
concept: concept_text.clone(),
score: combined_score,
degree_score,
pagerank_score,
idf_score,
document_frequency: concept_data.document_ids.len(),
total_frequency: concept_data.frequency,
});
}
ranked_concepts.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
ranked_concepts
}
fn calculate_degree_centrality(&self, graph: &ConceptGraph) -> HashMap<String, f32> {
let mut scores = HashMap::new();
let total_nodes = graph.graph.node_count();
if total_nodes <= 1 {
return scores;
}
let max_possible_degree = (total_nodes - 1) as f32;
for (concept_text, &node_idx) in &graph.concept_to_node {
let in_degree = graph
.graph
.neighbors_directed(node_idx, Direction::Incoming)
.count();
let out_degree = graph
.graph
.neighbors_directed(node_idx, Direction::Outgoing)
.count();
let total_degree = (in_degree + out_degree) as f32;
let normalized_degree = total_degree / max_possible_degree;
scores.insert(concept_text.clone(), normalized_degree);
}
scores
}
fn calculate_pagerank(&self, graph: &ConceptGraph) -> HashMap<String, f32> {
let mut scores = HashMap::new();
let pr_scores = page_rank(
&graph.graph,
self.config.pagerank_damping,
None, );
let max_score = pr_scores
.iter()
.copied()
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(1.0);
for (concept_text, &node_idx) in &graph.concept_to_node {
if let Some(&pr_score) = pr_scores.get(node_idx.index()) {
let normalized_score = (pr_score / max_score) as f32;
scores.insert(concept_text.clone(), normalized_score);
}
}
scores
}
fn calculate_idf(&self, graph: &ConceptGraph, total_documents: usize) -> HashMap<String, f32> {
let mut scores = HashMap::new();
if total_documents == 0 {
return scores;
}
let total_docs = total_documents as f32;
for (concept_text, concept_data) in &graph.concepts {
let doc_freq = concept_data.document_ids.len() as f32;
if doc_freq == 0.0 {
scores.insert(concept_text.clone(), 0.0);
continue;
}
let idf = (total_docs / doc_freq).ln();
let normalized_idf = (idf / 5.0).tanh();
scores.insert(concept_text.clone(), normalized_idf as f32);
}
scores
}
pub fn get_top_k(&self, ranked_concepts: &[RankedConcept]) -> Vec<String> {
ranked_concepts
.iter()
.take(self.config.top_k)
.map(|rc| rc.concept.clone())
.collect()
}
pub fn filter_by_threshold(&self, ranked_concepts: &[RankedConcept]) -> Vec<RankedConcept> {
ranked_concepts
.iter()
.filter(|rc| rc.score >= self.config.min_score)
.cloned()
.collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConceptRankingStats {
pub total_concepts: usize,
pub concepts_above_threshold: usize,
pub average_score: f32,
pub max_score: f32,
pub min_score: f32,
pub score_distribution: Vec<(f32, usize)>,
}
impl ConceptRanker {
pub fn calculate_stats(&self, ranked_concepts: &[RankedConcept]) -> ConceptRankingStats {
if ranked_concepts.is_empty() {
return ConceptRankingStats {
total_concepts: 0,
concepts_above_threshold: 0,
average_score: 0.0,
max_score: 0.0,
min_score: 0.0,
score_distribution: vec![],
};
}
let total = ranked_concepts.len();
let above_threshold = ranked_concepts
.iter()
.filter(|rc| rc.score >= self.config.min_score)
.count();
let avg_score = ranked_concepts.iter().map(|rc| rc.score).sum::<f32>() / total as f32;
let max_score = ranked_concepts
.iter()
.map(|rc| rc.score)
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(0.0);
let min_score = ranked_concepts
.iter()
.map(|rc| rc.score)
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(0.0);
let num_bins = 10;
let bin_size = (max_score - min_score) / num_bins as f32;
let mut histogram = vec![0usize; num_bins];
for rc in ranked_concepts {
let bin = if bin_size > 0.0 {
((rc.score - min_score) / bin_size).floor() as usize
} else {
0
};
let bin_idx = bin.min(num_bins - 1);
histogram[bin_idx] += 1;
}
let score_distribution: Vec<(f32, usize)> = histogram
.into_iter()
.enumerate()
.map(|(i, count)| {
let bin_start = min_score + (i as f32 * bin_size);
(bin_start, count)
})
.collect();
ConceptRankingStats {
total_concepts: total,
concepts_above_threshold: above_threshold,
average_score: avg_score,
max_score,
min_score,
score_distribution,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lightrag::concept_graph::{ConceptExtractor, ConceptGraphBuilder};
#[test]
fn test_concept_ranking() {
let mut builder = ConceptGraphBuilder::new();
builder.add_document_concepts(
"doc1",
vec![
"machine learning".to_string(),
"neural networks".to_string(),
],
);
builder.add_document_concepts(
"doc2",
vec!["machine learning".to_string(), "deep learning".to_string()],
);
builder.add_document_concepts(
"doc3",
vec!["neural networks".to_string(), "deep learning".to_string()],
);
builder.add_chunk_concepts(
"chunk1",
vec![
"machine learning".to_string(),
"neural networks".to_string(),
],
);
builder.add_chunk_concepts(
"chunk2",
vec!["machine learning".to_string(), "deep learning".to_string()],
);
builder.add_chunk_concepts(
"chunk3",
vec!["neural networks".to_string(), "deep learning".to_string()],
);
let graph = builder.build();
let ranker = ConceptRanker::new(ConceptSelectionConfig::default());
let ranked = ranker.rank_concepts(&graph, 3);
assert!(!ranked.is_empty());
let ml_concept = ranked.iter().find(|rc| rc.concept.contains("machine"));
assert!(ml_concept.is_some());
for rc in &ranked {
assert!(rc.score >= 0.0 && rc.score <= 1.0);
assert!(rc.degree_score >= 0.0 && rc.degree_score <= 1.0);
assert!(rc.pagerank_score >= 0.0 && rc.pagerank_score <= 1.0);
assert!(rc.idf_score >= 0.0 && rc.idf_score <= 1.0);
}
}
#[test]
fn test_degree_centrality() {
let mut builder = ConceptGraphBuilder::new();
builder.add_chunk_concepts("chunk1", vec!["a".to_string(), "b".to_string()]);
builder.add_chunk_concepts("chunk2", vec!["a".to_string(), "c".to_string()]);
builder.add_chunk_concepts("chunk3", vec!["a".to_string(), "d".to_string()]);
let graph = builder.build();
let ranker = ConceptRanker::new(ConceptSelectionConfig::default());
let degree_scores = ranker.calculate_degree_centrality(&graph);
let a_score = degree_scores.get("a").copied().unwrap_or(0.0);
assert!(a_score > 0.0);
for score in degree_scores.values() {
assert!(*score >= 0.0 && *score <= 1.0);
}
}
#[test]
fn test_idf_calculation() {
let mut builder = ConceptGraphBuilder::new();
builder.add_document_concepts("doc1", vec!["rare_term".to_string()]);
builder.add_document_concepts("doc2", vec!["common_term".to_string()]);
builder.add_document_concepts("doc3", vec!["common_term".to_string()]);
builder.add_document_concepts("doc4", vec!["common_term".to_string()]);
let graph = builder.build();
let ranker = ConceptRanker::new(ConceptSelectionConfig::default());
let idf_scores = ranker.calculate_idf(&graph, 4);
let rare_score = idf_scores.get("rare_term").copied().unwrap_or(0.0);
let common_score = idf_scores.get("common_term").copied().unwrap_or(0.0);
assert!(rare_score > common_score);
}
#[test]
fn test_top_k_selection() {
let mut builder = ConceptGraphBuilder::new();
for i in 0..50 {
builder.add_document_concepts(&format!("doc{}", i), vec![format!("concept_{}", i)]);
}
let graph = builder.build();
let ranker = ConceptRanker::new(ConceptSelectionConfig {
top_k: 10,
..Default::default()
});
let ranked = ranker.rank_concepts(&graph, 50);
let top_k = ranker.get_top_k(&ranked);
assert_eq!(top_k.len(), 10);
}
#[test]
fn test_stats_calculation() {
let mut builder = ConceptGraphBuilder::new();
builder.add_document_concepts("doc1", vec!["a".to_string(), "b".to_string()]);
builder.add_document_concepts("doc2", vec!["c".to_string(), "d".to_string()]);
let graph = builder.build();
let ranker = ConceptRanker::new(ConceptSelectionConfig::default());
let ranked = ranker.rank_concepts(&graph, 2);
let stats = ranker.calculate_stats(&ranked);
assert_eq!(stats.total_concepts, ranked.len());
assert!(stats.average_score >= 0.0);
assert!(stats.max_score >= stats.min_score);
assert_eq!(stats.score_distribution.len(), 10); }
}