use serde::{Deserialize, Serialize};
use crate::memory::score::extract::{EntityKind, ExtractedEntities};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct CanonicalEntity {
pub canonical_id: String,
pub kind: EntityKind,
pub surface: String,
pub span_start: u32,
pub span_end: u32,
pub score: f32,
}
pub fn canonicalise(extracted: &ExtractedEntities) -> Vec<CanonicalEntity> {
let mut out: Vec<CanonicalEntity> = extracted
.entities
.iter()
.map(|e| CanonicalEntity {
canonical_id: canonical_id_for(e.kind, &e.text),
kind: e.kind,
surface: e.text.clone(),
span_start: e.span_start,
span_end: e.span_end,
score: e.score,
})
.collect();
for topic in &extracted.topics {
let canonical_id = canonical_id_for(EntityKind::Topic, &topic.label);
if out
.iter()
.any(|e| e.kind == EntityKind::Topic && e.canonical_id == canonical_id)
{
continue;
}
out.push(CanonicalEntity {
canonical_id,
kind: EntityKind::Topic,
surface: topic.label.clone(),
span_start: 0,
span_end: 0,
score: topic.score,
});
}
out
}
pub fn canonical_id_for(kind: EntityKind, surface: &str) -> String {
let trimmed = surface.trim();
let clean = if kind == EntityKind::Url {
trimmed.to_string()
} else {
trimmed
.to_lowercase()
.trim_start_matches('@')
.trim_start_matches('#')
.to_string()
};
format!("{}:{}", kind.as_str(), clean)
}
#[cfg(test)]
#[path = "resolver_tests.rs"]
mod tests;