use crate::error::{InferenceError, Result};
use ort::inputs;
use ort::session::builder::GraphOptimizationLevel;
use ort::session::Session;
use ort::value::Tensor;
use parking_lot::Mutex;
use regex::Regex;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokenizers::Tokenizer;
use tracing::{debug, info, instrument, warn};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ExtractedEntity {
pub entity_type: String,
pub value: String,
pub score: f32,
pub start: usize,
pub end: usize,
}
impl ExtractedEntity {
pub fn to_tag(&self) -> String {
let normalized_value = normalize_tag_value(&self.value);
format!("entity:{}:{}", self.entity_type, normalized_value)
}
pub fn dedup_key(&self) -> (String, String) {
(self.entity_type.clone(), normalize_tag_value(&self.value))
}
}
pub fn normalize_label(label: &str) -> String {
label.trim().to_lowercase().replace(' ', "_")
}
fn normalize_tag_value(value: &str) -> String {
value
.split_whitespace()
.collect::<Vec<_>>()
.join(" ")
.to_lowercase()
.replace(':', "_")
}
pub fn deduplicate_entities(mut entities: Vec<ExtractedEntity>) -> Vec<ExtractedEntity> {
entities.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut seen: HashMap<(String, String), ()> = HashMap::new();
let mut out: Vec<ExtractedEntity> = Vec::with_capacity(entities.len());
for entity in entities {
let key = entity.dedup_key();
if seen.insert(key, ()).is_none() {
out.push(entity);
}
}
out.sort_by_key(|e| e.start);
out
}
struct RulePatterns {
uuid: Regex,
url: Regex,
email: Regex,
iso_date: Regex,
natural_date: Regex,
ip_v4: Regex,
}
impl RulePatterns {
fn new() -> Self {
Self {
uuid: Regex::new(
r"(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b",
)
.expect("uuid regex"),
url: Regex::new(r#"https?://[^\s<>\[\]()"']+"#).expect("url regex"),
email: Regex::new(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}")
.expect("email regex"),
iso_date: Regex::new(
r"\b\d{4}-(?:0[1-9]|1[0-2])-(?:0[1-9]|[12]\d|3[01])\b",
)
.expect("iso_date regex"),
natural_date: Regex::new(
r"(?i)\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:,\s*\d{4})?\b",
)
.expect("natural_date regex"),
ip_v4: Regex::new(
r"\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b",
)
.expect("ipv4 regex"),
}
}
}
lazy_static::lazy_static! {
static ref RULE_PATTERNS: RulePatterns = RulePatterns::new();
}
pub fn rule_based_extract(text: &str) -> Vec<ExtractedEntity> {
let mut entities: Vec<ExtractedEntity> = Vec::new();
let push = |entities: &mut Vec<ExtractedEntity>, entity_type: &str, m: regex::Match| {
entities.push(ExtractedEntity {
entity_type: entity_type.to_string(),
value: m.as_str().trim().to_string(),
score: 1.0,
start: m.start(),
end: m.end(),
});
};
for m in RULE_PATTERNS.email.find_iter(text) {
push(&mut entities, "email", m);
}
for m in RULE_PATTERNS.url.find_iter(text) {
if !entities.iter().any(|e| e.start == m.start()) {
push(&mut entities, "url", m);
}
}
for m in RULE_PATTERNS.uuid.find_iter(text) {
push(&mut entities, "uuid", m);
}
for m in RULE_PATTERNS.iso_date.find_iter(text) {
push(&mut entities, "date", m);
}
for m in RULE_PATTERNS.natural_date.find_iter(text) {
if !entities
.iter()
.any(|e| e.start == m.start() && e.entity_type == "date")
{
push(&mut entities, "date", m);
}
}
for m in RULE_PATTERNS.ip_v4.find_iter(text) {
push(&mut entities, "ip", m);
}
entities
}
const GLINER_MODEL_REPO: &str = "onnx-community/gliner_medium-v2.1";
const GLINER_TOKENIZER_REPO: &str = "onnx-community/gliner_medium-v2.1";
const GLINER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
const MAX_SPAN_WIDTH: usize = 12;
const DEFAULT_SCORE_THRESHOLD: f32 = 0.5;
const MAX_TEXT_WORDS: usize = 300;
pub struct GlinerEngine {
session: Arc<Mutex<Session>>,
tokenizer: Arc<Tokenizer>,
}
impl GlinerEngine {
#[instrument(skip_all)]
pub async fn new(num_threads: Option<usize>) -> Result<Self> {
let threads = num_threads.unwrap_or(1);
info!("Initializing GLiNER NER engine (threads={})", threads);
let (tokenizer_path, onnx_path) = Self::download_model_files().await?;
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
let session = Session::builder()
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
.with_intra_threads(threads)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
.commit_from_file(&onnx_path)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
info!("GLiNER engine ready");
Ok(Self {
session: Arc::new(Mutex::new(session)),
tokenizer: Arc::new(tokenizer),
})
}
pub async fn extract(&self, text: &str, entity_types: &[&str]) -> Result<Vec<ExtractedEntity>> {
if entity_types.is_empty() || text.is_empty() {
return Ok(Vec::new());
}
let text_owned = text.to_string();
let entity_types_owned: Vec<String> = entity_types.iter().map(|s| s.to_string()).collect();
let session = self.session.clone();
let tokenizer = self.tokenizer.clone();
tokio::task::spawn_blocking(move || {
Self::run_inference(
&text_owned,
&entity_types_owned
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>(),
&session,
&tokenizer,
)
})
.await
.map_err(|e| InferenceError::HubError(format!("GLiNER inference task panicked: {}", e)))?
}
fn run_inference(
text: &str,
entity_types: &[&str],
session: &Arc<Mutex<Session>>,
tokenizer: &Tokenizer,
) -> Result<Vec<ExtractedEntity>> {
let text = truncate_to_word_limit(text, MAX_TEXT_WORDS);
let prefix = entity_types.join(" << >> ");
let prefix_plus_sep = format!("{} << >> ", prefix);
let full_text = format!("{}{}", prefix_plus_sep, text);
let encoding = tokenizer
.encode(full_text.as_str(), true)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
let token_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
let attention_mask: Vec<i64> = encoding
.get_attention_mask()
.iter()
.map(|&x| x as i64)
.collect();
let seq_len = token_ids.len();
let prefix_encoding = tokenizer
.encode(prefix_plus_sep.as_str(), false)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
let prefix_word_count = count_distinct_word_ids(prefix_encoding.get_word_ids());
let word_ids = encoding.get_word_ids();
let token_offsets = encoding.get_offsets();
let mut words_mask = vec![0i64; seq_len];
let mut last_word_id: Option<u32> = None;
let mut cumulative_word_count = 0usize; let mut text_word_count = 0usize;
let mut text_word_ids: Vec<u32> = Vec::new();
let mut word_byte_ranges: HashMap<u32, (usize, usize)> = HashMap::new();
for (i, &wid_opt) in word_ids.iter().enumerate() {
let wid = match wid_opt {
Some(w) => w,
None => {
last_word_id = None;
continue;
}
};
let (tok_start, tok_end) = token_offsets[i];
let entry = word_byte_ranges.entry(wid).or_insert((tok_start, tok_end));
if tok_start < entry.0 {
entry.0 = tok_start;
}
if tok_end > entry.1 {
entry.1 = tok_end;
}
let is_new_word = last_word_id.map(|lw| lw != wid).unwrap_or(true);
if is_new_word {
if cumulative_word_count >= prefix_word_count {
words_mask[i] = 1;
text_word_count += 1;
text_word_ids.push(wid);
}
cumulative_word_count += 1;
}
last_word_id = Some(wid);
}
if text_word_count == 0 {
debug!("No text words after entity type prefix — skipping inference");
return Ok(Vec::new());
}
let text_lengths = vec![text_word_count as i64];
let prefix_byte_offset = prefix_plus_sep.len();
let mut span_idx_flat: Vec<i64> = Vec::new();
let mut span_mask: Vec<bool> = Vec::new();
for start in 0..text_word_count {
for end in start..text_word_count.min(start + MAX_SPAN_WIDTH) {
span_idx_flat.push(start as i64);
span_idx_flat.push(end as i64);
span_mask.push(true);
}
}
let num_spans = span_mask.len();
if num_spans == 0 {
return Ok(Vec::new());
}
let logits_raw: Vec<f32> = {
let mut session_guard = session.lock();
let input_ids_t = Tensor::<i64>::from_array(([1usize, seq_len], token_ids))
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
let attn_mask_t = Tensor::<i64>::from_array(([1usize, seq_len], attention_mask))
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
let words_mask_t = Tensor::<i64>::from_array(([1usize, seq_len], words_mask))
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
let text_lengths_t = Tensor::<i64>::from_array(([1usize, 1usize], text_lengths))
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
let span_idx_t = Tensor::<i64>::from_array(([1usize, num_spans, 2], span_idx_flat))
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
let span_mask_t = Tensor::<bool>::from_array(([1usize, num_spans], span_mask))
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
let outputs = session_guard
.run(inputs![
"input_ids" => input_ids_t,
"attention_mask" => attn_mask_t,
"words_mask" => words_mask_t,
"text_lengths" => text_lengths_t,
"span_idx" => span_idx_t,
"span_mask" => span_mask_t,
])
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
let (_shape, logits_slice) = outputs[0]
.try_extract_tensor::<f32>()
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
logits_slice.to_vec()
};
let num_entity_types = entity_types.len();
if logits_raw.len() != num_spans * num_entity_types {
warn!(
"GLiNER logits shape mismatch: got {}, expected {}",
logits_raw.len(),
num_spans * num_entity_types
);
return Ok(Vec::new());
}
let mut raw_entities: Vec<(usize, usize, usize, f32)> = Vec::new();
for (span_i, (start_w, end_w)) in iter_spans(text_word_count).enumerate() {
for (type_i, _) in entity_types.iter().enumerate() {
let score = sigmoid(logits_raw[span_i * num_entity_types + type_i]);
if score >= DEFAULT_SCORE_THRESHOLD {
raw_entities.push((type_i, start_w, end_w, score));
}
}
}
raw_entities.sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal));
let mut kept: Vec<(usize, usize, usize, f32)> = Vec::new();
'outer: for candidate in &raw_entities {
for kept_span in &kept {
if kept_span.0 == candidate.0
&& kept_span.1 <= candidate.2
&& candidate.1 <= kept_span.2
{
continue 'outer;
}
}
kept.push(*candidate);
}
let mut entities: Vec<ExtractedEntity> = kept
.into_iter()
.filter_map(|(type_i, start_w, end_w, score)| {
let start_wid = *text_word_ids.get(start_w)?;
let end_wid = *text_word_ids.get(end_w)?;
let &(start_byte_full, _) = word_byte_ranges.get(&start_wid)?;
let &(_, end_byte_full) = word_byte_ranges.get(&end_wid)?;
let start_byte = start_byte_full.saturating_sub(prefix_byte_offset);
let end_byte = end_byte_full.saturating_sub(prefix_byte_offset);
if start_byte >= end_byte || end_byte > text.len() {
return None;
}
let value = text[start_byte..end_byte].trim().to_string();
if value.is_empty() {
return None;
}
let entity_type = normalize_label(entity_types[type_i]);
Some(ExtractedEntity {
entity_type,
value,
score,
start: start_byte,
end: end_byte,
})
})
.collect();
entities.sort_by_key(|e| e.start);
debug!("GLiNER extracted {} entities", entities.len());
Ok(entities)
}
#[instrument(skip_all)]
async fn download_model_files() -> Result<(PathBuf, PathBuf)> {
info!(
"Resolving GLiNER model files: tokenizer={}, onnx={}",
GLINER_TOKENIZER_REPO, GLINER_MODEL_REPO
);
let tokenizer_cache = Self::model_cache_dir(GLINER_TOKENIZER_REPO)?;
let onnx_cache = Self::model_cache_dir(GLINER_MODEL_REPO)?;
let onnx_subdir = onnx_cache.join("onnx");
std::fs::create_dir_all(&onnx_subdir)?;
let local_tokenizer = tokenizer_cache.join("tokenizer.json");
let local_onnx = onnx_subdir.join("model_quantized.onnx");
if !local_tokenizer.exists() || !local_onnx.exists() {
let tok_cache = tokenizer_cache.clone();
let onnx_c = onnx_cache.clone();
let tok_exists = local_tokenizer.exists();
let onnx_exists = local_onnx.exists();
tokio::task::spawn_blocking(move || {
if !tok_exists {
crate::engine::EmbeddingEngine::download_hf_file_pub(
GLINER_TOKENIZER_REPO,
"tokenizer.json",
&tok_cache,
)
.map_err(|e| {
InferenceError::HubError(format!(
"Failed to download GLiNER tokenizer: {}",
e
))
})?;
}
if !onnx_exists {
crate::engine::EmbeddingEngine::download_hf_file_pub(
GLINER_MODEL_REPO,
GLINER_ONNX_FILE,
&onnx_c,
)
.map_err(|e| {
InferenceError::HubError(format!(
"Failed to download GLiNER ONNX model: {}",
e
))
})?;
}
Ok::<_, InferenceError>(())
})
.await
.map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
} else {
info!("GLiNER model files found in local cache");
}
let final_onnx = onnx_cache.join(GLINER_ONNX_FILE);
Ok((local_tokenizer, final_onnx))
}
fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
let base = std::env::var("HF_HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| {
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
PathBuf::from(home).join(".cache").join("huggingface")
});
let dir = base.join("dakera").join(model_id.replace('/', "--"));
std::fs::create_dir_all(&dir)?;
Ok(dir)
}
}
pub struct NerEngine {
gliner: Option<Arc<GlinerEngine>>,
}
impl NerEngine {
pub fn rule_based_only() -> Self {
Self { gliner: None }
}
pub async fn with_gliner(num_threads: Option<usize>) -> Result<Self> {
let gliner = GlinerEngine::new(num_threads).await?;
Ok(Self {
gliner: Some(Arc::new(gliner)),
})
}
pub async fn extract(&self, text: &str, gliner_types: &[&str]) -> Vec<ExtractedEntity> {
let mut entities = rule_based_extract(text);
if let Some(ref gliner) = self.gliner {
if !gliner_types.is_empty() {
match gliner.extract(text, gliner_types).await {
Ok(neural) => {
for ne in neural {
if !entities
.iter()
.any(|e| e.start == ne.start && e.end == ne.end)
{
entities.push(ne);
}
}
}
Err(e) => {
warn!("GLiNER extraction failed, using rule-based only: {}", e);
}
}
}
}
entities.sort_by_key(|e| e.start);
deduplicate_entities(entities)
}
}
fn count_distinct_word_ids(word_ids: &[Option<u32>]) -> usize {
let mut seen = std::collections::HashSet::new();
for &wid in word_ids {
if let Some(w) = wid {
seen.insert(w);
}
}
seen.len()
}
fn truncate_to_word_limit(text: &str, max_words: usize) -> &str {
let mut word_count = 0usize;
let mut byte_end = text.len();
let mut in_word = false;
for (i, ch) in text.char_indices() {
if ch.is_whitespace() {
if in_word {
word_count += 1;
if word_count >= max_words {
byte_end = i;
break;
}
}
in_word = false;
} else {
in_word = true;
}
}
&text[..byte_end]
}
fn iter_spans(num_words: usize) -> impl Iterator<Item = (usize, usize)> {
(0..num_words).flat_map(move |start| {
let max_end = num_words.min(start + MAX_SPAN_WIDTH);
(start..max_end).map(move |end| (start, end))
})
}
#[inline]
fn sigmoid(x: f32) -> f32 {
if x >= 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let ex = x.exp();
ex / (1.0 + ex)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rule_based_uuid() {
let text = "session id is 550e8400-e29b-41d4-a716-446655440000 here";
let entities = rule_based_extract(text);
assert!(entities.iter().any(|e| e.entity_type == "uuid"));
}
#[test]
fn test_rule_based_url() {
let text = "check https://example.com/path?q=1 for details";
let entities = rule_based_extract(text);
assert!(entities.iter().any(|e| e.entity_type == "url"));
}
#[test]
fn test_rule_based_email() {
let text = "contact alice@example.com for support";
let entities = rule_based_extract(text);
assert!(entities.iter().any(|e| e.entity_type == "email"));
assert!(!entities.iter().any(|e| e.entity_type == "url"));
}
#[test]
fn test_rule_based_iso_date() {
let text = "released on 2024-03-15 at noon";
let entities = rule_based_extract(text);
assert!(entities
.iter()
.any(|e| e.entity_type == "date" && e.value == "2024-03-15"));
}
#[test]
fn test_rule_based_natural_date() {
let text = "meeting on March 15, 2024 at noon";
let entities = rule_based_extract(text);
assert!(entities.iter().any(|e| e.entity_type == "date"));
}
#[test]
fn test_entity_to_tag_lowercase_value() {
let e = ExtractedEntity {
entity_type: "person".to_string(),
value: "Alice Smith".to_string(),
score: 0.9,
start: 0,
end: 11,
};
assert_eq!(e.to_tag(), "entity:person:alice smith");
}
#[test]
fn test_entity_to_tag_colon_escaping() {
let e = ExtractedEntity {
entity_type: "url".to_string(),
value: "http://example.com:8080/path".to_string(),
score: 1.0,
start: 0,
end: 27,
};
let tag = e.to_tag();
let parts: Vec<&str> = tag.splitn(3, ':').collect();
assert_eq!(parts.len(), 3, "tag should have 3 parts: {}", tag);
assert_eq!(parts[0], "entity");
assert_eq!(parts[1], "url");
assert!(
!parts[2].contains(':'),
"value should not contain colons: {}",
parts[2]
);
}
#[test]
fn test_entity_to_tag_normalizes_whitespace() {
let e = ExtractedEntity {
entity_type: "person".to_string(),
value: " John Doe ".to_string(),
score: 0.9,
start: 0,
end: 12,
};
assert_eq!(e.to_tag(), "entity:person:john doe");
}
#[test]
fn test_normalize_label() {
assert_eq!(normalize_label("Person"), "person");
assert_eq!(normalize_label("Law Firm"), "law_firm");
assert_eq!(normalize_label(" ORG "), "org");
assert_eq!(normalize_label("ORGANIZATION"), "organization");
assert_eq!(normalize_label("location"), "location");
}
#[test]
fn test_deduplicate_same_value_different_positions() {
let entities = vec![
ExtractedEntity {
entity_type: "person".to_string(),
value: "Alice".to_string(),
score: 0.8,
start: 0,
end: 5,
},
ExtractedEntity {
entity_type: "person".to_string(),
value: "Alice".to_string(),
score: 0.9,
start: 20,
end: 25,
},
];
let deduped = deduplicate_entities(entities);
assert_eq!(
deduped.len(),
1,
"same entity at different positions should be merged"
);
assert_eq!(deduped[0].score, 0.9, "should retain highest score");
}
#[test]
fn test_deduplicate_case_insensitive() {
let entities = vec![
ExtractedEntity {
entity_type: "person".to_string(),
value: "alice".to_string(),
score: 0.7,
start: 10,
end: 15,
},
ExtractedEntity {
entity_type: "person".to_string(),
value: "Alice".to_string(),
score: 0.95,
start: 0,
end: 5,
},
];
let deduped = deduplicate_entities(entities);
assert_eq!(
deduped.len(),
1,
"case-insensitive dedup: 'Alice' == 'alice'"
);
assert_eq!(deduped[0].score, 0.95);
}
#[test]
fn test_deduplicate_different_types_kept() {
let entities = vec![
ExtractedEntity {
entity_type: "person".to_string(),
value: "Apple".to_string(),
score: 0.6,
start: 0,
end: 5,
},
ExtractedEntity {
entity_type: "organization".to_string(),
value: "Apple".to_string(),
score: 0.9,
start: 0,
end: 5,
},
];
let deduped = deduplicate_entities(entities);
assert_eq!(
deduped.len(),
2,
"same value with different types must be kept separately"
);
}
#[test]
fn test_truncate_to_word_limit_long() {
let words: Vec<String> = (0..500).map(|i| format!("word{}", i)).collect();
let text = words.join(" ");
let truncated = truncate_to_word_limit(&text, 300);
let word_count = truncated.split_whitespace().count();
assert!(
word_count <= 300,
"truncated text must be ≤ 300 words, got {}",
word_count
);
}
#[test]
fn test_truncate_to_word_limit_short_pass_through() {
let text = "Hello world this is fine";
assert_eq!(
truncate_to_word_limit(text, 300),
text,
"short text must pass through unchanged"
);
}
#[test]
fn test_sigmoid() {
assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
assert!((sigmoid(100.0) - 1.0).abs() < 1e-4);
assert!((sigmoid(-100.0) - 0.0).abs() < 1e-4);
}
#[test]
fn test_count_distinct_word_ids() {
let wids: Vec<Option<u32>> =
vec![Some(0), Some(0), Some(1), Some(1), Some(2), None, Some(3)];
assert_eq!(count_distinct_word_ids(&wids), 4);
}
}