use crate::{Entity, EntityCategory, EntityType, Language, Model, Result};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct HeuristicCrfConfig {
pub hidden_size: usize,
pub num_layers: usize,
pub dropout: f32,
pub use_char_embeddings: bool,
pub max_seq_len: usize,
}
impl Default for HeuristicCrfConfig {
fn default() -> Self {
Self {
hidden_size: 256,
num_layers: 2,
dropout: 0.5,
use_char_embeddings: true,
max_seq_len: 512,
}
}
}
#[derive(Debug)]
pub struct HeuristicCrfNER {
config: HeuristicCrfConfig,
labels: Vec<String>,
label_to_idx: HashMap<String, usize>,
transitions: Vec<Vec<f64>>,
vocab: HashMap<String, usize>,
#[cfg(feature = "onnx")]
session: Option<ort::session::Session>,
}
impl HeuristicCrfNER {
#[must_use]
pub fn new() -> Self {
Self::with_config(HeuristicCrfConfig::default())
}
#[must_use]
pub fn with_config(config: HeuristicCrfConfig) -> Self {
let labels = vec![
"O".to_string(),
"B-PER".to_string(),
"I-PER".to_string(),
"B-ORG".to_string(),
"I-ORG".to_string(),
"B-LOC".to_string(),
"I-LOC".to_string(),
"B-MISC".to_string(),
"I-MISC".to_string(),
];
let label_to_idx: HashMap<String, usize> = labels
.iter()
.enumerate()
.map(|(i, l)| (l.clone(), i))
.collect();
let n = labels.len();
let mut transitions = vec![vec![0.0; n]; n];
for i in 0..n {
for j in 0..n {
let from_label = &labels[i];
let to_label = &labels[j];
if let Some(entity_type) = to_label.strip_prefix("I-") {
let valid_prev = format!("B-{}", entity_type);
let valid_cont = format!("I-{}", entity_type);
if from_label == &valid_prev || from_label == &valid_cont {
transitions[i][j] = 1.0; } else {
transitions[i][j] = -10.0; }
} else {
transitions[i][j] = 0.0;
}
}
}
Self {
config,
labels,
label_to_idx,
transitions,
vocab: HashMap::new(),
#[cfg(feature = "onnx")]
session: None,
}
}
#[cfg(feature = "onnx")]
pub fn from_onnx(model_path: &str) -> Result<Self> {
use crate::Error;
use ort::session::{builder::GraphOptimizationLevel, Session};
let session = Session::builder()
.map_err(|e| Error::model_init(format!("Failed to create session builder: {}", e)))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| Error::model_init(format!("Failed to set optimization level: {}", e)))?
.commit_from_file(model_path)
.map_err(|e| Error::model_init(format!("Failed to load ONNX model: {}", e)))?;
let mut model = Self::new();
model.session = Some(session);
Ok(model)
}
#[must_use]
pub fn config(&self) -> &HeuristicCrfConfig {
&self.config
}
#[must_use]
pub fn vocab(&self) -> &HashMap<String, usize> {
&self.vocab
}
#[must_use]
pub fn vocab_lookup(&self, word: &str) -> Option<usize> {
self.vocab.get(word).copied()
}
#[must_use]
pub fn labels(&self) -> &[String] {
&self.labels
}
fn tokenize(text: &str) -> Vec<&str> {
text.split_whitespace().collect()
}
fn get_emissions(&self, tokens: &[&str]) -> Vec<Vec<f64>> {
let n_labels = self.labels.len();
let mut emissions = vec![vec![0.0; n_labels]; tokens.len()];
const PERSON_NAMES: &[&str] = &[
"john",
"mary",
"james",
"david",
"michael",
"robert",
"william",
"richard",
"sarah",
"jennifer",
"elizabeth",
"lisa",
"marie",
"jane",
"emily",
"anna",
"barack",
"donald",
"joe",
"george",
"bill",
"hillary",
"satya",
"jeff",
"mr",
"mrs",
"ms",
"dr",
"prof",
"sir",
"lord",
"president",
"ceo",
];
const ORG_NAMES: &[&str] = &[
"google",
"apple",
"microsoft",
"amazon",
"facebook",
"meta",
"tesla",
"ibm",
"intel",
"nvidia",
"oracle",
"cisco",
"adobe",
"netflix",
"uber",
"university",
"institute",
"corporation",
"company",
"inc",
"corp",
"ltd",
"llc",
"foundation",
"association",
"organization",
"department",
"agency",
"fbi",
"cia",
"nsa",
"nasa",
"un",
"nato",
"who",
"imf",
"eu",
"usa",
];
const LOC_NAMES: &[&str] = &[
"new",
"york",
"california",
"texas",
"florida",
"london",
"paris",
"berlin",
"tokyo",
"beijing",
"moscow",
"washington",
"chicago",
"boston",
"seattle",
"san",
"francisco",
"los",
"angeles",
"las",
"vegas",
"united",
"states",
"america",
"china",
"russia",
"germany",
"france",
"japan",
"india",
"brazil",
"city",
"county",
"state",
"country",
"river",
"mountain",
"lake",
"ocean",
];
for (i, token) in tokens.iter().enumerate() {
let lower = token.to_lowercase();
let is_capitalized = token.chars().next().is_some_and(|c| c.is_uppercase());
let is_all_caps = token
.chars()
.all(|c| c.is_uppercase() || !c.is_alphabetic())
&& token.len() > 1;
let has_digit = token.chars().any(|c| c.is_ascii_digit());
let is_first = i == 0;
emissions[i][0] = 1.5;
if PERSON_NAMES.contains(&lower.as_str()) {
emissions[i][self.label_to_idx["B-PER"]] += 2.0;
emissions[i][self.label_to_idx["I-PER"]] += 1.0;
}
if ORG_NAMES.contains(&lower.as_str()) {
emissions[i][self.label_to_idx["B-ORG"]] += 2.0;
emissions[i][self.label_to_idx["I-ORG"]] += 1.0;
}
if LOC_NAMES.contains(&lower.as_str()) {
emissions[i][self.label_to_idx["B-LOC"]] += 2.0;
emissions[i][self.label_to_idx["I-LOC"]] += 1.0;
}
if is_capitalized && !has_digit && !is_first {
emissions[i][self.label_to_idx["B-PER"]] += 0.8;
emissions[i][self.label_to_idx["B-ORG"]] += 0.6;
emissions[i][self.label_to_idx["B-LOC"]] += 0.5;
}
if lower.ends_with("inc.")
|| lower.ends_with("corp.")
|| lower.ends_with("ltd.")
|| lower.ends_with("llc")
|| lower.ends_with("co.")
{
emissions[i][self.label_to_idx["B-ORG"]] += 1.5;
emissions[i][self.label_to_idx["I-ORG"]] += 1.0;
}
if is_all_caps && token.len() >= 2 && token.len() <= 5 && !has_digit {
emissions[i][self.label_to_idx["B-ORG"]] += 1.2;
}
if ["mr.", "mrs.", "ms.", "dr.", "prof."].contains(&lower.as_str()) {
emissions[i][self.label_to_idx["B-PER"]] += 1.5;
}
if i > 0 && tokens[i - 1].to_lowercase() == "the" && is_capitalized {
emissions[i][self.label_to_idx["B-ORG"]] += 0.5;
emissions[i][self.label_to_idx["B-LOC"]] += 0.3;
}
if i > 0 {
let prev_cap = tokens[i - 1]
.chars()
.next()
.is_some_and(|c| c.is_uppercase());
if prev_cap && is_capitalized && !is_first {
emissions[i][self.label_to_idx["I-PER"]] += 0.6;
emissions[i][self.label_to_idx["I-ORG"]] += 0.6;
emissions[i][self.label_to_idx["I-LOC"]] += 0.4;
}
}
}
emissions
}
fn viterbi_decode(&self, emissions: &[Vec<f64>]) -> Vec<usize> {
if emissions.is_empty() {
return vec![];
}
let n = emissions.len();
let m = self.labels.len();
let mut scores = vec![vec![f64::NEG_INFINITY; m]; n];
let mut backpointers = vec![vec![0usize; m]; n];
for j in 0..m {
scores[0][j] = emissions[0][j];
}
for i in 1..n {
for j in 0..m {
let mut best_score = f64::NEG_INFINITY;
let mut best_prev = 0;
#[allow(clippy::needless_range_loop)]
for k in 0..m {
let score = scores[i - 1][k] + self.transitions[k][j] + emissions[i][j];
if score > best_score {
best_score = score;
best_prev = k;
}
}
scores[i][j] = best_score;
backpointers[i][j] = best_prev;
}
}
let mut path = vec![0usize; n];
let mut best_final = 0;
let mut best_score = f64::NEG_INFINITY;
for (j, &score) in scores[n - 1].iter().enumerate() {
if score > best_score {
best_score = score;
best_final = j;
}
}
path[n - 1] = best_final;
for i in (0..n - 1).rev() {
path[i] = backpointers[i + 1][path[i + 1]];
}
path
}
fn labels_to_entities(
&self,
text: &str,
tokens: &[&str],
label_indices: &[usize],
) -> Vec<Entity> {
use crate::offset::SpanConverter;
let converter = SpanConverter::new(text);
let mut entities = Vec::new();
let token_positions: Vec<(usize, usize)> = Self::calculate_token_positions(text, tokens);
let mut current_entity: Option<(usize, usize, EntityType, Vec<&str>)> = None;
for (i, (&label_idx, &token)) in label_indices.iter().zip(tokens.iter()).enumerate() {
let label = &self.labels[label_idx];
if let Some(entity_suffix) = label.strip_prefix("B-") {
if let Some((start_token_idx, end_token_idx, entity_type, words)) =
current_entity.take()
{
Self::push_entity_from_positions(
&converter,
&token_positions,
start_token_idx,
end_token_idx,
&words,
entity_type,
&mut entities,
);
}
let entity_type = match entity_suffix {
"PER" => EntityType::Person,
"ORG" => EntityType::Organization,
"LOC" => EntityType::Location,
other => EntityType::custom(other, EntityCategory::Misc),
};
current_entity = Some((i, i, entity_type, vec![token]));
} else if label.starts_with("I-") && current_entity.is_some() {
if let Some((_, ref mut end_idx, _, ref mut words)) = current_entity {
words.push(token);
*end_idx = i;
}
} else {
if let Some((start_token_idx, end_token_idx, entity_type, words)) =
current_entity.take()
{
Self::push_entity_from_positions(
&converter,
&token_positions,
start_token_idx,
end_token_idx,
&words,
entity_type,
&mut entities,
);
}
}
}
if let Some((start_token_idx, end_token_idx, entity_type, words)) = current_entity.take() {
Self::push_entity_from_positions(
&converter,
&token_positions,
start_token_idx,
end_token_idx,
&words,
entity_type,
&mut entities,
);
}
entities
}
fn calculate_token_positions(text: &str, tokens: &[&str]) -> Vec<(usize, usize)> {
let mut positions = Vec::with_capacity(tokens.len());
let mut byte_pos = 0;
for token in tokens {
if let Some(rel_pos) = text[byte_pos..].find(token) {
let start = byte_pos + rel_pos;
let end = start + token.len();
positions.push((start, end));
byte_pos = end; } else {
positions.push((byte_pos, byte_pos));
}
}
positions
}
fn push_entity_from_positions(
converter: &crate::offset::SpanConverter,
positions: &[(usize, usize)],
start_token_idx: usize,
end_token_idx: usize,
words: &[&str],
entity_type: EntityType,
entities: &mut Vec<Entity>,
) {
if start_token_idx >= positions.len() || end_token_idx >= positions.len() {
return;
}
let byte_start = positions[start_token_idx].0;
let byte_end = positions[end_token_idx].1;
let char_start = converter.byte_to_char(byte_start);
let char_end = converter.byte_to_char(byte_end);
let entity_text = words.join(" ");
entities.push(Entity::new(
entity_text,
entity_type,
char_start,
char_end,
0.75, ));
}
}
impl Default for HeuristicCrfNER {
fn default() -> Self {
Self::new()
}
}
impl Model for HeuristicCrfNER {
fn name(&self) -> &'static str {
"heuristic-crf"
}
fn description(&self) -> &'static str {
"CRF sequence labeling with heuristic emission features (capitalization, word shape, gazetteer)"
}
fn extract_entities(&self, text: &str, _language: Option<Language>) -> Result<Vec<Entity>> {
if text.trim().is_empty() {
return Ok(vec![]);
}
let tokens = Self::tokenize(text);
if tokens.is_empty() {
return Ok(vec![]);
}
let emissions = self.get_emissions(&tokens);
let label_indices = self.viterbi_decode(&emissions);
let entities = self.labels_to_entities(text, &tokens, &label_indices);
Ok(entities)
}
fn supported_types(&self) -> Vec<EntityType> {
vec![
EntityType::Person,
EntityType::Organization,
EntityType::Location,
EntityType::custom("MISC", EntityCategory::Misc),
]
}
fn is_available(&self) -> bool {
true }
fn capabilities(&self) -> crate::ModelCapabilities {
crate::ModelCapabilities::default()
}
}
#[cfg(test)]
mod tests;