use crate::error::{InferenceError, Result};
use ort::inputs;
use ort::session::builder::GraphOptimizationLevel;
use ort::session::Session;
use ort::value::Tensor;
use parking_lot::Mutex;
use regex::Regex;
use std::path::PathBuf;
use std::sync::Arc;
use tokenizers::Tokenizer;
use tracing::{debug, info, instrument, warn};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ExtractedEntity {
pub entity_type: String,
pub value: String,
pub score: f32,
pub start: usize,
pub end: usize,
}
impl ExtractedEntity {
pub fn to_tag(&self) -> String {
let v = self.value.replace(':', "_");
format!("entity:{}:{}", self.entity_type, v)
}
}
struct RulePatterns {
uuid: Regex,
url: Regex,
email: Regex,
iso_date: Regex,
natural_date: Regex,
ip_v4: Regex,
}
impl RulePatterns {
fn new() -> Self {
Self {
uuid: Regex::new(
r"(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b",
)
.expect("uuid regex"),
url: Regex::new(r#"https?://[^\s<>\[\]()"']+"#)
.expect("url regex"),
email: Regex::new(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}")
.expect("email regex"),
iso_date: Regex::new(
r"\b\d{4}-(?:0[1-9]|1[0-2])-(?:0[1-9]|[12]\d|3[01])\b",
)
.expect("iso_date regex"),
natural_date: Regex::new(
r"(?i)\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:,\s*\d{4})?\b",
)
.expect("natural_date regex"),
ip_v4: Regex::new(
r"\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b",
)
.expect("ipv4 regex"),
}
}
}
lazy_static::lazy_static! {
static ref RULE_PATTERNS: RulePatterns = RulePatterns::new();
}
pub fn rule_based_extract(text: &str) -> Vec<ExtractedEntity> {
let mut entities: Vec<ExtractedEntity> = Vec::new();
let push = |entities: &mut Vec<ExtractedEntity>, entity_type: &str, m: regex::Match| {
entities.push(ExtractedEntity {
entity_type: entity_type.to_string(),
value: m.as_str().to_string(),
score: 1.0,
start: m.start(),
end: m.end(),
});
};
for m in RULE_PATTERNS.email.find_iter(text) {
push(&mut entities, "email", m);
}
for m in RULE_PATTERNS.url.find_iter(text) {
if !entities.iter().any(|e| e.start == m.start()) {
push(&mut entities, "url", m);
}
}
for m in RULE_PATTERNS.uuid.find_iter(text) {
push(&mut entities, "uuid", m);
}
for m in RULE_PATTERNS.iso_date.find_iter(text) {
push(&mut entities, "date", m);
}
for m in RULE_PATTERNS.natural_date.find_iter(text) {
if !entities
.iter()
.any(|e| e.start == m.start() && e.entity_type == "date")
{
push(&mut entities, "date", m);
}
}
for m in RULE_PATTERNS.ip_v4.find_iter(text) {
push(&mut entities, "ip", m);
}
entities
}
const GLINER_MODEL_REPO: &str = "onnx-community/gliner-medium-v2.1";
const GLINER_TOKENIZER_REPO: &str = "knowledgator/gliner-medium-v2.1";
const GLINER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
const MAX_SPAN_WIDTH: usize = 12;
const SCORE_THRESHOLD: f32 = 0.5;
pub struct GlinerEngine {
session: Arc<Mutex<Session>>,
tokenizer: Arc<Tokenizer>,
}
impl GlinerEngine {
#[instrument(skip_all)]
pub async fn new(num_threads: Option<usize>) -> Result<Self> {
let threads = num_threads.unwrap_or(1);
info!("Initializing GLiNER NER engine (threads={})", threads);
let (tokenizer_path, onnx_path) = Self::download_model_files().await?;
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
let session = Session::builder()
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
.with_intra_threads(threads)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
.commit_from_file(&onnx_path)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
info!("GLiNER engine ready");
Ok(Self {
session: Arc::new(Mutex::new(session)),
tokenizer: Arc::new(tokenizer),
})
}
pub async fn extract(&self, text: &str, entity_types: &[&str]) -> Result<Vec<ExtractedEntity>> {
if entity_types.is_empty() || text.is_empty() {
return Ok(Vec::new());
}
let text_owned = text.to_string();
let entity_types_owned: Vec<String> = entity_types.iter().map(|s| s.to_string()).collect();
let session = self.session.clone();
let tokenizer = self.tokenizer.clone();
tokio::task::spawn_blocking(move || {
Self::run_inference(
&text_owned,
&entity_types_owned
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>(),
&session,
&tokenizer,
)
})
.await
.map_err(|e| InferenceError::HubError(format!("GLiNER inference task panicked: {}", e)))?
}
fn run_inference(
text: &str,
entity_types: &[&str],
session: &Arc<Mutex<Session>>,
tokenizer: &Tokenizer,
) -> Result<Vec<ExtractedEntity>> {
let mut full_text = entity_types.join(" << >> ");
full_text.push_str(" << >> ");
full_text.push_str(text);
let encoding = tokenizer
.encode(full_text.as_str(), true)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
let token_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
let attention_mask: Vec<i64> = encoding
.get_attention_mask()
.iter()
.map(|&x| x as i64)
.collect();
let seq_len = token_ids.len();
let word_ids = encoding.get_word_ids();
let mut words_mask = vec![0i64; seq_len];
let mut last_word_id: Option<u32> = None;
let mut text_token_start = usize::MAX;
let prefix = entity_types.join(" << >> ");
let prefix_plus_sep = format!("{} << >> ", prefix);
let prefix_encoding = tokenizer
.encode(prefix_plus_sep.as_str(), false)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
let prefix_word_count = prefix_encoding
.get_word_ids()
.iter()
.filter_map(|&w| w)
.collect::<std::collections::HashSet<_>>()
.len();
let mut text_word_count = 0usize;
for (i, &wid_opt) in word_ids.iter().enumerate() {
let wid = match wid_opt {
Some(w) => w,
None => {
last_word_id = None;
continue;
}
};
let is_new_word = last_word_id.map(|lw| lw != wid).unwrap_or(true);
if is_new_word {
let global_word_idx = {
word_ids[..i]
.iter()
.filter_map(|&w| w)
.collect::<std::collections::HashSet<_>>()
.len()
};
if global_word_idx >= prefix_word_count {
if text_token_start == usize::MAX {
text_token_start = i;
}
words_mask[i] = 1;
text_word_count += 1;
}
}
last_word_id = Some(wid);
}
if text_word_count == 0 || text_token_start == usize::MAX {
debug!("No text words found after entity type prefix, skipping inference");
return Ok(Vec::new());
}
let text_lengths = vec![text_word_count as i64];
let mut span_idx_flat: Vec<i64> = Vec::new();
let mut span_mask: Vec<bool> = Vec::new();
for start in 0..text_word_count {
for end in start..text_word_count.min(start + MAX_SPAN_WIDTH) {
span_idx_flat.push(start as i64);
span_idx_flat.push(end as i64);
span_mask.push(true);
}
}
let num_spans = span_mask.len();
if num_spans == 0 {
return Ok(Vec::new());
}
let span_mask_values: Vec<i64> = span_mask
.iter()
.map(|&b| if b { 1i64 } else { 0 })
.collect();
let logits_raw: Vec<f32> = {
let mut session_guard = session.lock();
let input_ids_t = Tensor::<i64>::from_array(([1usize, seq_len], token_ids))
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
let attn_mask_t = Tensor::<i64>::from_array(([1usize, seq_len], attention_mask))
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
let words_mask_t = Tensor::<i64>::from_array(([1usize, seq_len], words_mask))
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
let text_lengths_t = Tensor::<i64>::from_array(([1usize], text_lengths))
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
let span_idx_t = Tensor::<i64>::from_array(([1usize, num_spans, 2], span_idx_flat))
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
let span_mask_t = Tensor::<i64>::from_array(([1usize, num_spans], span_mask_values))
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
let outputs = session_guard
.run(inputs![
"input_ids" => input_ids_t,
"attention_mask" => attn_mask_t,
"words_mask" => words_mask_t,
"text_lengths" => text_lengths_t,
"span_idx" => span_idx_t,
"span_mask" => span_mask_t,
])
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
let (_shape, logits_slice) = outputs[0]
.try_extract_tensor::<f32>()
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
logits_slice.to_vec()
};
let num_entity_types = entity_types.len();
let expected = num_spans * num_entity_types;
if logits_raw.len() != expected {
warn!(
"GLiNER logits shape mismatch: got {}, expected {}",
logits_raw.len(),
expected
);
return Ok(Vec::new());
}
let mut raw_entities: Vec<(usize, usize, usize, f32)> = Vec::new();
for (span_i, (start_w, end_w)) in iter_spans(text_word_count).enumerate() {
for (type_i, _entity_type) in entity_types.iter().enumerate() {
let logit = logits_raw[span_i * num_entity_types + type_i];
let score = sigmoid(logit);
if score >= SCORE_THRESHOLD {
raw_entities.push((type_i, start_w, end_w, score));
}
}
}
raw_entities.sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal));
let mut kept: Vec<(usize, usize, usize, f32)> = Vec::new();
'outer: for candidate in &raw_entities {
for kept_span in &kept {
if kept_span.0 == candidate.0
&& kept_span.1 <= candidate.2
&& candidate.1 <= kept_span.2
{
continue 'outer;
}
}
kept.push(*candidate);
}
let words: Vec<&str> = text.split_whitespace().collect();
let mut word_char_starts: Vec<usize> = Vec::with_capacity(words.len());
let mut word_char_ends: Vec<usize> = Vec::with_capacity(words.len());
{
let mut char_pos = 0usize;
for word in &words {
if let Some(rel) = text[char_pos..].find(word) {
let start = char_pos + rel;
let end = start + word.len();
word_char_starts.push(start);
word_char_ends.push(end);
char_pos = end;
} else {
word_char_starts.push(char_pos);
word_char_ends.push(char_pos);
}
}
}
let mut entities: Vec<ExtractedEntity> = kept
.into_iter()
.filter_map(|(type_i, start_w, end_w, score)| {
let start_char = *word_char_starts.get(start_w)?;
let end_char = *word_char_ends.get(end_w)?;
let value = text[start_char..end_char].to_string();
Some(ExtractedEntity {
entity_type: entity_types[type_i].to_lowercase().replace(' ', "_"),
value,
score,
start: start_char,
end: end_char,
})
})
.collect();
entities.sort_by_key(|e| e.start);
debug!("GLiNER extracted {} entities", entities.len());
Ok(entities)
}
#[instrument(skip_all)]
async fn download_model_files() -> Result<(PathBuf, PathBuf)> {
info!(
"Resolving GLiNER model files: tokenizer={}, onnx={}",
GLINER_TOKENIZER_REPO, GLINER_MODEL_REPO
);
let tokenizer_cache = Self::model_cache_dir(GLINER_TOKENIZER_REPO)?;
let onnx_cache = Self::model_cache_dir(GLINER_MODEL_REPO)?;
let onnx_subdir = onnx_cache.join("onnx");
std::fs::create_dir_all(&onnx_subdir)?;
let local_tokenizer = tokenizer_cache.join("tokenizer.json");
let local_onnx = onnx_subdir.join("model_quantized.onnx");
if !local_tokenizer.exists() || !local_onnx.exists() {
let tok_cache = tokenizer_cache.clone();
let onnx_c = onnx_cache.clone();
let tok_exists = local_tokenizer.exists();
let onnx_exists = local_onnx.exists();
tokio::task::spawn_blocking(move || {
if !tok_exists {
crate::engine::EmbeddingEngine::download_hf_file_pub(
GLINER_TOKENIZER_REPO,
"tokenizer.json",
&tok_cache,
)
.map_err(|e| {
InferenceError::HubError(format!(
"Failed to download GLiNER tokenizer: {}",
e
))
})?;
}
if !onnx_exists {
crate::engine::EmbeddingEngine::download_hf_file_pub(
GLINER_MODEL_REPO,
GLINER_ONNX_FILE,
&onnx_c,
)
.map_err(|e| {
InferenceError::HubError(format!(
"Failed to download GLiNER ONNX model: {}",
e
))
})?;
}
Ok::<_, InferenceError>(())
})
.await
.map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
} else {
info!("GLiNER model files found in local cache");
}
let final_onnx = onnx_cache.join(GLINER_ONNX_FILE);
Ok((local_tokenizer, final_onnx))
}
fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
let base = std::env::var("HF_HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| {
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
PathBuf::from(home).join(".cache").join("huggingface")
});
let dir = base.join("dakera").join(model_id.replace('/', "--"));
std::fs::create_dir_all(&dir)?;
Ok(dir)
}
}
pub struct NerEngine {
gliner: Option<Arc<GlinerEngine>>,
}
impl NerEngine {
pub fn rule_based_only() -> Self {
Self { gliner: None }
}
pub async fn with_gliner(num_threads: Option<usize>) -> Result<Self> {
let gliner = GlinerEngine::new(num_threads).await?;
Ok(Self {
gliner: Some(Arc::new(gliner)),
})
}
pub async fn extract(&self, text: &str, gliner_types: &[&str]) -> Vec<ExtractedEntity> {
let mut entities = rule_based_extract(text);
if let Some(ref gliner) = self.gliner {
if !gliner_types.is_empty() {
match gliner.extract(text, gliner_types).await {
Ok(neural) => {
for ne in neural {
if !entities
.iter()
.any(|e| e.start == ne.start && e.end == ne.end)
{
entities.push(ne);
}
}
}
Err(e) => {
warn!("GLiNER extraction failed, using rule-based only: {}", e);
}
}
}
}
entities.sort_by_key(|e| e.start);
entities
}
}
fn iter_spans(num_words: usize) -> impl Iterator<Item = (usize, usize)> {
(0..num_words).flat_map(move |start| {
let max_end = num_words.min(start + MAX_SPAN_WIDTH);
(start..max_end).map(move |end| (start, end))
})
}
#[inline]
fn sigmoid(x: f32) -> f32 {
if x >= 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let ex = x.exp();
ex / (1.0 + ex)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rule_based_uuid() {
let text = "session id is 550e8400-e29b-41d4-a716-446655440000 here";
let entities = rule_based_extract(text);
assert!(entities.iter().any(|e| e.entity_type == "uuid"));
}
#[test]
fn test_rule_based_url() {
let text = "check https://example.com/path?q=1 for details";
let entities = rule_based_extract(text);
assert!(entities.iter().any(|e| e.entity_type == "url"));
}
#[test]
fn test_rule_based_email() {
let text = "contact alice@example.com for support";
let entities = rule_based_extract(text);
assert!(entities.iter().any(|e| e.entity_type == "email"));
assert!(!entities.iter().any(|e| e.entity_type == "url"));
}
#[test]
fn test_rule_based_iso_date() {
let text = "released on 2024-03-15 at noon";
let entities = rule_based_extract(text);
assert!(entities
.iter()
.any(|e| e.entity_type == "date" && e.value == "2024-03-15"));
}
#[test]
fn test_rule_based_natural_date() {
let text = "meeting on March 15, 2024 at noon";
let entities = rule_based_extract(text);
assert!(entities.iter().any(|e| e.entity_type == "date"));
}
#[test]
fn test_entity_to_tag() {
let e = ExtractedEntity {
entity_type: "person".to_string(),
value: "Alice Smith".to_string(),
score: 0.9,
start: 0,
end: 11,
};
assert_eq!(e.to_tag(), "entity:person:Alice Smith");
}
#[test]
fn test_entity_to_tag_colon_escaping() {
let e = ExtractedEntity {
entity_type: "url".to_string(),
value: "http://example.com:8080/path".to_string(),
score: 1.0,
start: 0,
end: 27,
};
let tag = e.to_tag();
let parts: Vec<&str> = tag.splitn(3, ':').collect();
assert_eq!(parts.len(), 3, "tag should have 3 parts: {}", tag);
assert_eq!(parts[0], "entity");
assert_eq!(parts[1], "url");
assert!(
!parts[2].contains(':'),
"value should not contain colons: {}",
parts[2]
);
}
#[test]
fn test_sigmoid() {
assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
assert!((sigmoid(100.0) - 1.0).abs() < 1e-4);
assert!((sigmoid(-100.0) - 0.0).abs() < 1e-4);
}
}