1use crate::error::{InferenceError, Result};
13use ort::inputs;
14use ort::session::builder::GraphOptimizationLevel;
15use ort::session::Session;
16use ort::value::Tensor;
17use parking_lot::Mutex;
18use regex::Regex;
19use std::path::PathBuf;
20use std::sync::Arc;
21use tokenizers::Tokenizer;
22use tracing::{debug, info, instrument, warn};
23
24#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
30pub struct ExtractedEntity {
31 pub entity_type: String,
33 pub value: String,
35 pub score: f32,
37 pub start: usize,
39 pub end: usize,
41}
42
43impl ExtractedEntity {
44 pub fn to_tag(&self) -> String {
46 let v = self.value.replace(':', "_");
47 format!("entity:{}:{}", self.entity_type, v)
48 }
49}
50
51struct RulePatterns {
56 uuid: Regex,
57 url: Regex,
58 email: Regex,
59 iso_date: Regex,
60 natural_date: Regex,
61 ip_v4: Regex,
62}
63
64impl RulePatterns {
65 fn new() -> Self {
66 Self {
67 uuid: Regex::new(
68 r"(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b",
69 )
70 .expect("uuid regex"),
71 url: Regex::new(r#"https?://[^\s<>\[\]()"']+"#)
72 .expect("url regex"),
73 email: Regex::new(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}")
74 .expect("email regex"),
75 iso_date: Regex::new(
76 r"\b\d{4}-(?:0[1-9]|1[0-2])-(?:0[1-9]|[12]\d|3[01])\b",
77 )
78 .expect("iso_date regex"),
79 natural_date: Regex::new(
80 r"(?i)\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:,\s*\d{4})?\b",
81 )
82 .expect("natural_date regex"),
83 ip_v4: Regex::new(
84 r"\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b",
85 )
86 .expect("ipv4 regex"),
87 }
88 }
89}
90
91lazy_static::lazy_static! {
92 static ref RULE_PATTERNS: RulePatterns = RulePatterns::new();
93}
94
95pub fn rule_based_extract(text: &str) -> Vec<ExtractedEntity> {
99 let mut entities: Vec<ExtractedEntity> = Vec::new();
100
101 let push = |entities: &mut Vec<ExtractedEntity>, entity_type: &str, m: regex::Match| {
102 entities.push(ExtractedEntity {
103 entity_type: entity_type.to_string(),
104 value: m.as_str().to_string(),
105 score: 1.0,
106 start: m.start(),
107 end: m.end(),
108 });
109 };
110
111 for m in RULE_PATTERNS.email.find_iter(text) {
113 push(&mut entities, "email", m);
114 }
115 for m in RULE_PATTERNS.url.find_iter(text) {
116 if !entities.iter().any(|e| e.start == m.start()) {
118 push(&mut entities, "url", m);
119 }
120 }
121 for m in RULE_PATTERNS.uuid.find_iter(text) {
122 push(&mut entities, "uuid", m);
123 }
124 for m in RULE_PATTERNS.iso_date.find_iter(text) {
125 push(&mut entities, "date", m);
126 }
127 for m in RULE_PATTERNS.natural_date.find_iter(text) {
128 if !entities
130 .iter()
131 .any(|e| e.start == m.start() && e.entity_type == "date")
132 {
133 push(&mut entities, "date", m);
134 }
135 }
136 for m in RULE_PATTERNS.ip_v4.find_iter(text) {
137 push(&mut entities, "ip", m);
138 }
139
140 entities
141}
142
143const GLINER_MODEL_REPO: &str = "onnx-community/gliner_medium-v2.1";
148const GLINER_TOKENIZER_REPO: &str = "onnx-community/gliner_medium-v2.1";
149const GLINER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
150
151const MAX_SPAN_WIDTH: usize = 12;
153const SCORE_THRESHOLD: f32 = 0.5;
155
156pub struct GlinerEngine {
160 session: Arc<Mutex<Session>>,
161 tokenizer: Arc<Tokenizer>,
162}
163
164impl GlinerEngine {
165 #[instrument(skip_all)]
167 pub async fn new(num_threads: Option<usize>) -> Result<Self> {
168 let threads = num_threads.unwrap_or(1);
169 info!("Initializing GLiNER NER engine (threads={})", threads);
170
171 let (tokenizer_path, onnx_path) = Self::download_model_files().await?;
172
173 let tokenizer = Tokenizer::from_file(&tokenizer_path)
174 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
175
176 let session = Session::builder()
177 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
178 .with_optimization_level(GraphOptimizationLevel::Level3)
179 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
180 .with_intra_threads(threads)
181 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
182 .commit_from_file(&onnx_path)
183 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
184
185 info!("GLiNER engine ready");
186 Ok(Self {
187 session: Arc::new(Mutex::new(session)),
188 tokenizer: Arc::new(tokenizer),
189 })
190 }
191
192 pub async fn extract(&self, text: &str, entity_types: &[&str]) -> Result<Vec<ExtractedEntity>> {
196 if entity_types.is_empty() || text.is_empty() {
197 return Ok(Vec::new());
198 }
199
200 let text_owned = text.to_string();
201 let entity_types_owned: Vec<String> = entity_types.iter().map(|s| s.to_string()).collect();
202 let session = self.session.clone();
203 let tokenizer = self.tokenizer.clone();
204
205 tokio::task::spawn_blocking(move || {
206 Self::run_inference(
207 &text_owned,
208 &entity_types_owned
209 .iter()
210 .map(|s| s.as_str())
211 .collect::<Vec<_>>(),
212 &session,
213 &tokenizer,
214 )
215 })
216 .await
217 .map_err(|e| InferenceError::HubError(format!("GLiNER inference task panicked: {}", e)))?
218 }
219
220 fn run_inference(
221 text: &str,
222 entity_types: &[&str],
223 session: &Arc<Mutex<Session>>,
224 tokenizer: &Tokenizer,
225 ) -> Result<Vec<ExtractedEntity>> {
226 let mut full_text = entity_types.join(" << >> ");
229 full_text.push_str(" << >> ");
230 full_text.push_str(text);
231
232 let encoding = tokenizer
234 .encode(full_text.as_str(), true)
235 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
236
237 let token_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
238 let attention_mask: Vec<i64> = encoding
239 .get_attention_mask()
240 .iter()
241 .map(|&x| x as i64)
242 .collect();
243 let seq_len = token_ids.len();
244
245 let word_ids = encoding.get_word_ids();
249
250 let mut words_mask = vec![0i64; seq_len];
251 let mut last_word_id: Option<u32> = None;
252 let mut text_token_start = usize::MAX;
253
254 let prefix = entity_types.join(" << >> ");
260 let prefix_plus_sep = format!("{} << >> ", prefix);
261 let prefix_encoding = tokenizer
262 .encode(prefix_plus_sep.as_str(), false)
263 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
264 let prefix_word_count = prefix_encoding
265 .get_word_ids()
266 .iter()
267 .filter_map(|&w| w)
268 .collect::<std::collections::HashSet<_>>()
269 .len();
270
271 let mut text_word_count = 0usize;
272 for (i, &wid_opt) in word_ids.iter().enumerate() {
273 let wid = match wid_opt {
274 Some(w) => w,
275 None => {
276 last_word_id = None;
277 continue;
278 }
279 };
280 let is_new_word = last_word_id.map(|lw| lw != wid).unwrap_or(true);
281 if is_new_word {
282 let global_word_idx = {
283 word_ids[..i]
285 .iter()
286 .filter_map(|&w| w)
287 .collect::<std::collections::HashSet<_>>()
288 .len()
289 };
290 if global_word_idx >= prefix_word_count {
291 if text_token_start == usize::MAX {
292 text_token_start = i;
293 }
294 words_mask[i] = 1;
295 text_word_count += 1;
296 }
297 }
298 last_word_id = Some(wid);
299 }
300
301 if text_word_count == 0 || text_token_start == usize::MAX {
302 debug!("No text words found after entity type prefix, skipping inference");
303 return Ok(Vec::new());
304 }
305 let text_lengths = vec![text_word_count as i64];
306
307 let mut span_idx_flat: Vec<i64> = Vec::new();
310 let mut span_mask: Vec<bool> = Vec::new();
311
312 for start in 0..text_word_count {
313 for end in start..text_word_count.min(start + MAX_SPAN_WIDTH) {
314 span_idx_flat.push(start as i64);
315 span_idx_flat.push(end as i64);
316 span_mask.push(true);
317 }
318 }
319
320 let num_spans = span_mask.len();
321 if num_spans == 0 {
322 return Ok(Vec::new());
323 }
324
325 let span_mask_values: Vec<i64> = span_mask
327 .iter()
328 .map(|&b| if b { 1i64 } else { 0 })
329 .collect();
330
331 let logits_raw: Vec<f32> = {
332 let mut session_guard = session.lock();
333
334 let input_ids_t = Tensor::<i64>::from_array(([1usize, seq_len], token_ids))
336 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
337 let attn_mask_t = Tensor::<i64>::from_array(([1usize, seq_len], attention_mask))
338 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
339 let words_mask_t = Tensor::<i64>::from_array(([1usize, seq_len], words_mask))
340 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
341 let text_lengths_t = Tensor::<i64>::from_array(([1usize, 1usize], text_lengths))
344 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
345 let span_idx_t = Tensor::<i64>::from_array(([1usize, num_spans, 2], span_idx_flat))
346 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
347 let span_mask_t = Tensor::<i64>::from_array(([1usize, num_spans], span_mask_values))
348 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
349
350 let outputs = session_guard
351 .run(inputs![
352 "input_ids" => input_ids_t,
353 "attention_mask" => attn_mask_t,
354 "words_mask" => words_mask_t,
355 "text_lengths" => text_lengths_t,
356 "span_idx" => span_idx_t,
357 "span_mask" => span_mask_t,
358 ])
359 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
360
361 let (_shape, logits_slice) = outputs[0]
363 .try_extract_tensor::<f32>()
364 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
365 logits_slice.to_vec()
366 };
367
368 let num_entity_types = entity_types.len();
370 let expected = num_spans * num_entity_types;
371 if logits_raw.len() != expected {
372 warn!(
373 "GLiNER logits shape mismatch: got {}, expected {}",
374 logits_raw.len(),
375 expected
376 );
377 return Ok(Vec::new());
378 }
379
380 let mut raw_entities: Vec<(usize, usize, usize, f32)> = Vec::new(); for (span_i, (start_w, end_w)) in iter_spans(text_word_count).enumerate() {
384 for (type_i, _entity_type) in entity_types.iter().enumerate() {
385 let logit = logits_raw[span_i * num_entity_types + type_i];
386 let score = sigmoid(logit);
387 if score >= SCORE_THRESHOLD {
388 raw_entities.push((type_i, start_w, end_w, score));
389 }
390 }
391 }
392
393 raw_entities.sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal));
395 let mut kept: Vec<(usize, usize, usize, f32)> = Vec::new();
396 'outer: for candidate in &raw_entities {
397 for kept_span in &kept {
398 if kept_span.0 == candidate.0
400 && kept_span.1 <= candidate.2
401 && candidate.1 <= kept_span.2
402 {
403 continue 'outer;
404 }
405 }
406 kept.push(*candidate);
407 }
408
409 let words: Vec<&str> = text.split_whitespace().collect();
411 let mut word_char_starts: Vec<usize> = Vec::with_capacity(words.len());
412 let mut word_char_ends: Vec<usize> = Vec::with_capacity(words.len());
413 {
414 let mut char_pos = 0usize;
415 for word in &words {
416 if let Some(rel) = text[char_pos..].find(word) {
418 let start = char_pos + rel;
419 let end = start + word.len();
420 word_char_starts.push(start);
421 word_char_ends.push(end);
422 char_pos = end;
423 } else {
424 word_char_starts.push(char_pos);
425 word_char_ends.push(char_pos);
426 }
427 }
428 }
429
430 let mut entities: Vec<ExtractedEntity> = kept
431 .into_iter()
432 .filter_map(|(type_i, start_w, end_w, score)| {
433 let start_char = *word_char_starts.get(start_w)?;
434 let end_char = *word_char_ends.get(end_w)?;
435 let value = text[start_char..end_char].to_string();
436 Some(ExtractedEntity {
437 entity_type: entity_types[type_i].to_lowercase().replace(' ', "_"),
438 value,
439 score,
440 start: start_char,
441 end: end_char,
442 })
443 })
444 .collect();
445
446 entities.sort_by_key(|e| e.start);
447 debug!("GLiNER extracted {} entities", entities.len());
448 Ok(entities)
449 }
450
451 #[instrument(skip_all)]
454 async fn download_model_files() -> Result<(PathBuf, PathBuf)> {
455 info!(
456 "Resolving GLiNER model files: tokenizer={}, onnx={}",
457 GLINER_TOKENIZER_REPO, GLINER_MODEL_REPO
458 );
459
460 let tokenizer_cache = Self::model_cache_dir(GLINER_TOKENIZER_REPO)?;
461 let onnx_cache = Self::model_cache_dir(GLINER_MODEL_REPO)?;
462 let onnx_subdir = onnx_cache.join("onnx");
463 std::fs::create_dir_all(&onnx_subdir)?;
464
465 let local_tokenizer = tokenizer_cache.join("tokenizer.json");
466 let local_onnx = onnx_subdir.join("model_quantized.onnx");
467
468 if !local_tokenizer.exists() || !local_onnx.exists() {
469 let tok_cache = tokenizer_cache.clone();
470 let onnx_c = onnx_cache.clone();
471 let tok_exists = local_tokenizer.exists();
472 let onnx_exists = local_onnx.exists();
473
474 tokio::task::spawn_blocking(move || {
475 if !tok_exists {
476 crate::engine::EmbeddingEngine::download_hf_file_pub(
477 GLINER_TOKENIZER_REPO,
478 "tokenizer.json",
479 &tok_cache,
480 )
481 .map_err(|e| {
482 InferenceError::HubError(format!(
483 "Failed to download GLiNER tokenizer: {}",
484 e
485 ))
486 })?;
487 }
488 if !onnx_exists {
489 crate::engine::EmbeddingEngine::download_hf_file_pub(
490 GLINER_MODEL_REPO,
491 GLINER_ONNX_FILE,
492 &onnx_c,
493 )
494 .map_err(|e| {
495 InferenceError::HubError(format!(
496 "Failed to download GLiNER ONNX model: {}",
497 e
498 ))
499 })?;
500 }
501 Ok::<_, InferenceError>(())
502 })
503 .await
504 .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
505 } else {
506 info!("GLiNER model files found in local cache");
507 }
508
509 let final_onnx = onnx_cache.join(GLINER_ONNX_FILE);
510 Ok((local_tokenizer, final_onnx))
511 }
512
513 fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
514 let base = std::env::var("HF_HOME")
515 .map(PathBuf::from)
516 .unwrap_or_else(|_| {
517 let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
518 PathBuf::from(home).join(".cache").join("huggingface")
519 });
520 let dir = base.join("dakera").join(model_id.replace('/', "--"));
521 std::fs::create_dir_all(&dir)?;
522 Ok(dir)
523 }
524}
525
526pub struct NerEngine {
532 gliner: Option<Arc<GlinerEngine>>,
533}
534
535impl NerEngine {
536 pub fn rule_based_only() -> Self {
538 Self { gliner: None }
539 }
540
541 pub async fn with_gliner(num_threads: Option<usize>) -> Result<Self> {
543 let gliner = GlinerEngine::new(num_threads).await?;
544 Ok(Self {
545 gliner: Some(Arc::new(gliner)),
546 })
547 }
548
549 pub async fn extract(&self, text: &str, gliner_types: &[&str]) -> Vec<ExtractedEntity> {
555 let mut entities = rule_based_extract(text);
556
557 if let Some(ref gliner) = self.gliner {
558 if !gliner_types.is_empty() {
559 match gliner.extract(text, gliner_types).await {
560 Ok(neural) => {
561 for ne in neural {
562 if !entities
564 .iter()
565 .any(|e| e.start == ne.start && e.end == ne.end)
566 {
567 entities.push(ne);
568 }
569 }
570 }
571 Err(e) => {
572 warn!("GLiNER extraction failed, using rule-based only: {}", e);
573 }
574 }
575 }
576 }
577
578 entities.sort_by_key(|e| e.start);
579 entities
580 }
581}
582
583fn iter_spans(num_words: usize) -> impl Iterator<Item = (usize, usize)> {
589 (0..num_words).flat_map(move |start| {
590 let max_end = num_words.min(start + MAX_SPAN_WIDTH);
591 (start..max_end).map(move |end| (start, end))
592 })
593}
594
595#[inline]
597fn sigmoid(x: f32) -> f32 {
598 if x >= 0.0 {
599 1.0 / (1.0 + (-x).exp())
600 } else {
601 let ex = x.exp();
602 ex / (1.0 + ex)
603 }
604}
605
606#[cfg(test)]
611mod tests {
612 use super::*;
613
614 #[test]
615 fn test_rule_based_uuid() {
616 let text = "session id is 550e8400-e29b-41d4-a716-446655440000 here";
617 let entities = rule_based_extract(text);
618 assert!(entities.iter().any(|e| e.entity_type == "uuid"));
619 }
620
621 #[test]
622 fn test_rule_based_url() {
623 let text = "check https://example.com/path?q=1 for details";
624 let entities = rule_based_extract(text);
625 assert!(entities.iter().any(|e| e.entity_type == "url"));
626 }
627
628 #[test]
629 fn test_rule_based_email() {
630 let text = "contact alice@example.com for support";
631 let entities = rule_based_extract(text);
632 assert!(entities.iter().any(|e| e.entity_type == "email"));
633 assert!(!entities.iter().any(|e| e.entity_type == "url"));
635 }
636
637 #[test]
638 fn test_rule_based_iso_date() {
639 let text = "released on 2024-03-15 at noon";
640 let entities = rule_based_extract(text);
641 assert!(entities
642 .iter()
643 .any(|e| e.entity_type == "date" && e.value == "2024-03-15"));
644 }
645
646 #[test]
647 fn test_rule_based_natural_date() {
648 let text = "meeting on March 15, 2024 at noon";
649 let entities = rule_based_extract(text);
650 assert!(entities.iter().any(|e| e.entity_type == "date"));
651 }
652
653 #[test]
654 fn test_entity_to_tag() {
655 let e = ExtractedEntity {
656 entity_type: "person".to_string(),
657 value: "Alice Smith".to_string(),
658 score: 0.9,
659 start: 0,
660 end: 11,
661 };
662 assert_eq!(e.to_tag(), "entity:person:Alice Smith");
663 }
664
665 #[test]
666 fn test_entity_to_tag_colon_escaping() {
667 let e = ExtractedEntity {
668 entity_type: "url".to_string(),
669 value: "http://example.com:8080/path".to_string(),
670 score: 1.0,
671 start: 0,
672 end: 27,
673 };
674 let tag = e.to_tag();
675 let parts: Vec<&str> = tag.splitn(3, ':').collect();
678 assert_eq!(parts.len(), 3, "tag should have 3 parts: {}", tag);
679 assert_eq!(parts[0], "entity");
680 assert_eq!(parts[1], "url");
681 assert!(
682 !parts[2].contains(':'),
683 "value should not contain colons: {}",
684 parts[2]
685 );
686 }
687
688 #[test]
689 fn test_sigmoid() {
690 assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
691 assert!((sigmoid(100.0) - 1.0).abs() < 1e-4);
692 assert!((sigmoid(-100.0) - 0.0).abs() < 1e-4);
693 }
694}