use crate::{
core::{EntityId, KnowledgeGraph},
retrieval::SearchResult,
};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct SymbolicAnchor {
pub concept: String,
pub grounded_entities: Vec<EntityId>,
pub relevance_score: f32,
pub similarity_score: f32,
}
impl SymbolicAnchor {
pub fn new(concept: String, relevance_score: f32) -> Self {
Self {
concept,
grounded_entities: Vec::new(),
relevance_score,
similarity_score: 0.0,
}
}
pub fn add_entity(&mut self, entity_id: EntityId) {
if !self.grounded_entities.contains(&entity_id) {
self.grounded_entities.push(entity_id);
}
}
pub fn with_similarity(mut self, score: f32) -> Self {
self.similarity_score = score.clamp(0.0, 1.0);
self
}
}
pub struct SymbolicAnchoringStrategy {
graph: Arc<KnowledgeGraph>,
min_relevance: f32,
max_anchors: usize,
max_entities_per_anchor: usize,
pagerank_scores: Option<HashMap<EntityId, f32>>,
}
impl SymbolicAnchoringStrategy {
pub fn new(graph: Arc<KnowledgeGraph>) -> Self {
Self {
graph,
min_relevance: 0.3,
max_anchors: 5,
max_entities_per_anchor: 10,
pagerank_scores: None,
}
}
pub fn with_pagerank_scores(mut self, scores: HashMap<EntityId, f32>) -> Self {
self.pagerank_scores = Some(scores);
self
}
pub fn with_min_relevance(mut self, min_relevance: f32) -> Self {
self.min_relevance = min_relevance.clamp(0.0, 1.0);
self
}
pub fn with_max_anchors(mut self, max_anchors: usize) -> Self {
self.max_anchors = max_anchors;
self
}
pub fn extract_anchors(&self, query: &str) -> Vec<SymbolicAnchor> {
let mut anchors = Vec::new();
let concepts = self.extract_concepts(query);
for concept in concepts {
let mut anchor = SymbolicAnchor::new(concept.clone(), 1.0);
let grounded = self.ground_concept(&concept);
for entity_id in grounded.into_iter().take(self.max_entities_per_anchor) {
anchor.add_entity(entity_id);
}
if !anchor.grounded_entities.is_empty() {
let relevance = self.calculate_relevance(&anchor);
anchor.relevance_score = relevance;
if anchor.relevance_score >= self.min_relevance {
anchors.push(anchor);
}
}
}
anchors.sort_by(|a, b| {
b.relevance_score
.partial_cmp(&a.relevance_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
anchors.truncate(self.max_anchors);
anchors
}
fn extract_concepts(&self, query: &str) -> Vec<String> {
let mut concepts = Vec::new();
let words: Vec<&str> = query.split_whitespace().collect();
let conceptual_patterns = [
"what is",
"nature of",
"meaning of",
"definition of",
"concept of",
"idea of",
"philosophy of",
"theory of",
];
let query_lower = query.to_lowercase();
let is_conceptual = conceptual_patterns
.iter()
.any(|pattern| query_lower.contains(pattern));
if is_conceptual {
for (i, word) in words.iter().enumerate() {
let word_lower = word.to_lowercase();
if i > 0 {
let prev_lower = words[i - 1].to_lowercase();
if ["is", "of", "about"].contains(&prev_lower.as_str()) {
let clean = word.trim_matches(|c: char| !c.is_alphanumeric());
if !clean.is_empty() && clean.len() > 2 {
concepts.push(clean.to_string());
}
}
}
if Self::is_likely_concept(&word_lower) {
let clean = word.trim_matches(|c: char| !c.is_alphanumeric());
if !clean.is_empty() && !concepts.contains(&clean.to_string()) {
concepts.push(clean.to_string());
}
}
}
}
if concepts.is_empty() {
for word in words {
if word.len() > 4
&& word
.chars()
.next()
.expect("non-empty string")
.is_uppercase()
{
let clean = word.trim_matches(|c: char| !c.is_alphanumeric());
if !clean.is_empty() {
concepts.push(clean.to_string());
}
}
}
}
concepts
}
fn is_likely_concept(word: &str) -> bool {
let concept_words = [
"love",
"virtue",
"justice",
"truth",
"beauty",
"good",
"evil",
"knowledge",
"wisdom",
"courage",
"philosophy",
"ethics",
"morality",
"freedom",
"happiness",
"meaning",
"purpose",
"existence",
"reality",
"consciousness",
"mind",
"soul",
"spirit",
"nature",
"essence",
];
concept_words.contains(&word)
}
fn ground_concept(&self, concept: &str) -> Vec<EntityId> {
let mut grounded = Vec::new();
let concept_lower = concept.to_lowercase();
for entity in self.graph.entities() {
let entity_name_lower = entity.name.to_lowercase();
let entity_type_lower = entity.entity_type.to_lowercase();
if entity_name_lower.contains(&concept_lower) {
grounded.push(entity.id.clone());
continue;
}
if entity_type_lower == "concept" && entity_name_lower.contains(&concept_lower) {
grounded.push(entity.id.clone());
continue;
}
for rel in self.graph.get_entity_relationships(&entity.id.0) {
if rel.relation_type.to_lowercase().contains(&concept_lower) {
grounded.push(entity.id.clone());
break;
}
}
}
grounded
}
fn calculate_relevance(&self, anchor: &SymbolicAnchor) -> f32 {
if anchor.grounded_entities.is_empty() {
return 0.0;
}
let count_score = (anchor.grounded_entities.len() as f32 / 10.0).min(1.0);
if let Some(ref pagerank) = self.pagerank_scores {
let mut total_pr = 0.0;
let mut found_count = 0;
for entity_id in &anchor.grounded_entities {
if let Some(&pr_score) = pagerank.get(entity_id) {
total_pr += pr_score;
found_count += 1;
}
}
if found_count > 0 {
let avg_pr = total_pr / found_count as f32;
return (count_score * 0.4) + (avg_pr * 0.6);
}
}
count_score
}
pub fn boost_with_anchors(
&self,
mut results: Vec<SearchResult>,
anchors: &[SymbolicAnchor],
) -> Vec<SearchResult> {
if anchors.is_empty() {
return results;
}
let mut entity_anchors: HashMap<String, Vec<&SymbolicAnchor>> = HashMap::new();
for anchor in anchors {
for entity_id in &anchor.grounded_entities {
let entity_str = entity_id.0.clone();
entity_anchors.entry(entity_str).or_default().push(anchor);
}
}
for result in &mut results {
let mut total_boost = 0.0;
let mut match_count = 0;
for entity_name in &result.entities {
if let Some(matching_anchors) = entity_anchors.get(entity_name) {
let boost: f32 = matching_anchors
.iter()
.map(|a| a.relevance_score)
.sum::<f32>()
/ matching_anchors.len() as f32;
total_boost += boost;
match_count += 1;
}
}
if match_count > 0 {
let avg_boost = total_boost / match_count as f32;
let _original_score = result.score;
result.score *= 1.0 + avg_boost;
#[cfg(feature = "tracing")]
tracing::debug!(
result_id = %result.id,
original_score = _original_score,
boost = avg_boost,
boosted_score = result.score,
matched_entities = match_count,
"Applied symbolic anchor boost"
);
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
}
pub fn is_conceptual_query(query: &str) -> bool {
let query_lower = query.to_lowercase();
let conceptual_patterns = [
"what is",
"what are",
"nature of",
"meaning of",
"definition of",
"concept of",
"idea of",
"philosophy of",
"theory of",
"how does",
"why does",
"explain",
];
conceptual_patterns
.iter()
.any(|pattern| query_lower.contains(pattern))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::Entity;
fn create_test_graph() -> KnowledgeGraph {
let mut graph = KnowledgeGraph::new();
let love_entity = Entity::new(
EntityId::new("concept_love".to_string()),
"love".to_string(),
"CONCEPT".to_string(),
0.9,
);
graph.add_entity(love_entity).unwrap();
let phaedrus = Entity::new(
EntityId::new("dialog_phaedrus".to_string()),
"Phaedrus".to_string(),
"DIALOG".to_string(),
0.95,
);
graph.add_entity(phaedrus).unwrap();
graph
}
#[test]
fn test_is_conceptual_query() {
assert!(is_conceptual_query("What is the nature of love?"));
assert!(is_conceptual_query("Explain the concept of virtue"));
assert!(is_conceptual_query("What are the key ideas in Platonism?"));
assert!(!is_conceptual_query("Who taught Plato?"));
assert!(!is_conceptual_query("When was Socrates born?"));
}
#[test]
fn test_extract_concepts() {
let graph = Arc::new(create_test_graph());
let strategy = SymbolicAnchoringStrategy::new(graph);
let concepts = strategy.extract_concepts("What is the nature of love?");
assert!(concepts.contains(&"love".to_string()));
}
#[test]
fn test_is_likely_concept() {
assert!(SymbolicAnchoringStrategy::is_likely_concept("love"));
assert!(SymbolicAnchoringStrategy::is_likely_concept("virtue"));
assert!(SymbolicAnchoringStrategy::is_likely_concept("justice"));
assert!(!SymbolicAnchoringStrategy::is_likely_concept("table"));
assert!(!SymbolicAnchoringStrategy::is_likely_concept("book"));
}
#[test]
fn test_symbolic_anchor_creation() {
let mut anchor = SymbolicAnchor::new("love".to_string(), 0.8);
anchor.add_entity(EntityId::new("entity1".to_string()));
anchor.add_entity(EntityId::new("entity2".to_string()));
assert_eq!(anchor.concept, "love");
assert_eq!(anchor.grounded_entities.len(), 2);
assert_eq!(anchor.relevance_score, 0.8);
}
#[test]
fn test_pagerank_boost() {
let graph = Arc::new(create_test_graph());
let mut pagerank_scores = HashMap::new();
pagerank_scores.insert(EntityId::new("concept_love".to_string()), 0.3);
pagerank_scores.insert(EntityId::new("dialog_phaedrus".to_string()), 0.9);
let strategy = SymbolicAnchoringStrategy::new(graph).with_pagerank_scores(pagerank_scores);
let mut anchor = SymbolicAnchor::new("love".to_string(), 0.8);
anchor.add_entity(EntityId::new("concept_love".to_string()));
anchor.add_entity(EntityId::new("dialog_phaedrus".to_string()));
let relevance = strategy.calculate_relevance(&anchor);
assert!(
relevance > 0.4 && relevance < 0.5,
"Expected ~0.44, got {}",
relevance
);
}
#[test]
fn test_pagerank_boost_fallback() {
let graph = Arc::new(create_test_graph());
let strategy = SymbolicAnchoringStrategy::new(graph);
let mut anchor = SymbolicAnchor::new("love".to_string(), 0.8);
anchor.add_entity(EntityId::new("concept_love".to_string()));
anchor.add_entity(EntityId::new("dialog_phaedrus".to_string()));
let relevance = strategy.calculate_relevance(&anchor);
assert!(
(relevance - 0.2).abs() < 0.01,
"Expected 0.2, got {}",
relevance
);
}
}