use crate::profile::ContextTerms;
use crate::types::{Detection, DetectionExplanation, EntityType, NlpArtifacts};
use std::collections::{HashMap, HashSet};
pub trait ContextEnhancer: Send + Sync {
fn enhance(&self, detections: &mut [Detection], text: &str, artifacts: &NlpArtifacts);
}
#[derive(Clone, Debug)]
pub struct LemmaContextEnhancer {
context: HashMap<EntityType, ContextTerms>,
}
impl LemmaContextEnhancer {
pub fn new(context: HashMap<EntityType, ContextTerms>) -> Self {
Self { context }
}
}
impl ContextEnhancer for LemmaContextEnhancer {
fn enhance(&self, detections: &mut [Detection], _text: &str, artifacts: &NlpArtifacts) {
let use_lemma = artifacts.capabilities.lemma;
for detection in detections.iter_mut() {
let terms = match self.context.get(&detection.entity_type) {
Some(terms) => terms,
None => continue,
};
let matched = find_context_terms(artifacts, detection.start, detection.end, terms, use_lemma);
if matched.is_empty() {
continue;
}
let base = detection.score;
let boost = terms.boost;
detection.score = (base + boost).clamp(0.0, 1.0);
detection.explanation = DetectionExplanation::ContextBoost {
base,
boost,
matched_terms: matched,
};
}
}
}
fn find_context_terms(
artifacts: &NlpArtifacts,
start: usize,
end: usize,
terms: &ContextTerms,
use_lemma: bool,
) -> Vec<String> {
let mut terms_set: HashSet<String> = terms.terms.iter().map(|t| t.to_lowercase()).collect();
let tokens = &artifacts.tokens;
let mut center_index = None;
for (idx, token) in tokens.iter().enumerate() {
if token.start < end && token.end > start {
center_index = Some(idx);
break;
}
}
let center_index = match center_index {
Some(idx) => idx,
None => return Vec::new(),
};
let window = terms.window_tokens;
let start_idx = center_index.saturating_sub(window);
let end_idx = (center_index + window + 1).min(tokens.len());
let mut matched = Vec::new();
for token in &tokens[start_idx..end_idx] {
let candidate = if use_lemma {
token
.lemma
.as_ref()
.map(|s| s.to_lowercase())
.unwrap_or_else(|| token.text.to_lowercase())
} else {
token.text.to_lowercase()
};
if terms_set.remove(&candidate) {
matched.push(candidate);
}
}
matched
}