use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExpansionSource {
Synonym,
Hypernym,
Hyponym,
RelatedEntity,
CoOccurrence,
}
#[derive(Debug, Clone)]
pub struct ExpansionTerm {
pub term: String,
pub score: f64,
pub source: ExpansionSource,
}
#[derive(Debug, Clone)]
pub struct ExpandedQuery {
pub original: String,
pub expansions: Vec<ExpansionTerm>,
pub expanded_terms: Vec<String>,
}
#[derive(Debug, Clone, Default)]
pub struct GraphContext {
pub triples: Vec<(String, String, String)>,
}
impl GraphContext {
pub fn new() -> Self {
Self::default()
}
pub fn add_triple(
&mut self,
subject: impl Into<String>,
predicate: impl Into<String>,
object: impl Into<String>,
) {
self.triples
.push((subject.into(), predicate.into(), object.into()));
}
}
pub struct QueryExpander {
synonyms: HashMap<String, Vec<String>>,
context: GraphContext,
max_expansions: usize,
}
impl QueryExpander {
pub fn new(max_expansions: usize) -> Self {
Self {
synonyms: HashMap::new(),
context: GraphContext::new(),
max_expansions: max_expansions.max(1),
}
}
pub fn add_synonym(&mut self, term: &str, synonym: &str) {
self.synonyms
.entry(term.to_lowercase())
.or_default()
.push(synonym.to_lowercase());
self.synonyms
.entry(synonym.to_lowercase())
.or_default()
.push(term.to_lowercase());
}
pub fn set_context(&mut self, context: GraphContext) {
self.context = context;
}
pub fn score_expansion(source: &ExpansionSource) -> f64 {
match source {
ExpansionSource::Synonym => 0.9,
ExpansionSource::Hypernym => 0.7,
ExpansionSource::Hyponym => 0.6,
ExpansionSource::RelatedEntity => 0.5,
ExpansionSource::CoOccurrence => 0.4,
}
}
pub fn synonyms_for<'a>(&'a self, term: &str) -> Vec<&'a str> {
self.synonyms
.get(&term.to_lowercase())
.map(|v| v.iter().map(|s| s.as_str()).collect())
.unwrap_or_default()
}
pub fn related_entities(&self, term: &str) -> Vec<String> {
let term_lc = term.to_lowercase();
let mut related: Vec<String> = Vec::new();
for (s, _p, o) in &self.context.triples {
if s.to_lowercase() == term_lc {
related.push(o.clone());
} else if o.to_lowercase() == term_lc {
related.push(s.clone());
}
}
let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
related.retain(|x| seen.insert(x.clone()));
related
}
pub fn co_occurring_terms(&self, term: &str) -> Vec<(String, usize)> {
let term_lc = term.to_lowercase();
let mut freq: HashMap<String, usize> = HashMap::new();
for (s, _p, o) in &self.context.triples {
let s_lc = s.to_lowercase();
let o_lc = o.to_lowercase();
if s_lc == term_lc {
*freq.entry(o.clone()).or_insert(0) += 1;
} else if o_lc == term_lc {
*freq.entry(s.clone()).or_insert(0) += 1;
}
}
let mut pairs: Vec<(String, usize)> = freq.into_iter().collect();
pairs.sort_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0)));
pairs
}
pub fn expand(&self, query: &str) -> ExpandedQuery {
let tokens: Vec<&str> = query.split_whitespace().collect();
let mut all_expansions: Vec<ExpansionTerm> = Vec::new();
let mut seen_terms: std::collections::HashSet<String> = std::collections::HashSet::new();
for token in &tokens {
let lc_token = token.to_lowercase();
for syn in self.synonyms_for(token) {
let key = syn.to_lowercase();
if !seen_terms.contains(&key) && key != lc_token {
seen_terms.insert(key.clone());
all_expansions.push(ExpansionTerm {
term: syn.to_string(),
score: Self::score_expansion(&ExpansionSource::Synonym),
source: ExpansionSource::Synonym,
});
}
}
for entity in self.related_entities(token) {
let key = entity.to_lowercase();
if !seen_terms.contains(&key) && key != lc_token {
seen_terms.insert(key.clone());
all_expansions.push(ExpansionTerm {
term: entity,
score: Self::score_expansion(&ExpansionSource::RelatedEntity),
source: ExpansionSource::RelatedEntity,
});
}
}
for (co_term, freq) in self.co_occurring_terms(token) {
let key = co_term.to_lowercase();
if !seen_terms.contains(&key) && key != lc_token {
seen_terms.insert(key.clone());
let freq_bonus = (freq as f64 * 0.05).min(0.1);
let score =
(Self::score_expansion(&ExpansionSource::CoOccurrence) + freq_bonus)
.min(1.0);
all_expansions.push(ExpansionTerm {
term: co_term,
score,
source: ExpansionSource::CoOccurrence,
});
}
}
}
all_expansions.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.term.cmp(&b.term))
});
all_expansions.truncate(self.max_expansions);
let expanded_terms: Vec<String> = all_expansions.iter().map(|e| e.term.clone()).collect();
ExpandedQuery {
original: query.to_string(),
expansions: all_expansions,
expanded_terms,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_graph_context_new_empty() {
let ctx = GraphContext::new();
assert!(ctx.triples.is_empty());
}
#[test]
fn test_graph_context_add_triple() {
let mut ctx = GraphContext::new();
ctx.add_triple("Battery", "type", "EnergyStorage");
assert_eq!(ctx.triples.len(), 1);
assert_eq!(ctx.triples[0].0, "Battery");
}
#[test]
fn test_graph_context_multiple_triples() {
let mut ctx = GraphContext::new();
ctx.add_triple("A", "rel", "B");
ctx.add_triple("B", "rel", "C");
assert_eq!(ctx.triples.len(), 2);
}
#[test]
fn test_score_synonym_is_highest() {
let syn = QueryExpander::score_expansion(&ExpansionSource::Synonym);
let rel = QueryExpander::score_expansion(&ExpansionSource::RelatedEntity);
assert!(syn > rel);
}
#[test]
fn test_score_all_sources() {
assert_eq!(QueryExpander::score_expansion(&ExpansionSource::Synonym), 0.9);
assert_eq!(QueryExpander::score_expansion(&ExpansionSource::Hypernym), 0.7);
assert_eq!(QueryExpander::score_expansion(&ExpansionSource::Hyponym), 0.6);
assert_eq!(QueryExpander::score_expansion(&ExpansionSource::RelatedEntity), 0.5);
assert_eq!(QueryExpander::score_expansion(&ExpansionSource::CoOccurrence), 0.4);
}
#[test]
fn test_score_ordering() {
let sources = [
ExpansionSource::Synonym,
ExpansionSource::Hypernym,
ExpansionSource::Hyponym,
ExpansionSource::RelatedEntity,
ExpansionSource::CoOccurrence,
];
for w in sources.windows(2) {
assert!(
QueryExpander::score_expansion(&w[0]) > QueryExpander::score_expansion(&w[1]),
"{:?} should score higher than {:?}",
w[0],
w[1]
);
}
}
#[test]
fn test_new_stores_max_expansions() {
let qe = QueryExpander::new(10);
assert_eq!(qe.max_expansions, 10);
}
#[test]
fn test_new_zero_clamps_to_one() {
let qe = QueryExpander::new(0);
assert_eq!(qe.max_expansions, 1);
}
#[test]
fn test_add_synonym_basic() {
let mut qe = QueryExpander::new(10);
qe.add_synonym("car", "automobile");
let syns = qe.synonyms_for("car");
assert!(syns.contains(&"automobile"));
}
#[test]
fn test_add_synonym_bidirectional() {
let mut qe = QueryExpander::new(10);
qe.add_synonym("car", "vehicle");
assert!(qe.synonyms_for("vehicle").contains(&"car"));
}
#[test]
fn test_add_synonym_case_insensitive() {
let mut qe = QueryExpander::new(10);
qe.add_synonym("Car", "Automobile");
let syns = qe.synonyms_for("car");
assert!(syns.contains(&"automobile"));
}
#[test]
fn test_synonyms_for_unknown_returns_empty() {
let qe = QueryExpander::new(10);
assert!(qe.synonyms_for("unknown").is_empty());
}
#[test]
fn test_add_multiple_synonyms() {
let mut qe = QueryExpander::new(10);
qe.add_synonym("car", "auto");
qe.add_synonym("car", "vehicle");
qe.add_synonym("car", "automobile");
assert_eq!(qe.synonyms_for("car").len(), 3);
}
#[test]
fn test_related_entities_as_subject() {
let mut qe = QueryExpander::new(20);
let mut ctx = GraphContext::new();
ctx.add_triple("Battery", "hasComponent", "Anode");
ctx.add_triple("Battery", "hasComponent", "Cathode");
qe.set_context(ctx);
let related = qe.related_entities("Battery");
assert!(related.contains(&"Anode".to_string()));
assert!(related.contains(&"Cathode".to_string()));
}
#[test]
fn test_related_entities_as_object() {
let mut qe = QueryExpander::new(20);
let mut ctx = GraphContext::new();
ctx.add_triple("Plant", "produces", "Oxygen");
qe.set_context(ctx);
let related = qe.related_entities("Oxygen");
assert!(related.contains(&"Plant".to_string()));
}
#[test]
fn test_related_entities_deduplication() {
let mut qe = QueryExpander::new(20);
let mut ctx = GraphContext::new();
ctx.add_triple("A", "rel1", "B");
ctx.add_triple("A", "rel2", "B");
qe.set_context(ctx);
let related = qe.related_entities("A");
assert_eq!(related.iter().filter(|e| e.as_str() == "B").count(), 1);
}
#[test]
fn test_related_entities_empty_context() {
let qe = QueryExpander::new(10);
assert!(qe.related_entities("anything").is_empty());
}
#[test]
fn test_related_entities_case_insensitive() {
let mut qe = QueryExpander::new(20);
let mut ctx = GraphContext::new();
ctx.add_triple("battery", "type", "LiIon");
qe.set_context(ctx);
let related = qe.related_entities("Battery");
assert!(related.contains(&"LiIon".to_string()));
}
#[test]
fn test_co_occurring_terms_basic() {
let mut qe = QueryExpander::new(20);
let mut ctx = GraphContext::new();
ctx.add_triple("A", "rel", "B");
ctx.add_triple("A", "rel", "B");
ctx.add_triple("A", "rel", "C");
qe.set_context(ctx);
let co = qe.co_occurring_terms("A");
assert_eq!(co[0].0, "B");
assert_eq!(co[0].1, 2);
}
#[test]
fn test_co_occurring_terms_empty() {
let qe = QueryExpander::new(10);
assert!(qe.co_occurring_terms("anything").is_empty());
}
#[test]
fn test_co_occurring_terms_sorted_desc() {
let mut qe = QueryExpander::new(20);
let mut ctx = GraphContext::new();
ctx.add_triple("X", "p", "Z");
ctx.add_triple("X", "p", "Y");
ctx.add_triple("X", "p", "Y");
ctx.add_triple("X", "p", "Y");
qe.set_context(ctx);
let co = qe.co_occurring_terms("X");
assert_eq!(co[0].0, "Y");
assert_eq!(co[0].1, 3);
}
#[test]
fn test_expand_empty_query() {
let qe = QueryExpander::new(10);
let eq = qe.expand("");
assert_eq!(eq.original, "");
assert!(eq.expansions.is_empty());
}
#[test]
fn test_expand_with_synonym() {
let mut qe = QueryExpander::new(10);
qe.add_synonym("cat", "feline");
let eq = qe.expand("cat");
assert!(eq.expanded_terms.contains(&"feline".to_string()));
}
#[test]
fn test_expand_respects_max_expansions() {
let mut qe = QueryExpander::new(2);
qe.add_synonym("x", "a");
qe.add_synonym("x", "b");
qe.add_synonym("x", "c");
let eq = qe.expand("x");
assert!(eq.expansions.len() <= 2);
}
#[test]
fn test_expand_no_duplicates() {
let mut qe = QueryExpander::new(20);
qe.add_synonym("dog", "hound");
qe.add_synonym("dog", "hound"); let eq = qe.expand("dog");
let count = eq.expanded_terms.iter().filter(|t| t.as_str() == "hound").count();
assert_eq!(count, 1);
}
#[test]
fn test_expand_original_preserved() {
let qe = QueryExpander::new(10);
let eq = qe.expand("find batteries");
assert_eq!(eq.original, "find batteries");
}
#[test]
fn test_expand_with_graph_context() {
let mut qe = QueryExpander::new(20);
let mut ctx = GraphContext::new();
ctx.add_triple("Battery", "hasComponent", "Anode");
qe.set_context(ctx);
let eq = qe.expand("Battery");
assert!(
eq.expanded_terms.contains(&"Anode".to_string()),
"expanded_terms: {:?}",
eq.expanded_terms
);
}
#[test]
fn test_expand_combined_sources() {
let mut qe = QueryExpander::new(20);
qe.add_synonym("cell", "unit");
let mut ctx = GraphContext::new();
ctx.add_triple("cell", "partOf", "module");
qe.set_context(ctx);
let eq = qe.expand("cell");
let sources: Vec<&ExpansionSource> = eq.expansions.iter().map(|e| &e.source).collect();
assert!(sources.contains(&&ExpansionSource::Synonym));
assert!(sources.contains(&&ExpansionSource::RelatedEntity));
}
#[test]
fn test_expand_multi_token_query() {
let mut qe = QueryExpander::new(20);
qe.add_synonym("find", "search");
qe.add_synonym("cell", "battery");
let eq = qe.expand("find cell");
assert!(eq.expanded_terms.contains(&"search".to_string()));
assert!(eq.expanded_terms.contains(&"battery".to_string()));
}
#[test]
fn test_expand_term_not_repeated_in_expansions() {
let mut qe = QueryExpander::new(20);
qe.add_synonym("car", "auto");
let eq = qe.expand("car");
assert!(!eq.expanded_terms.contains(&"car".to_string()));
}
#[test]
fn test_expand_sorted_by_score_descending() {
let mut qe = QueryExpander::new(20);
qe.add_synonym("item", "thing");
let mut ctx = GraphContext::new();
ctx.add_triple("item", "rel", "object");
qe.set_context(ctx);
let eq = qe.expand("item");
if eq.expansions.len() >= 2 {
let first_score = eq.expansions[0].score;
let last_score = eq.expansions.last().map(|e| e.score).unwrap_or(0.0);
assert!(first_score >= last_score);
}
}
#[test]
fn test_expand_co_occurrence_included() {
let mut qe = QueryExpander::new(20);
let mut ctx = GraphContext::new();
ctx.add_triple("anode", "partOf", "cell");
qe.set_context(ctx);
let eq = qe.expand("anode");
assert!(eq.expanded_terms.contains(&"cell".to_string()));
}
#[test]
fn test_expansion_term_struct() {
let t = ExpansionTerm {
term: "test".to_string(),
score: 0.8,
source: ExpansionSource::Hypernym,
};
assert_eq!(t.term, "test");
assert_eq!(t.score, 0.8);
assert_eq!(t.source, ExpansionSource::Hypernym);
}
#[test]
fn test_expanded_query_struct() {
let eq = ExpandedQuery {
original: "hello".to_string(),
expansions: vec![],
expanded_terms: vec!["world".to_string()],
};
assert_eq!(eq.original, "hello");
assert_eq!(eq.expanded_terms.len(), 1);
}
#[test]
fn test_set_context_replaces_previous() {
let mut qe = QueryExpander::new(10);
let mut ctx1 = GraphContext::new();
ctx1.add_triple("A", "rel", "B");
qe.set_context(ctx1);
let ctx2 = GraphContext::new();
qe.set_context(ctx2);
assert!(qe.related_entities("A").is_empty());
}
#[test]
fn test_co_occurring_from_both_positions() {
let mut qe = QueryExpander::new(20);
let mut ctx = GraphContext::new();
ctx.add_triple("alpha", "to", "beta"); ctx.add_triple("gamma", "to", "alpha"); qe.set_context(ctx);
let co = qe.co_occurring_terms("alpha");
let terms: Vec<&str> = co.iter().map(|(t, _)| t.as_str()).collect();
assert!(terms.contains(&"beta"));
assert!(terms.contains(&"gamma"));
}
#[test]
fn test_expand_case_normalisation_in_expansion() {
let mut qe = QueryExpander::new(10);
qe.add_synonym("Battery", "accumulator");
let eq = qe.expand("battery");
assert!(
eq.expanded_terms.contains(&"accumulator".to_string()),
"expanded_terms: {:?}",
eq.expanded_terms
);
}
#[test]
fn test_expansion_source_debug() {
let s = format!("{:?}", ExpansionSource::Synonym);
assert!(s.contains("Synonym"));
}
#[test]
fn test_expansion_source_equality() {
assert_eq!(ExpansionSource::Hypernym, ExpansionSource::Hypernym);
assert_ne!(ExpansionSource::Hypernym, ExpansionSource::Hyponym);
}
#[test]
fn test_graph_context_clone() {
let mut ctx = GraphContext::new();
ctx.add_triple("A", "b", "C");
let ctx2 = ctx.clone();
assert_eq!(ctx2.triples.len(), 1);
}
}