use super::heuristic::HeuristicNER;
use super::regex::RegexNER;
use crate::{Entity, EntityType, Language, Model, Result};
use itertools::Itertools;
use std::borrow::Cow;
use std::sync::Arc;
use crate::backends::method_for_backend_name;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum ConflictStrategy {
#[default]
Priority,
LongestSpan,
HighestConf,
Union,
}
impl ConflictStrategy {
fn resolve(&self, existing: &Entity, candidate: &Entity) -> Resolution {
if candidate.start() <= existing.start()
&& candidate.end() >= existing.end()
&& candidate.end() > candidate.start()
&& is_structured_type(&candidate.entity_type)
&& is_generic_type(&existing.entity_type)
{
return Resolution::Replace;
}
match self {
ConflictStrategy::Priority => Resolution::KeepExisting,
ConflictStrategy::LongestSpan => {
let existing_len = existing.end() - existing.start();
let candidate_len = candidate.end() - candidate.start();
if candidate_len > existing_len {
Resolution::Replace
} else if candidate_len < existing_len {
Resolution::KeepExisting
} else {
Resolution::KeepExisting
}
}
ConflictStrategy::HighestConf => {
if candidate.confidence > existing.confidence {
Resolution::Replace
} else if candidate.confidence < existing.confidence {
Resolution::KeepExisting
} else {
Resolution::KeepExisting
}
}
ConflictStrategy::Union => Resolution::KeepBoth,
}
}
}
fn is_structured_type(t: &EntityType) -> bool {
matches!(
t,
EntityType::Money
| EntityType::Date
| EntityType::Time
| EntityType::Percent
| EntityType::Email
| EntityType::Url
| EntityType::Phone
| EntityType::Quantity
)
}
fn is_generic_type(t: &EntityType) -> bool {
matches!(t, EntityType::Custom { .. })
}
#[derive(Debug)]
enum Resolution {
KeepExisting,
Replace,
KeepBoth,
}
pub struct StackedNER {
layers: Vec<Arc<dyn Model + Send + Sync>>,
strategy: ConflictStrategy,
name: String,
name_static: std::sync::OnceLock<&'static str>,
}
#[derive(Default)]
pub struct StackedNERBuilder {
layers: Vec<Box<dyn Model + Send + Sync>>,
strategy: ConflictStrategy,
}
impl StackedNERBuilder {
#[must_use]
pub fn layer<M: Model + Send + Sync + 'static>(mut self, model: M) -> Self {
self.layers.push(Box::new(model));
self
}
#[must_use]
pub fn layer_boxed(mut self, model: Box<dyn Model + Send + Sync>) -> Self {
self.layers.push(model);
self
}
#[must_use]
pub fn strategy(mut self, strategy: ConflictStrategy) -> Self {
self.strategy = strategy;
self
}
#[must_use]
pub fn build(self) -> StackedNER {
self.try_build().expect(
"StackedNER requires at least one layer. Use StackedNER::builder().layer(...).build()",
)
}
pub fn try_build(self) -> crate::Result<StackedNER> {
if self.layers.is_empty() {
return Err(crate::Error::InvalidInput(
"StackedNER requires at least one layer".to_string(),
));
}
let name = format!(
"stacked({})",
self.layers
.iter()
.map(|l| l.name())
.collect::<Vec<_>>()
.join("+")
);
Ok(StackedNER {
layers: self.layers.into_iter().map(Arc::from).collect(),
strategy: self.strategy,
name,
name_static: std::sync::OnceLock::new(),
})
}
}
impl StackedNER {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn builder() -> StackedNERBuilder {
StackedNERBuilder::default()
}
#[must_use]
pub fn with_layers(layers: Vec<Box<dyn Model + Send + Sync>>) -> Self {
let mut builder = Self::builder();
for layer in layers {
builder = builder.layer_boxed(layer);
}
builder.build()
}
#[must_use]
pub fn with_heuristic_threshold(_threshold: f64) -> Self {
Self::builder()
.layer(RegexNER::new())
.layer(HeuristicNER::new())
.build()
}
#[must_use]
pub fn pattern_only() -> Self {
Self::builder().layer(RegexNER::new()).build()
}
#[must_use]
pub fn heuristic_only() -> Self {
Self::builder().layer(HeuristicNER::new()).build()
}
#[must_use]
pub fn with_ml_first(ml_backend: Box<dyn Model + Send + Sync>) -> Self {
Self::builder()
.layer_boxed(ml_backend)
.layer(RegexNER::new())
.layer(HeuristicNER::new())
.build()
}
#[must_use]
pub fn with_ml_fallback(ml_backend: Box<dyn Model + Send + Sync>) -> Self {
Self::builder()
.layer(RegexNER::new())
.layer(HeuristicNER::new())
.layer_boxed(ml_backend)
.build()
}
#[must_use]
pub fn num_layers(&self) -> usize {
self.layers.len()
}
#[must_use]
pub fn layer_names(&self) -> Vec<String> {
self.layers
.iter()
.map(|l| l.name().to_string())
.collect_vec()
}
#[must_use]
pub fn strategy(&self) -> ConflictStrategy {
self.strategy
}
#[must_use]
pub fn stats(&self) -> StackStats {
StackStats {
layer_count: self.layers.len(),
strategy: self.strategy,
layer_names: self.layer_names(),
}
}
}
#[derive(Debug, Clone)]
pub struct StackStats {
pub layer_count: usize,
pub strategy: ConflictStrategy,
pub layer_names: Vec<String>,
}
impl Default for StackedNER {
fn default() -> Self {
#[cfg(feature = "onnx")]
{
use crate::backends::onnx::BertNEROnnx;
use crate::DEFAULT_BERT_ONNX_MODEL;
let bert = BertNEROnnx::new(DEFAULT_BERT_ONNX_MODEL).ok();
let nuner = crate::backends::nuner::NuNER::from_pretrained(crate::DEFAULT_NUNER_MODEL)
.map(|n| n.with_threshold(0.9))
.ok();
if bert.is_some() || nuner.is_some() {
let mut builder = Self::builder();
if let Some(b) = bert {
builder = builder.layer_boxed(Box::new(b));
}
if let Some(n) = nuner {
builder = builder.layer_boxed(Box::new(n));
}
return builder
.layer(RegexNER::new())
.layer(HeuristicNER::new())
.build();
}
use crate::{GLiNEROnnx, DEFAULT_GLINER_MODEL};
if let Ok(gliner) = GLiNEROnnx::new(DEFAULT_GLINER_MODEL) {
return Self::builder()
.layer_boxed(Box::new(gliner))
.layer(RegexNER::new())
.layer(HeuristicNER::new())
.build();
}
}
Self::builder()
.layer(RegexNER::new())
.layer(HeuristicNER::new())
.build()
}
}
impl Model for StackedNER {
#[cfg_attr(feature = "production", tracing::instrument(skip(self, text), fields(text_len = text.len(), num_layers = self.layers.len())))]
fn extract_entities(&self, text: &str, language: Option<Language>) -> Result<Vec<Entity>> {
let mut entities: Vec<Entity> = Vec::with_capacity(16);
let mut layer_errors = Vec::new();
let text_char_count = text.chars().count();
let skip_nuner = !text_may_need_nuner(text);
for layer in &self.layers {
let layer_name = layer.name();
if skip_nuner && layer_name.to_lowercase().contains("nuner") {
log::debug!("StackedNER: skipping NuNER (text appears well-capitalized)");
continue;
}
let layer_entities = match layer.extract_entities(text, language) {
Ok(ents) => ents,
Err(e) => {
layer_errors.push((layer_name.to_string(), e));
continue;
}
};
for mut candidate in layer_entities {
if candidate.end() > text_char_count {
log::debug!(
"StackedNER: Clamping entity end offset from {} to {} (text length: {})",
candidate.end(),
text_char_count,
text_char_count
);
candidate.set_end(text_char_count);
if candidate.start() < candidate.end() {
candidate.text = crate::offset::TextSpan::from_chars(
text,
candidate.start(),
candidate.end(),
)
.extract(text)
.to_string();
}
}
if candidate.start() >= candidate.end() || candidate.start() > text_char_count {
log::debug!(
"StackedNER: Skipping entity with invalid span: start={}, end={}, text_len={}",
candidate.start(),
candidate.end(),
text_char_count
);
continue;
}
if candidate.provenance.is_none() {
candidate.provenance = Some(crate::Provenance {
source: Cow::Borrowed(layer_name),
method: method_for_backend_name(layer_name),
pattern: None,
raw_confidence: Some(candidate.confidence),
model_version: None,
timestamp: None,
});
}
let overlapping_indices: Vec<usize> = entities
.iter()
.enumerate()
.filter_map(|(idx, e)| {
if candidate.end() > e.start() && candidate.start() < e.end() {
Some(idx)
} else {
None
}
})
.collect();
match overlapping_indices.len() {
0 => {
entities.push(candidate);
}
1 => {
let idx = overlapping_indices[0];
match self.strategy.resolve(&entities[idx], &candidate) {
Resolution::KeepExisting => {}
Resolution::Replace => {
entities[idx] = candidate;
}
Resolution::KeepBoth => {
entities.push(candidate);
}
}
}
_ => {
let best_idx = overlapping_indices
.iter()
.max_by(|&&a, &&b| {
match self.strategy {
ConflictStrategy::Priority => {
a.cmp(&b).reverse()
}
ConflictStrategy::LongestSpan => {
let len_a = entities[a].end() - entities[a].start();
let len_b = entities[b].end() - entities[b].start();
len_a.cmp(&len_b).then_with(|| b.cmp(&a))
}
ConflictStrategy::HighestConf => entities[a]
.confidence
.partial_cmp(&entities[b].confidence)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| b.cmp(&a)),
ConflictStrategy::Union => {
a.cmp(&b)
}
}
})
.copied()
.unwrap_or(overlapping_indices[0]);
match self.strategy {
ConflictStrategy::Union => {
entities.push(candidate);
}
_ => {
match self.strategy.resolve(&entities[best_idx], &candidate) {
Resolution::KeepExisting => {
let mut to_remove: Vec<usize> = overlapping_indices
.into_iter()
.filter(|&idx| idx != best_idx)
.collect();
to_remove.sort_unstable_by(|a, b| b.cmp(a));
for idx in to_remove {
entities.remove(idx);
}
}
Resolution::Replace => {
let mut to_remove: Vec<usize> = overlapping_indices
.into_iter()
.filter(|&idx| idx != best_idx)
.collect();
to_remove.sort_unstable_by(|a, b| b.cmp(a));
let removed_before_best =
to_remove.iter().filter(|&&idx| idx < best_idx).count();
let adjusted_best_idx = best_idx - removed_before_best;
for idx in to_remove {
entities.remove(idx);
}
entities[adjusted_best_idx] = candidate;
}
Resolution::KeepBoth => {
let mut to_remove: Vec<usize> = overlapping_indices
.into_iter()
.filter(|&idx| idx != best_idx)
.collect();
to_remove.sort_unstable_by(|a, b| b.cmp(a));
for idx in to_remove {
entities.remove(idx);
}
entities.push(candidate);
}
}
}
}
}
}
}
}
entities.sort_unstable_by(|a, b| {
let a_ty = a.entity_type.as_label();
let b_ty = b.entity_type.as_label();
let a_src = a
.provenance
.as_ref()
.map(|p| p.source.as_ref())
.unwrap_or("");
let b_src = b
.provenance
.as_ref()
.map(|p| p.source.as_ref())
.unwrap_or("");
(a.start(), a.end(), a_ty, a_src, a.text.as_str()).cmp(&(
b.start(),
b.end(),
b_ty,
b_src,
b.text.as_str(),
))
});
if self.strategy != ConflictStrategy::Union {
entities.dedup_by(|a, b| {
a.start() == b.start() && a.end() == b.end() && a.entity_type == b.entity_type
});
}
if entities.is_empty() && layer_errors.len() == self.layers.len() {
if let Some((_, last_err)) = layer_errors.pop() {
return Err(last_err);
}
}
if !layer_errors.is_empty() && !entities.is_empty() {
log::warn!(
"StackedNER: Some layers failed but returning partial results. Errors: {:?}",
layer_errors
.iter()
.map(|(n, e)| format!("{n}: {e}"))
.collect::<Vec<_>>()
);
}
heal_adjacent_spans(text, &mut entities);
extend_person_spans(text, &mut entities);
filter_title_words(&mut entities);
for entity in &entities {
if entity.start() >= entity.end() {
log::warn!(
"StackedNER: Invalid entity span detected: start={}, end={}, text={:?}, type={:?}",
entity.start(),
entity.end(),
entity.text,
entity.entity_type
);
}
}
Ok(entities)
}
fn supported_types(&self) -> Vec<EntityType> {
self.layers
.iter()
.flat_map(|layer| layer.supported_types())
.sorted_by(|a, b| format!("{:?}", a).cmp(&format!("{:?}", b)))
.dedup()
.collect_vec()
}
fn is_available(&self) -> bool {
self.layers.iter().any(|l| l.is_available())
}
fn name(&self) -> &'static str {
self.name_static
.get_or_init(|| Box::leak(self.name.clone().into_boxed_str()))
}
fn description(&self) -> &'static str {
"Stacked NER (multi-backend composition)"
}
fn capabilities(&self) -> crate::ModelCapabilities {
crate::ModelCapabilities::default()
}
}
fn heal_adjacent_spans(text: &str, entities: &mut Vec<Entity>) {
if entities.len() < 2 {
return;
}
let mut merged = Vec::with_capacity(entities.len());
let mut current = entities[0].clone();
for next in entities.iter().skip(1) {
let gap = next.start().saturating_sub(current.end());
let truly_adjacent = next.start() >= current.end() && next.start() > current.start();
let same_type = current.entity_type == next.entity_type;
let gap_ok = truly_adjacent
&& gap <= 1
&& (gap == 0
|| text
.chars()
.nth(current.end())
.is_some_and(|c| c.is_alphanumeric() || c == ' '));
let next_len = next.end().saturating_sub(next.start());
let is_fragment = truly_adjacent && gap == 0 && next_len <= 3;
if (same_type && gap_ok) || is_fragment {
current.set_end(next.end());
current.text = text
.chars()
.skip(current.start())
.take(current.end() - current.start())
.collect();
if next.confidence > current.confidence {
current.confidence = next.confidence;
}
} else {
merged.push(current);
current = next.clone();
}
}
merged.push(current);
*entities = merged;
}
fn extend_person_spans(text: &str, entities: &mut [Entity]) {
const NON_NAME_WORDS: &[&str] = &[
"the",
"a",
"an",
"this",
"that",
"these",
"those",
"it",
"he",
"she",
"we",
"they",
"in",
"on",
"at",
"to",
"for",
"from",
"by",
"with",
"and",
"but",
"or",
"so",
"if",
"is",
"are",
"was",
"were",
"be",
"been",
"have",
"has",
"had",
"what",
"where",
"when",
"who",
"why",
"how",
"here",
"about",
"more",
"next",
"back",
"home",
"however",
"meanwhile",
"furthermore",
"moreover",
"therefore",
"although",
"indeed",
"perhaps",
"certainly",
"no",
"yes",
"some",
"many",
"each",
"every",
"both",
"all",
"few",
"several",
"other",
"another",
"monday",
"tuesday",
"wednesday",
"thursday",
"friday",
"saturday",
"sunday",
"january",
"february",
"march",
"april",
"may",
"june",
"july",
"august",
"september",
"october",
"november",
"december",
"ceo",
"cto",
"cfo",
"coo",
"vp",
"president",
"chairman",
"director",
"manager",
"secretary",
"minister",
"kanzler",
"bundeskanzler",
"phone",
"fax",
"mobile",
"address",
"website",
"name",
"company",
"contact",
];
let text_chars: Vec<char> = text.chars().collect();
let text_len = text_chars.len();
let occupied: Vec<(usize, usize)> = entities.iter().map(|e| (e.start(), e.end())).collect();
let mut changed = true;
while changed {
changed = false;
for i in 0..entities.len() {
if entities[i].entity_type != EntityType::Person {
continue;
}
let end = entities[i].end();
if end >= text_len {
continue;
}
let mut pos = end;
while pos < text_len && text_chars[pos].is_whitespace() {
pos += 1;
}
if pos >= text_len || pos == end {
continue;
}
if !text_chars[pos].is_uppercase() {
continue;
}
let word_start = pos;
while pos < text_len
&& !text_chars[pos].is_whitespace()
&& text_chars[pos] != ','
&& text_chars[pos] != '.'
&& text_chars[pos] != ';'
&& text_chars[pos] != ':'
&& text_chars[pos] != '('
&& text_chars[pos] != ')'
{
pos += 1;
}
let word_end = pos;
let overlaps_existing = occupied.iter().any(|&(s, e)| {
word_start < e && word_end > s
});
if overlaps_existing {
continue;
}
let word: String = text_chars[word_start..word_end].iter().collect();
let word_lower = word
.trim_end_matches(|c: char| !c.is_alphanumeric())
.to_lowercase();
if NON_NAME_WORDS.contains(&word_lower.as_str()) {
continue;
}
entities[i].set_end(word_end);
entities[i].text = text_chars[entities[i].start()..word_end].iter().collect();
changed = true;
}
}
}
fn filter_title_words(entities: &mut Vec<Entity>) {
const TITLE_WORDS: &[&str] = &[
"bundeskanzler",
"bundeskanzlerin",
"kanzler",
"kanzlerin",
"bundespraesident",
"bundespraesidentin",
"buergermeister",
"buergermeisterin",
"president",
"chairman",
"chairwoman",
"director",
"secretary",
"minister",
"chancellor",
"governor",
"senator",
"congressman",
"congresswoman",
"mayor",
];
const COMMON_NOUNS_NOT_PER: &[&str] = &[
"death",
"police",
"military",
"authorities",
"officials",
"analysts",
"scientists",
"researchers",
"experts",
"voters",
"residents",
];
entities.retain(|e| {
if e.text.contains(' ') {
return true;
}
let lower = e.text.to_lowercase();
if matches!(
e.entity_type,
EntityType::Organization | EntityType::Custom { .. }
) && TITLE_WORDS.contains(&lower.as_str())
{
return false;
}
if matches!(e.entity_type, EntityType::Person)
&& COMMON_NOUNS_NOT_PER.contains(&lower.as_str())
{
return false;
}
true
});
}
fn text_may_need_nuner(text: &str) -> bool {
if text.len() < 30 {
return true;
}
const STOPWORDS: &[&str] = &[
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by",
"from", "is", "are", "was", "were", "be", "been", "has", "have", "had", "that", "this",
"it", "he", "she", "we", "they", "not", "would", "could", "should", "will", "can", "may",
"also", "its", "his", "her", "our", "their", "who", "which", "what", "when", "where",
"how", "than", "then", "into", "over", "about", "after", "before", "between", "under",
"up", "out", "new", "said", "told", "expects",
];
let mut consecutive_lc = 0u32;
for word in text.split_whitespace() {
let clean = word.trim_matches(|c: char| !c.is_alphanumeric());
if clean.len() <= 2 {
consecutive_lc = 0;
continue;
}
let first = match clean.chars().next() {
Some(c) if c.is_alphabetic() => c,
_ => {
consecutive_lc = 0;
continue;
}
};
if first.is_lowercase() && !STOPWORDS.contains(&clean.to_lowercase().as_str()) {
consecutive_lc += 1;
if consecutive_lc >= 2 {
return true;
}
} else {
consecutive_lc = 0;
}
}
false
}
#[cfg(test)]
mod tests;