use crate::{Entity, EntityType, Language, Model, Result};
use anno_core::Lexicon;
use std::sync::Arc;
pub struct LexiconNER {
lexicon: Arc<dyn Lexicon + Send + Sync>,
case_sensitive: bool,
word_boundary: bool,
}
impl LexiconNER {
pub fn new(lexicon: impl Lexicon + 'static) -> Self {
Self {
lexicon: Arc::new(lexicon),
case_sensitive: false,
word_boundary: true,
}
}
pub fn with_case_sensitive(mut self, case_sensitive: bool) -> Self {
self.case_sensitive = case_sensitive;
self
}
pub fn with_word_boundary(mut self, word_boundary: bool) -> Self {
self.word_boundary = word_boundary;
self
}
pub fn lexicon(&self) -> &dyn Lexicon {
self.lexicon.as_ref()
}
}
impl Model for LexiconNER {
fn extract_entities(&self, text: &str, language: Option<Language>) -> Result<Vec<Entity>> {
let mut entities = Vec::new();
let text_chars: Vec<char> = text.chars().collect();
let text_len = text_chars.len();
let is_cjk = language.is_some_and(|l| l.is_cjk());
let is_word_boundary_char = |c: char| -> bool {
if is_cjk {
c.is_whitespace()
|| matches!(
c,
'。' | ',' | '、' | ';' | ':' | '?' | '!' | '・' | '.' | ',' | ';' | ':' | '?' | '!' | '(' | ')' | '[' | ']' | '{' | '}'
)
} else {
!c.is_alphanumeric()
}
};
for start in 0..text_len {
for end in (start + 1)..=text_len.min(start + 50) {
let span_text: String = text_chars[start..end].iter().collect();
if self.word_boundary {
let is_word_start =
start == 0 || is_word_boundary_char(text_chars[start.saturating_sub(1)]);
let is_word_end = end >= text_len || is_word_boundary_char(text_chars[end]);
if !is_word_start || !is_word_end {
continue;
}
}
let matched = if self.case_sensitive {
self.lexicon.lookup(&span_text)
} else {
self.lexicon
.lookup(&span_text)
.or_else(|| {
let lower = span_text.to_lowercase();
if lower != span_text {
self.lexicon.lookup(&lower)
} else {
None
}
})
.or_else(|| {
let mut capitalized = span_text.to_lowercase();
if let Some(first) = capitalized.chars().next() {
capitalized.replace_range(
0..first.len_utf8(),
&first.to_uppercase().to_string(),
);
if capitalized != span_text {
self.lexicon.lookup(&capitalized)
} else {
None
}
} else {
None
}
})
};
if let Some((entity_type, confidence)) = matched {
let char_start = start;
let char_end = end;
let actual_span: String = text.chars().skip(start).take(end - start).collect();
let provenance = anno_core::Provenance {
source: std::borrow::Cow::Borrowed("lexicon"),
method: anno_core::ExtractionMethod::Heuristic,
pattern: Some(std::borrow::Cow::Owned(format!(
"lexicon:{}",
self.lexicon.source()
))),
raw_confidence: Some(confidence),
model_version: None,
timestamp: None,
};
entities.push(Entity::with_provenance(
actual_span,
entity_type,
char_start,
char_end,
confidence,
provenance,
));
break;
}
}
}
entities.sort_by_key(|e| (e.start(), e.end()));
let mut deduped: Vec<Entity> = Vec::new();
for entity in entities {
if deduped.is_empty() || !deduped.last().unwrap().overlaps(&entity) {
deduped.push(entity);
} else {
let last = deduped.last_mut().unwrap();
if entity.end() - entity.start() > last.end() - last.start() {
*last = entity;
}
}
}
Ok(deduped)
}
fn supported_types(&self) -> Vec<EntityType> {
vec![]
}
fn is_available(&self) -> bool {
!self.lexicon.is_empty()
}
fn name(&self) -> &'static str {
"lexicon"
}
fn description(&self) -> &'static str {
"Exact-match lexicon/gazetteer lookup"
}
}
#[cfg(test)]
mod tests {
use super::*;
use anno_core::HashMapLexicon;
#[test]
fn test_lexicon_ner_basic() {
let mut lexicon = HashMapLexicon::new("test");
lexicon.insert("Apple", EntityType::Organization, 0.99);
lexicon.insert("Microsoft", EntityType::Organization, 0.99);
let ner = LexiconNER::new(lexicon);
let entities = ner
.extract_entities("Apple and Microsoft are tech companies.", None)
.unwrap();
assert_eq!(entities.len(), 2);
assert!(entities
.iter()
.any(|e| e.text == "Apple" && e.entity_type == EntityType::Organization));
assert!(entities
.iter()
.any(|e| e.text == "Microsoft" && e.entity_type == EntityType::Organization));
}
#[test]
fn test_lexicon_ner_case_insensitive() {
let mut lexicon = HashMapLexicon::new("test");
lexicon.insert("Apple", EntityType::Organization, 0.99);
let ner = LexiconNER::new(lexicon);
let entities = ner.extract_entities("apple stock rose.", None).unwrap();
assert_eq!(entities.len(), 1);
assert_eq!(entities[0].text, "apple");
}
#[test]
fn test_lexicon_ner_word_boundary() {
let mut lexicon = HashMapLexicon::new("test");
lexicon.insert("Apple", EntityType::Organization, 0.99);
let ner = LexiconNER::new(lexicon);
let entities = ner
.extract_entities("AppleInc is a company.", None)
.unwrap();
assert_eq!(entities.len(), 0);
}
#[test]
fn test_lexicon_ner_no_word_boundary() {
let mut lexicon = HashMapLexicon::new("test");
lexicon.insert("Apple", EntityType::Organization, 0.99);
let ner = LexiconNER::new(lexicon).with_word_boundary(false);
let entities = ner.extract_entities("AppleInc", None).unwrap();
assert!(entities.iter().any(|e| e.text == "Apple"));
}
#[test]
fn test_lexicon_ner_unicode_offsets() {
let mut lexicon = HashMapLexicon::new("test");
lexicon.insert("東京", EntityType::Location, 0.99);
let ner = LexiconNER::new(lexicon);
let text = "Visit 東京 for tourism.";
let entities = ner.extract_entities(text, None).unwrap();
assert_eq!(entities.len(), 1);
let entity = &entities[0];
assert_eq!(entity.text, "東京");
assert!(entity.start() < entity.end());
assert!(entity.end() <= text.chars().count());
}
#[test]
fn test_lexicon_ner_empty_input() {
let mut lexicon = HashMapLexicon::new("test");
lexicon.insert("Apple", EntityType::Organization, 0.99);
let ner = LexiconNER::new(lexicon);
let entities = ner.extract_entities("", None).unwrap();
assert!(entities.is_empty());
}
#[test]
fn test_lexicon_ner_no_matches() {
let mut lexicon = HashMapLexicon::new("test");
lexicon.insert("Apple", EntityType::Organization, 0.99);
let ner = LexiconNER::new(lexicon);
let entities = ner.extract_entities("The quick brown fox.", None).unwrap();
assert!(entities.is_empty());
}
#[test]
fn test_lexicon_ner_multiple_occurrences() {
let mut lexicon = HashMapLexicon::new("test");
lexicon.insert(
"the",
EntityType::custom("DET", anno_core::EntityCategory::Misc),
0.5,
);
let ner = LexiconNER::new(lexicon).with_word_boundary(false);
let entities = ner.extract_entities("the cat in the hat", None).unwrap();
assert!(
entities.len() >= 2,
"expected >= 2 matches, got {}",
entities.len()
);
}
#[test]
fn test_lexicon_ner_overlapping_entries() {
let mut lexicon = HashMapLexicon::new("test");
lexicon.insert("New York", EntityType::Location, 0.9);
lexicon.insert("New York City", EntityType::Location, 0.95);
let ner = LexiconNER::new(lexicon);
let entities = ner
.extract_entities("Visit New York City today.", None)
.unwrap();
assert!(!entities.is_empty(), "should find at least one entity");
assert!(
entities
.iter()
.any(|e| e.text == "New York" || e.text == "New York City"),
"should find New York or New York City, got: {:?}",
entities.iter().map(|e| &e.text).collect::<Vec<_>>()
);
}
#[test]
fn test_lexicon_ner_case_sensitive() {
let mut lexicon = HashMapLexicon::new("test");
lexicon.insert("Apple", EntityType::Organization, 0.99);
let ner = LexiconNER::new(lexicon).with_case_sensitive(true);
let entities = ner.extract_entities("apple stock rose.", None).unwrap();
assert!(
entities.is_empty(),
"case-sensitive should not match 'apple'"
);
let entities = ner.extract_entities("Apple stock rose.", None).unwrap();
assert_eq!(entities.len(), 1);
}
#[test]
fn test_lexicon_ner_entity_offsets_correct() {
let mut lexicon = HashMapLexicon::new("test");
lexicon.insert("Bob", EntityType::Person, 0.99);
let ner = LexiconNER::new(lexicon);
let text = "Hello Bob, welcome!";
let entities = ner.extract_entities(text, None).unwrap();
assert_eq!(entities.len(), 1);
let e = &entities[0];
assert_eq!(e.start(), 6);
assert_eq!(e.end(), 9);
let extracted: String = text
.chars()
.skip(e.start())
.take(e.end() - e.start())
.collect();
assert_eq!(extracted, "Bob");
}
#[test]
fn test_lexicon_ner_metadata() {
let mut lexicon = HashMapLexicon::new("test");
lexicon.insert("test", EntityType::Person, 0.9);
let ner = LexiconNER::new(lexicon);
assert_eq!(ner.name(), "lexicon");
assert!(!ner.description().is_empty());
}
#[test]
fn test_lexicon_ner_mixed_script_offsets() {
let mut lexicon = HashMapLexicon::new("test");
lexicon.insert("Paris", EntityType::Location, 0.99);
let ner = LexiconNER::new(lexicon);
let text = "我住在 Paris 很久了";
let entities = ner.extract_entities(text, None).unwrap();
assert_eq!(entities.len(), 1);
let e = &entities[0];
assert_eq!(e.text, "Paris");
let extracted: String = text
.chars()
.skip(e.start())
.take(e.end() - e.start())
.collect();
assert_eq!(extracted, "Paris");
}
}