use super::{OntologySource, SemanticEntry};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DisambiguationStrategy {
HighestConfidence,
PopularityPrior,
ContextBased,
NoDisambiguation,
}
impl Default for DisambiguationStrategy {
fn default() -> Self {
DisambiguationStrategy::HighestConfidence
}
}
#[derive(Debug, Default)]
pub struct DisambiguationContext {
pub context_words: Vec<String>,
pub resolved_entities: Vec<String>,
pub topic_hints: Vec<String>,
}
impl DisambiguationContext {
pub fn new() -> Self {
Self::default()
}
pub fn add_context_word(&mut self, word: &str) {
self.context_words.push(word.to_string());
}
pub fn add_resolved_entity(&mut self, uri: &str) {
self.resolved_entities.push(uri.to_string());
}
pub fn add_topic_hint(&mut self, topic: &str) {
self.topic_hints.push(topic.to_string());
}
}
pub fn disambiguate(
candidates: &[SemanticEntry],
strategy: DisambiguationStrategy,
context: Option<&DisambiguationContext>,
) -> Vec<SemanticEntry> {
if candidates.is_empty() {
return Vec::new();
}
match strategy {
DisambiguationStrategy::HighestConfidence => {
let mut sorted = candidates.to_vec();
sorted.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
vec![sorted.remove(0)]
}
DisambiguationStrategy::PopularityPrior => {
let mut scored: Vec<(f32, SemanticEntry)> = candidates
.iter()
.map(|e| {
let boost = match e.source {
OntologySource::Wikidata => 1.2,
OntologySource::DBpedia => 1.1,
_ => 1.0,
};
(e.confidence * boost, e.clone())
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
vec![scored.remove(0).1]
}
DisambiguationStrategy::ContextBased => {
if let Some(ctx) = context {
let mut scored: Vec<(f32, SemanticEntry)> = candidates
.iter()
.map(|e| {
let mut score = e.confidence;
if ctx.resolved_entities.contains(&e.uri) {
score *= 1.5;
}
(score, e.clone())
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
vec![scored.remove(0).1]
} else {
disambiguate(candidates, DisambiguationStrategy::HighestConfidence, None)
}
}
DisambiguationStrategy::NoDisambiguation => {
let mut sorted = candidates.to_vec();
sorted.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
sorted
}
}
}
pub fn score_by_context(
candidates: &[SemanticEntry],
context: &DisambiguationContext,
) -> Vec<(SemanticEntry, f32)> {
candidates
.iter()
.map(|e| {
let mut score = e.confidence;
if context.resolved_entities.contains(&e.uri) {
score *= 1.5;
}
(e.clone(), score)
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn make_entry(uri: &str, confidence: f32, source: OntologySource) -> SemanticEntry {
SemanticEntry::new(uri, confidence, source)
}
#[test]
fn test_highest_confidence() {
let candidates = vec![
make_entry("http://example.org/1", 0.7, OntologySource::Custom),
make_entry("http://example.org/2", 0.9, OntologySource::Custom),
make_entry("http://example.org/3", 0.5, OntologySource::Custom),
];
let result = disambiguate(&candidates, DisambiguationStrategy::HighestConfidence, None);
assert_eq!(result.len(), 1);
assert_eq!(result[0].uri, "http://example.org/2");
}
#[test]
fn test_popularity_prior() {
let candidates = vec![
make_entry("http://example.org/1", 0.8, OntologySource::Custom),
make_entry(
"http://www.wikidata.org/entity/Q1",
0.7,
OntologySource::Wikidata,
),
];
let result = disambiguate(&candidates, DisambiguationStrategy::PopularityPrior, None);
assert_eq!(result.len(), 1);
assert!(result[0].uri.contains("wikidata"));
}
#[test]
fn test_no_disambiguation() {
let candidates = vec![
make_entry("http://example.org/1", 0.7, OntologySource::Custom),
make_entry("http://example.org/2", 0.9, OntologySource::Custom),
];
let result = disambiguate(&candidates, DisambiguationStrategy::NoDisambiguation, None);
assert_eq!(result.len(), 2);
assert_eq!(result[0].uri, "http://example.org/2"); }
#[test]
fn test_unique_candidate() {
let candidates = vec![make_entry(
"http://example.org/only",
0.99,
OntologySource::Wikidata,
)];
let result = disambiguate(&candidates, DisambiguationStrategy::HighestConfidence, None);
assert_eq!(result.len(), 1);
assert_eq!(result[0].uri, "http://example.org/only");
}
#[test]
fn test_add_context_words() {
let mut ctx = DisambiguationContext::new();
ctx.add_context_word("東京");
ctx.add_context_word("都庁");
assert_eq!(ctx.context_words.len(), 2);
}
}