1use crate::error::{InferenceError, Result};
13use ort::inputs;
14use ort::session::builder::GraphOptimizationLevel;
15use ort::session::Session;
16use ort::value::Tensor;
17use parking_lot::Mutex;
18use regex::Regex;
19use std::path::PathBuf;
20use std::sync::Arc;
21use tokenizers::Tokenizer;
22use tracing::{debug, info, instrument, warn};
23
24#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
30pub struct ExtractedEntity {
31 pub entity_type: String,
33 pub value: String,
35 pub score: f32,
37 pub start: usize,
39 pub end: usize,
41}
42
43impl ExtractedEntity {
44 pub fn to_tag(&self) -> String {
46 let v = self.value.replace(':', "_");
47 format!("entity:{}:{}", self.entity_type, v)
48 }
49}
50
51struct RulePatterns {
56 uuid: Regex,
57 url: Regex,
58 email: Regex,
59 iso_date: Regex,
60 natural_date: Regex,
61 ip_v4: Regex,
62}
63
64impl RulePatterns {
65 fn new() -> Self {
66 Self {
67 uuid: Regex::new(
68 r"(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b",
69 )
70 .expect("uuid regex"),
71 url: Regex::new(r#"https?://[^\s<>\[\]()"']+"#)
72 .expect("url regex"),
73 email: Regex::new(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}")
74 .expect("email regex"),
75 iso_date: Regex::new(
76 r"\b\d{4}-(?:0[1-9]|1[0-2])-(?:0[1-9]|[12]\d|3[01])\b",
77 )
78 .expect("iso_date regex"),
79 natural_date: Regex::new(
80 r"(?i)\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:,\s*\d{4})?\b",
81 )
82 .expect("natural_date regex"),
83 ip_v4: Regex::new(
84 r"\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b",
85 )
86 .expect("ipv4 regex"),
87 }
88 }
89}
90
91lazy_static::lazy_static! {
92 static ref RULE_PATTERNS: RulePatterns = RulePatterns::new();
93}
94
95pub fn rule_based_extract(text: &str) -> Vec<ExtractedEntity> {
99 let mut entities: Vec<ExtractedEntity> = Vec::new();
100
101 let push = |entities: &mut Vec<ExtractedEntity>, entity_type: &str, m: regex::Match| {
102 entities.push(ExtractedEntity {
103 entity_type: entity_type.to_string(),
104 value: m.as_str().to_string(),
105 score: 1.0,
106 start: m.start(),
107 end: m.end(),
108 });
109 };
110
111 for m in RULE_PATTERNS.email.find_iter(text) {
113 push(&mut entities, "email", m);
114 }
115 for m in RULE_PATTERNS.url.find_iter(text) {
116 if !entities.iter().any(|e| e.start == m.start()) {
118 push(&mut entities, "url", m);
119 }
120 }
121 for m in RULE_PATTERNS.uuid.find_iter(text) {
122 push(&mut entities, "uuid", m);
123 }
124 for m in RULE_PATTERNS.iso_date.find_iter(text) {
125 push(&mut entities, "date", m);
126 }
127 for m in RULE_PATTERNS.natural_date.find_iter(text) {
128 if !entities
130 .iter()
131 .any(|e| e.start == m.start() && e.entity_type == "date")
132 {
133 push(&mut entities, "date", m);
134 }
135 }
136 for m in RULE_PATTERNS.ip_v4.find_iter(text) {
137 push(&mut entities, "ip", m);
138 }
139
140 entities
141}
142
143const GLINER_MODEL_REPO: &str = "onnx-community/gliner-medium-v2.1";
148const GLINER_TOKENIZER_REPO: &str = "knowledgator/gliner-medium-v2.1";
149const GLINER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
150
151const MAX_SPAN_WIDTH: usize = 12;
153const SCORE_THRESHOLD: f32 = 0.5;
155
156pub struct GlinerEngine {
160 session: Arc<Mutex<Session>>,
161 tokenizer: Arc<Tokenizer>,
162}
163
164impl GlinerEngine {
165 #[instrument(skip_all)]
167 pub async fn new(num_threads: Option<usize>) -> Result<Self> {
168 let threads = num_threads.unwrap_or(1);
169 info!("Initializing GLiNER NER engine (threads={})", threads);
170
171 let (tokenizer_path, onnx_path) = Self::download_model_files().await?;
172
173 let tokenizer = Tokenizer::from_file(&tokenizer_path)
174 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
175
176 let session = Session::builder()
177 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
178 .with_optimization_level(GraphOptimizationLevel::Level3)
179 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
180 .with_intra_threads(threads)
181 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
182 .commit_from_file(&onnx_path)
183 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
184
185 info!("GLiNER engine ready");
186 Ok(Self {
187 session: Arc::new(Mutex::new(session)),
188 tokenizer: Arc::new(tokenizer),
189 })
190 }
191
192 pub async fn extract(&self, text: &str, entity_types: &[&str]) -> Result<Vec<ExtractedEntity>> {
196 if entity_types.is_empty() || text.is_empty() {
197 return Ok(Vec::new());
198 }
199
200 let text_owned = text.to_string();
201 let entity_types_owned: Vec<String> = entity_types.iter().map(|s| s.to_string()).collect();
202 let session = self.session.clone();
203 let tokenizer = self.tokenizer.clone();
204
205 tokio::task::spawn_blocking(move || {
206 Self::run_inference(
207 &text_owned,
208 &entity_types_owned
209 .iter()
210 .map(|s| s.as_str())
211 .collect::<Vec<_>>(),
212 &session,
213 &tokenizer,
214 )
215 })
216 .await
217 .map_err(|e| InferenceError::HubError(format!("GLiNER inference task panicked: {}", e)))?
218 }
219
220 fn run_inference(
221 text: &str,
222 entity_types: &[&str],
223 session: &Arc<Mutex<Session>>,
224 tokenizer: &Tokenizer,
225 ) -> Result<Vec<ExtractedEntity>> {
226 let mut full_text = entity_types.join(" << >> ");
229 full_text.push_str(" << >> ");
230 full_text.push_str(text);
231
232 let encoding = tokenizer
234 .encode(full_text.as_str(), true)
235 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
236
237 let token_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
238 let attention_mask: Vec<i64> = encoding
239 .get_attention_mask()
240 .iter()
241 .map(|&x| x as i64)
242 .collect();
243 let seq_len = token_ids.len();
244
245 let word_ids = encoding.get_word_ids();
249
250 let mut words_mask = vec![0i64; seq_len];
251 let mut last_word_id: Option<u32> = None;
252 let mut text_token_start = usize::MAX;
253
254 let prefix = entity_types.join(" << >> ");
260 let prefix_plus_sep = format!("{} << >> ", prefix);
261 let prefix_encoding = tokenizer
262 .encode(prefix_plus_sep.as_str(), false)
263 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
264 let prefix_word_count = prefix_encoding
265 .get_word_ids()
266 .iter()
267 .filter_map(|&w| w)
268 .collect::<std::collections::HashSet<_>>()
269 .len();
270
271 let mut text_word_count = 0usize;
272 for (i, &wid_opt) in word_ids.iter().enumerate() {
273 let wid = match wid_opt {
274 Some(w) => w,
275 None => {
276 last_word_id = None;
277 continue;
278 }
279 };
280 let is_new_word = last_word_id.map(|lw| lw != wid).unwrap_or(true);
281 if is_new_word {
282 let global_word_idx = {
283 word_ids[..i]
285 .iter()
286 .filter_map(|&w| w)
287 .collect::<std::collections::HashSet<_>>()
288 .len()
289 };
290 if global_word_idx >= prefix_word_count {
291 if text_token_start == usize::MAX {
292 text_token_start = i;
293 }
294 words_mask[i] = 1;
295 text_word_count += 1;
296 }
297 }
298 last_word_id = Some(wid);
299 }
300
301 if text_word_count == 0 || text_token_start == usize::MAX {
302 debug!("No text words found after entity type prefix, skipping inference");
303 return Ok(Vec::new());
304 }
305 let text_lengths = vec![text_word_count as i64];
306
307 let mut span_idx_flat: Vec<i64> = Vec::new();
310 let mut span_mask: Vec<bool> = Vec::new();
311
312 for start in 0..text_word_count {
313 for end in start..text_word_count.min(start + MAX_SPAN_WIDTH) {
314 span_idx_flat.push(start as i64);
315 span_idx_flat.push(end as i64);
316 span_mask.push(true);
317 }
318 }
319
320 let num_spans = span_mask.len();
321 if num_spans == 0 {
322 return Ok(Vec::new());
323 }
324
325 let span_mask_values: Vec<i64> = span_mask
327 .iter()
328 .map(|&b| if b { 1i64 } else { 0 })
329 .collect();
330
331 let logits_raw: Vec<f32> = {
332 let mut session_guard = session.lock();
333
334 let input_ids_t = Tensor::<i64>::from_array(([1usize, seq_len], token_ids))
336 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
337 let attn_mask_t = Tensor::<i64>::from_array(([1usize, seq_len], attention_mask))
338 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
339 let words_mask_t = Tensor::<i64>::from_array(([1usize, seq_len], words_mask))
340 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
341 let text_lengths_t = Tensor::<i64>::from_array(([1usize], text_lengths))
342 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
343 let span_idx_t = Tensor::<i64>::from_array(([1usize, num_spans, 2], span_idx_flat))
344 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
345 let span_mask_t = Tensor::<i64>::from_array(([1usize, num_spans], span_mask_values))
346 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
347
348 let outputs = session_guard
349 .run(inputs![
350 "input_ids" => input_ids_t,
351 "attention_mask" => attn_mask_t,
352 "words_mask" => words_mask_t,
353 "text_lengths" => text_lengths_t,
354 "span_idx" => span_idx_t,
355 "span_mask" => span_mask_t,
356 ])
357 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
358
359 let (_shape, logits_slice) = outputs[0]
361 .try_extract_tensor::<f32>()
362 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
363 logits_slice.to_vec()
364 };
365
366 let num_entity_types = entity_types.len();
368 let expected = num_spans * num_entity_types;
369 if logits_raw.len() != expected {
370 warn!(
371 "GLiNER logits shape mismatch: got {}, expected {}",
372 logits_raw.len(),
373 expected
374 );
375 return Ok(Vec::new());
376 }
377
378 let mut raw_entities: Vec<(usize, usize, usize, f32)> = Vec::new(); for (span_i, (start_w, end_w)) in iter_spans(text_word_count).enumerate() {
382 for (type_i, _entity_type) in entity_types.iter().enumerate() {
383 let logit = logits_raw[span_i * num_entity_types + type_i];
384 let score = sigmoid(logit);
385 if score >= SCORE_THRESHOLD {
386 raw_entities.push((type_i, start_w, end_w, score));
387 }
388 }
389 }
390
391 raw_entities.sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal));
393 let mut kept: Vec<(usize, usize, usize, f32)> = Vec::new();
394 'outer: for candidate in &raw_entities {
395 for kept_span in &kept {
396 if kept_span.0 == candidate.0
398 && kept_span.1 <= candidate.2
399 && candidate.1 <= kept_span.2
400 {
401 continue 'outer;
402 }
403 }
404 kept.push(*candidate);
405 }
406
407 let words: Vec<&str> = text.split_whitespace().collect();
409 let mut word_char_starts: Vec<usize> = Vec::with_capacity(words.len());
410 let mut word_char_ends: Vec<usize> = Vec::with_capacity(words.len());
411 {
412 let mut char_pos = 0usize;
413 for word in &words {
414 if let Some(rel) = text[char_pos..].find(word) {
416 let start = char_pos + rel;
417 let end = start + word.len();
418 word_char_starts.push(start);
419 word_char_ends.push(end);
420 char_pos = end;
421 } else {
422 word_char_starts.push(char_pos);
423 word_char_ends.push(char_pos);
424 }
425 }
426 }
427
428 let mut entities: Vec<ExtractedEntity> = kept
429 .into_iter()
430 .filter_map(|(type_i, start_w, end_w, score)| {
431 let start_char = *word_char_starts.get(start_w)?;
432 let end_char = *word_char_ends.get(end_w)?;
433 let value = text[start_char..end_char].to_string();
434 Some(ExtractedEntity {
435 entity_type: entity_types[type_i].to_lowercase().replace(' ', "_"),
436 value,
437 score,
438 start: start_char,
439 end: end_char,
440 })
441 })
442 .collect();
443
444 entities.sort_by_key(|e| e.start);
445 debug!("GLiNER extracted {} entities", entities.len());
446 Ok(entities)
447 }
448
449 #[instrument(skip_all)]
452 async fn download_model_files() -> Result<(PathBuf, PathBuf)> {
453 info!(
454 "Resolving GLiNER model files: tokenizer={}, onnx={}",
455 GLINER_TOKENIZER_REPO, GLINER_MODEL_REPO
456 );
457
458 let tokenizer_cache = Self::model_cache_dir(GLINER_TOKENIZER_REPO)?;
459 let onnx_cache = Self::model_cache_dir(GLINER_MODEL_REPO)?;
460 let onnx_subdir = onnx_cache.join("onnx");
461 std::fs::create_dir_all(&onnx_subdir)?;
462
463 let local_tokenizer = tokenizer_cache.join("tokenizer.json");
464 let local_onnx = onnx_subdir.join("model_quantized.onnx");
465
466 if !local_tokenizer.exists() || !local_onnx.exists() {
467 let tok_cache = tokenizer_cache.clone();
468 let onnx_c = onnx_cache.clone();
469 let tok_exists = local_tokenizer.exists();
470 let onnx_exists = local_onnx.exists();
471
472 tokio::task::spawn_blocking(move || {
473 if !tok_exists {
474 crate::engine::EmbeddingEngine::download_hf_file_pub(
475 GLINER_TOKENIZER_REPO,
476 "tokenizer.json",
477 &tok_cache,
478 )
479 .map_err(|e| {
480 InferenceError::HubError(format!(
481 "Failed to download GLiNER tokenizer: {}",
482 e
483 ))
484 })?;
485 }
486 if !onnx_exists {
487 crate::engine::EmbeddingEngine::download_hf_file_pub(
488 GLINER_MODEL_REPO,
489 GLINER_ONNX_FILE,
490 &onnx_c,
491 )
492 .map_err(|e| {
493 InferenceError::HubError(format!(
494 "Failed to download GLiNER ONNX model: {}",
495 e
496 ))
497 })?;
498 }
499 Ok::<_, InferenceError>(())
500 })
501 .await
502 .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
503 } else {
504 info!("GLiNER model files found in local cache");
505 }
506
507 let final_onnx = onnx_cache.join(GLINER_ONNX_FILE);
508 Ok((local_tokenizer, final_onnx))
509 }
510
511 fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
512 let base = std::env::var("HF_HOME")
513 .map(PathBuf::from)
514 .unwrap_or_else(|_| {
515 let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
516 PathBuf::from(home).join(".cache").join("huggingface")
517 });
518 let dir = base.join("dakera").join(model_id.replace('/', "--"));
519 std::fs::create_dir_all(&dir)?;
520 Ok(dir)
521 }
522}
523
524pub struct NerEngine {
530 gliner: Option<Arc<GlinerEngine>>,
531}
532
533impl NerEngine {
534 pub fn rule_based_only() -> Self {
536 Self { gliner: None }
537 }
538
539 pub async fn with_gliner(num_threads: Option<usize>) -> Result<Self> {
541 let gliner = GlinerEngine::new(num_threads).await?;
542 Ok(Self {
543 gliner: Some(Arc::new(gliner)),
544 })
545 }
546
547 pub async fn extract(&self, text: &str, gliner_types: &[&str]) -> Vec<ExtractedEntity> {
553 let mut entities = rule_based_extract(text);
554
555 if let Some(ref gliner) = self.gliner {
556 if !gliner_types.is_empty() {
557 match gliner.extract(text, gliner_types).await {
558 Ok(neural) => {
559 for ne in neural {
560 if !entities
562 .iter()
563 .any(|e| e.start == ne.start && e.end == ne.end)
564 {
565 entities.push(ne);
566 }
567 }
568 }
569 Err(e) => {
570 warn!("GLiNER extraction failed, using rule-based only: {}", e);
571 }
572 }
573 }
574 }
575
576 entities.sort_by_key(|e| e.start);
577 entities
578 }
579}
580
581fn iter_spans(num_words: usize) -> impl Iterator<Item = (usize, usize)> {
587 (0..num_words).flat_map(move |start| {
588 let max_end = num_words.min(start + MAX_SPAN_WIDTH);
589 (start..max_end).map(move |end| (start, end))
590 })
591}
592
593#[inline]
595fn sigmoid(x: f32) -> f32 {
596 if x >= 0.0 {
597 1.0 / (1.0 + (-x).exp())
598 } else {
599 let ex = x.exp();
600 ex / (1.0 + ex)
601 }
602}
603
604#[cfg(test)]
609mod tests {
610 use super::*;
611
612 #[test]
613 fn test_rule_based_uuid() {
614 let text = "session id is 550e8400-e29b-41d4-a716-446655440000 here";
615 let entities = rule_based_extract(text);
616 assert!(entities.iter().any(|e| e.entity_type == "uuid"));
617 }
618
619 #[test]
620 fn test_rule_based_url() {
621 let text = "check https://example.com/path?q=1 for details";
622 let entities = rule_based_extract(text);
623 assert!(entities.iter().any(|e| e.entity_type == "url"));
624 }
625
626 #[test]
627 fn test_rule_based_email() {
628 let text = "contact alice@example.com for support";
629 let entities = rule_based_extract(text);
630 assert!(entities.iter().any(|e| e.entity_type == "email"));
631 assert!(!entities.iter().any(|e| e.entity_type == "url"));
633 }
634
635 #[test]
636 fn test_rule_based_iso_date() {
637 let text = "released on 2024-03-15 at noon";
638 let entities = rule_based_extract(text);
639 assert!(entities
640 .iter()
641 .any(|e| e.entity_type == "date" && e.value == "2024-03-15"));
642 }
643
644 #[test]
645 fn test_rule_based_natural_date() {
646 let text = "meeting on March 15, 2024 at noon";
647 let entities = rule_based_extract(text);
648 assert!(entities.iter().any(|e| e.entity_type == "date"));
649 }
650
651 #[test]
652 fn test_entity_to_tag() {
653 let e = ExtractedEntity {
654 entity_type: "person".to_string(),
655 value: "Alice Smith".to_string(),
656 score: 0.9,
657 start: 0,
658 end: 11,
659 };
660 assert_eq!(e.to_tag(), "entity:person:Alice Smith");
661 }
662
663 #[test]
664 fn test_entity_to_tag_colon_escaping() {
665 let e = ExtractedEntity {
666 entity_type: "url".to_string(),
667 value: "http://example.com:8080/path".to_string(),
668 score: 1.0,
669 start: 0,
670 end: 27,
671 };
672 let tag = e.to_tag();
673 let parts: Vec<&str> = tag.splitn(3, ':').collect();
676 assert_eq!(parts.len(), 3, "tag should have 3 parts: {}", tag);
677 assert_eq!(parts[0], "entity");
678 assert_eq!(parts[1], "url");
679 assert!(
680 !parts[2].contains(':'),
681 "value should not contain colons: {}",
682 parts[2]
683 );
684 }
685
686 #[test]
687 fn test_sigmoid() {
688 assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
689 assert!((sigmoid(100.0) - 1.0).abs() < 1e-4);
690 assert!((sigmoid(-100.0) - 0.0).abs() < 1e-4);
691 }
692}