use std::collections::HashMap;
use crate::error::{Result, TextError};
use super::graph::{EntityId, KnowledgeGraph};
#[derive(Debug, Clone, PartialEq)]
pub struct EntityMention {
pub surface: String,
pub span: (usize, usize),
pub ner_type: Option<String>,
pub coref_chain: Option<usize>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct LinkedMention {
pub mention: EntityMention,
pub entity_id: Option<EntityId>,
pub entity_name: Option<String>,
pub confidence: f64,
pub is_nil: bool,
pub candidates: Vec<CandidateEntity>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct CandidateEntity {
pub entity_id: EntityId,
pub entity_name: String,
pub score: f64,
}
#[derive(Debug, Clone)]
struct AliasEntry {
entity_id: EntityId,
entity_name: String,
prior: f64,
}
pub struct EntityLinker {
alias_table: HashMap<String, Vec<AliasEntry>>,
nil_threshold: f64,
max_candidates: usize,
}
impl EntityLinker {
pub fn new() -> Self {
EntityLinker {
alias_table: HashMap::new(),
nil_threshold: 0.15,
max_candidates: 10,
}
}
pub fn with_nil_threshold(mut self, threshold: f64) -> Self {
self.nil_threshold = threshold;
self
}
pub fn build_from_kg(&mut self, kg: &KnowledgeGraph) {
for name in kg.entities() {
if let Some(id) = kg.entity_id(name) {
self.insert_alias(name, id, name, 1.0);
for token in name.split_whitespace() {
if token.len() >= 3 {
self.insert_alias(token, id, name, 0.4);
}
}
let initialism: String = name
.split_whitespace()
.filter_map(|w| w.chars().next())
.filter(|c| c.is_uppercase())
.collect();
if initialism.len() >= 2 {
self.insert_alias(&initialism, id, name, 0.3);
}
}
}
}
pub fn add_alias(&mut self, surface: &str, entity_name: &str, prior: f64) {
let entry = AliasEntry {
entity_id: usize::MAX,
entity_name: entity_name.to_string(),
prior,
};
self.alias_table
.entry(surface.to_lowercase())
.or_default()
.push(entry);
}
fn insert_alias(&mut self, surface: &str, id: EntityId, name: &str, prior: f64) {
let entries = self.alias_table.entry(surface.to_lowercase()).or_default();
if !entries.iter().any(|e| e.entity_id == id) {
entries.push(AliasEntry {
entity_id: id,
entity_name: name.to_string(),
prior,
});
}
}
pub fn generate_candidates(
&self,
surface: &str,
kg: &KnowledgeGraph,
) -> Vec<CandidateEntity> {
let key = surface.to_lowercase();
let mut seen_ids: std::collections::HashSet<EntityId> =
std::collections::HashSet::new();
let mut candidates: Vec<CandidateEntity> = Vec::new();
if let Some(entries) = self.alias_table.get(&key) {
for entry in entries {
let id = if entry.entity_id == usize::MAX {
match kg.entity_id(&entry.entity_name) {
Some(id) => id,
None => continue,
}
} else {
entry.entity_id
};
if seen_ids.insert(id) {
candidates.push(CandidateEntity {
entity_id: id,
entity_name: entry.entity_name.clone(),
score: entry.prior,
});
}
}
}
if candidates.is_empty() {
let lower_surface = surface.to_lowercase();
for name in kg.entities() {
let lower_name = name.to_lowercase();
if lower_name.starts_with(&lower_surface)
|| lower_surface.starts_with(&lower_name)
|| lower_name.contains(&lower_surface)
{
if let Some(id) = kg.entity_id(name) {
if seen_ids.insert(id) {
let overlap = lower_name.len().min(lower_surface.len()) as f64
/ lower_name.len().max(lower_surface.len()).max(1) as f64;
candidates.push(CandidateEntity {
entity_id: id,
entity_name: name.to_string(),
score: overlap * 0.5,
});
}
}
}
}
}
candidates.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
candidates.truncate(self.max_candidates);
candidates
}
fn score_candidate(
&self,
candidate: &CandidateEntity,
mention: &EntityMention,
context: &str,
kg: &KnowledgeGraph,
) -> f64 {
let mut score = candidate.score;
if let Some(ref ner_type) = mention.ner_type {
let entity_types = kg.entity_types(&candidate.entity_name);
let ner_lower = ner_type.to_lowercase();
for et in &entity_types {
if et.to_lowercase().contains(&ner_lower)
|| ner_lower.contains(&et.to_lowercase())
{
score += 0.2;
break;
}
}
}
let ctx_tokens: std::collections::HashSet<String> = context
.split_whitespace()
.map(|t| t.to_lowercase().trim_matches(|c: char| !c.is_alphabetic()).to_string())
.filter(|t| t.len() > 2)
.collect();
let neighbours: Vec<&super::graph::Triple> =
kg.query_all(&candidate.entity_name);
for nb in neighbours {
if let Some(obj_name) = kg.entity_name(nb.object) {
for tok in obj_name.split_whitespace() {
if ctx_tokens.contains(&tok.to_lowercase()) {
score += 0.05;
}
}
}
if let Some(subj_name) = kg.entity_name(nb.subject) {
for tok in subj_name.split_whitespace() {
if ctx_tokens.contains(&tok.to_lowercase()) {
score += 0.05;
}
}
}
}
score.min(1.0)
}
pub fn link_mention(
&self,
mention: &EntityMention,
context: &str,
kg: &KnowledgeGraph,
) -> LinkedMention {
let mut candidates = self.generate_candidates(&mention.surface, kg);
for c in candidates.iter_mut() {
c.score = self.score_candidate(c, mention, context, kg);
}
candidates.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let top = candidates.first();
let (entity_id, entity_name, confidence, is_nil) = match top {
Some(c) if c.score >= self.nil_threshold => {
(Some(c.entity_id), Some(c.entity_name.clone()), c.score, false)
}
_ => (None, None, 0.0, true),
};
LinkedMention {
mention: mention.clone(),
entity_id,
entity_name,
confidence,
is_nil,
candidates,
}
}
pub fn link_document(
&self,
mentions: &[EntityMention],
document: &str,
kg: &KnowledgeGraph,
) -> Vec<LinkedMention> {
let mut linked: Vec<LinkedMention> = mentions
.iter()
.map(|m| self.link_mention(m, document, kg))
.collect();
self.propagate_coref_links(&mut linked);
linked
}
fn propagate_coref_links(&self, linked: &mut Vec<LinkedMention>) {
let mut chain_best: HashMap<usize, usize> = HashMap::new();
for (i, lm) in linked.iter().enumerate() {
if let Some(chain_id) = lm.mention.coref_chain {
if !lm.is_nil {
let entry = chain_best.entry(chain_id).or_insert(i);
if linked[i].confidence > linked[*entry].confidence {
*entry = i;
}
}
}
}
for i in 0..linked.len() {
if linked[i].is_nil {
if let Some(chain_id) = linked[i].mention.coref_chain {
if let Some(&best_idx) = chain_best.get(&chain_id) {
let eid = linked[best_idx].entity_id;
let ename = linked[best_idx].entity_name.clone();
let conf = linked[best_idx].confidence * 0.8; linked[i].entity_id = eid;
linked[i].entity_name = ename;
linked[i].confidence = conf;
linked[i].is_nil = false;
}
}
}
}
}
}
impl Default for EntityLinker {
fn default() -> Self {
Self::new()
}
}
pub fn is_nil_mention(
surface: &str,
kg: &KnowledgeGraph,
linker: &EntityLinker,
threshold: f64,
) -> bool {
let dummy = EntityMention {
surface: surface.to_string(),
span: (0, surface.len()),
ner_type: None,
coref_chain: None,
};
let result = linker.link_mention(&dummy, surface, kg);
result.is_nil || result.confidence < threshold
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_kg() -> KnowledgeGraph {
let mut kg = KnowledgeGraph::new();
kg.add_entity("Apple Inc.", "Organization");
kg.add_entity("Steve Jobs", "Person");
kg.add_entity("Cupertino", "Location");
kg.add_relation("Apple Inc.", "founded_by", "Steve Jobs", 0.99);
kg.add_relation("Apple Inc.", "headquartered_in", "Cupertino", 0.99);
kg
}
#[test]
fn test_build_from_kg() {
let kg = sample_kg();
let mut linker = EntityLinker::new();
linker.build_from_kg(&kg);
assert!(linker.alias_table.contains_key("apple inc."));
}
#[test]
fn test_generate_candidates() {
let kg = sample_kg();
let mut linker = EntityLinker::new();
linker.build_from_kg(&kg);
let candidates = linker.generate_candidates("Apple Inc.", &kg);
assert!(!candidates.is_empty());
assert!(candidates.iter().any(|c| c.entity_name == "Apple Inc."));
}
#[test]
fn test_link_known_entity() {
let kg = sample_kg();
let mut linker = EntityLinker::new();
linker.build_from_kg(&kg);
let mention = EntityMention {
surface: "Apple Inc.".to_string(),
span: (0, 10),
ner_type: Some("Organization".to_string()),
coref_chain: None,
};
let linked = linker.link_mention(&mention, "Apple Inc. was founded by Steve Jobs.", &kg);
assert!(!linked.is_nil, "should link to known entity");
assert_eq!(linked.entity_name.as_deref(), Some("Apple Inc."));
}
#[test]
fn test_nil_for_unknown() {
let kg = sample_kg();
let mut linker = EntityLinker::new();
linker.build_from_kg(&kg);
let mention = EntityMention {
surface: "Banana Corp".to_string(),
span: (0, 11),
ner_type: None,
coref_chain: None,
};
let linked = linker.link_mention(&mention, "Banana Corp sells tropical fruit.", &kg);
assert!(linked.is_nil, "unknown entity should be NIL");
}
#[test]
fn test_coref_propagation() {
let kg = sample_kg();
let mut linker = EntityLinker::new();
linker.build_from_kg(&kg);
let mentions = vec![
EntityMention {
surface: "Apple Inc.".to_string(),
span: (0, 10),
ner_type: Some("Organization".to_string()),
coref_chain: Some(0),
},
EntityMention {
surface: "the company".to_string(),
span: (12, 23),
ner_type: None,
coref_chain: Some(0), },
];
let linked =
linker.link_document(&mentions, "Apple Inc. the company was founded in 1976.", &kg);
assert!(!linked[1].is_nil, "coref mention should be propagated");
}
#[test]
fn test_is_nil_mention() {
let kg = sample_kg();
let mut linker = EntityLinker::new();
linker.build_from_kg(&kg);
assert!(!is_nil_mention("Apple Inc.", &kg, &linker, 0.1));
assert!(is_nil_mention("Unknown Corp XYZ", &kg, &linker, 0.1));
}
}