1#![allow(missing_docs)] #![allow(dead_code)] #![allow(clippy::type_complexity)] #![allow(clippy::manual_contains)] #![allow(unused_variables)] #![allow(clippy::items_after_test_module)] #![allow(unused_imports)] #[cfg(feature = "onnx")]
35use crate::sync::{lock, try_lock, Mutex};
36use crate::{Entity, Error, Result};
37use anno_core::{EntityCategory, EntityType};
38
39const TOKEN_START: u32 = 1;
41const TOKEN_END: u32 = 2;
42const TOKEN_ENT: u32 = 128002;
43const TOKEN_SEP: u32 = 128003;
44
45const MAX_SPAN_WIDTH: usize = 12;
47
48#[cfg(feature = "onnx")]
50#[derive(Debug, Clone)]
51pub struct GLiNERConfig {
52 pub prefer_quantized: bool,
54 pub optimization_level: u8,
56 pub num_threads: usize,
58 pub prompt_cache_size: usize,
64}
65
66#[cfg(feature = "onnx")]
67impl Default for GLiNERConfig {
68 fn default() -> Self {
69 Self {
70 prefer_quantized: true,
71 optimization_level: 3,
72 num_threads: 4,
73 prompt_cache_size: 100,
74 }
75 }
76}
77
78#[cfg(feature = "onnx")]
83#[derive(Debug, Clone, PartialEq, Eq, Hash)]
84struct PromptCacheKey {
85 text_hash: u64,
86 entity_types_hash: u64,
87 model_id: String,
88}
89
90#[cfg(feature = "onnx")]
92#[derive(Debug, Clone)]
93struct PromptCacheValue {
94 input_ids: Vec<i64>,
95 attention_mask: Vec<i64>,
96 words_mask: Vec<i64>,
97 text_lengths: i64,
98 entity_count: usize,
99}
100
101#[cfg(feature = "onnx")]
105#[derive(Debug)]
106pub struct GLiNEROnnx {
107 session: Mutex<ort::session::Session>,
108 tokenizer: std::sync::Arc<tokenizers::Tokenizer>,
110 model_name: String,
112 is_quantized: bool,
114 prompt_cache: Option<Mutex<lru::LruCache<PromptCacheKey, PromptCacheValue>>>,
116}
117
118#[cfg(feature = "onnx")]
119impl GLiNEROnnx {
120 pub fn new(model_name: &str) -> Result<Self> {
122 Self::with_config(model_name, GLiNERConfig::default())
123 }
124
125 pub fn with_config(model_name: &str, config: GLiNERConfig) -> Result<Self> {
145 use hf_hub::api::sync::{Api, ApiBuilder};
146 use ort::execution_providers::CPUExecutionProvider;
147 use ort::session::builder::GraphOptimizationLevel;
148 use ort::session::Session;
149
150 crate::env::load_dotenv();
152
153 let api = if let Some(token) = crate::env::hf_token() {
154 ApiBuilder::new()
155 .with_token(Some(token))
156 .build()
157 .map_err(|e| Error::Retrieval(format!("HuggingFace API with token: {}", e)))?
158 } else {
159 Api::new().map_err(|e| {
160 Error::Retrieval(format!("Failed to initialize HuggingFace API: {}", e))
161 })?
162 };
163
164 let repo = api.model(model_name.to_string());
165
166 let (model_path, is_quantized) = if config.prefer_quantized {
168 if let Ok(path) = repo.get("onnx/model_quantized.onnx") {
170 log::info!("[GLiNER] Using quantized model (INT8)");
171 (path, true)
172 } else if let Ok(path) = repo.get("model_quantized.onnx") {
173 log::info!("[GLiNER] Using quantized model (INT8)");
174 (path, true)
175 } else if let Ok(path) = repo.get("onnx/model_int8.onnx") {
176 log::info!("[GLiNER] Using INT8 quantized model");
177 (path, true)
178 } else {
179 let path = repo
181 .get("onnx/model.onnx")
182 .or_else(|_| repo.get("model.onnx"))
183 .map_err(|e| {
184 Error::Retrieval(format!("Failed to download model.onnx: {}", e))
185 })?;
186 log::info!("[GLiNER] Using FP32 model (quantized not available)");
187 (path, false)
188 }
189 } else {
190 let path = repo
191 .get("onnx/model.onnx")
192 .or_else(|_| repo.get("model.onnx"))
193 .map_err(|e| Error::Retrieval(format!("Failed to download model.onnx: {}", e)))?;
194 (path, false)
195 };
196
197 let tokenizer_path = repo
198 .get("tokenizer.json")
199 .map_err(|e| Error::Retrieval(format!("Failed to download tokenizer.json: {}", e)))?;
200
201 let opt_level = match config.optimization_level {
203 1 => GraphOptimizationLevel::Level1,
204 2 => GraphOptimizationLevel::Level2,
205 _ => GraphOptimizationLevel::Level3,
206 };
207
208 let mut builder = Session::builder()
209 .map_err(|e| Error::Retrieval(format!("Failed to create ONNX session builder: {}", e)))?
210 .with_optimization_level(opt_level)
211 .map_err(|e| Error::Retrieval(format!("Failed to set optimization level: {}", e)))?
212 .with_execution_providers([CPUExecutionProvider::default().build()])
213 .map_err(|e| Error::Retrieval(format!("Failed to set execution providers: {}", e)))?;
214
215 if config.num_threads > 0 {
216 builder = builder
217 .with_intra_threads(config.num_threads)
218 .map_err(|e| Error::Retrieval(format!("Failed to set threads: {}", e)))?;
219 }
220
221 let session = builder
222 .commit_from_file(&model_path)
223 .map_err(|e| Error::Retrieval(format!("Failed to load ONNX model: {}", e)))?;
224
225 let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
226 .map_err(|e| Error::Retrieval(format!("Failed to load tokenizer: {}", e)))?;
227
228 log::debug!("[GLiNER] Model loaded");
229
230 let prompt_cache = if config.prompt_cache_size > 0 {
232 use lru::LruCache;
233 use std::num::NonZeroUsize;
234 Some(Mutex::new(LruCache::new(
235 NonZeroUsize::new(config.prompt_cache_size).expect("prompt_cache_size must be > 0"),
236 )))
237 } else {
238 None
239 };
240
241 Ok(Self {
242 session: Mutex::new(session),
243 tokenizer: std::sync::Arc::new(tokenizer),
244 model_name: model_name.to_string(),
245 is_quantized,
246 prompt_cache,
247 })
248 }
249
250 #[must_use]
252 pub fn is_quantized(&self) -> bool {
253 self.is_quantized
254 }
255
256 #[must_use]
258 pub fn tokenizer(&self) -> std::sync::Arc<tokenizers::Tokenizer> {
259 std::sync::Arc::clone(&self.tokenizer)
260 }
261
262 pub fn model_name(&self) -> &str {
264 &self.model_name
265 }
266
267 pub fn extract(
281 &self,
282 text: &str,
283 entity_types: &[&str],
284 threshold: f32,
285 ) -> Result<Vec<Entity>> {
286 if text.is_empty() || entity_types.is_empty() {
287 return Ok(vec![]);
288 }
289
290 let text_words: Vec<&str> = text.split_whitespace().collect();
292 let num_text_words = text_words.len();
293
294 if num_text_words == 0 {
295 return Ok(vec![]);
296 }
297
298 let (input_ids, attention_mask, words_mask, text_lengths, entity_count) =
301 self.encode_prompt_cached(&text_words, entity_types)?;
302
303 let (span_idx, span_mask) = self.make_span_tensors(num_text_words);
305
306 use ndarray::{Array2, Array3};
308 use ort::value::Tensor;
309
310 let batch_size = 1;
311 let seq_len = input_ids.len();
312 let num_spans = num_text_words.checked_mul(MAX_SPAN_WIDTH).ok_or_else(|| {
314 Error::InvalidInput(format!(
315 "Span count overflow: {} words * {} MAX_SPAN_WIDTH",
316 num_text_words, MAX_SPAN_WIDTH
317 ))
318 })?;
319
320 let input_ids_array = Array2::from_shape_vec((batch_size, seq_len), input_ids)
321 .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
322 let attention_mask_array = Array2::from_shape_vec((batch_size, seq_len), attention_mask)
323 .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
324 let words_mask_array = Array2::from_shape_vec((batch_size, seq_len), words_mask)
325 .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
326 let text_lengths_array =
327 Array2::from_shape_vec((batch_size, 1), vec![num_text_words as i64])
328 .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
329 let span_idx_array = Array3::from_shape_vec((batch_size, num_spans, 2), span_idx)
330 .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
331 let span_mask_array = Array2::from_shape_vec((batch_size, num_spans), span_mask)
332 .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
333
334 let input_ids_t = super::ort_compat::tensor_from_ndarray(input_ids_array)
335 .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
336 let attention_mask_t = super::ort_compat::tensor_from_ndarray(attention_mask_array)
337 .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
338 let words_mask_t = super::ort_compat::tensor_from_ndarray(words_mask_array)
339 .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
340 let text_lengths_t = super::ort_compat::tensor_from_ndarray(text_lengths_array)
341 .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
342 let span_idx_t = super::ort_compat::tensor_from_ndarray(span_idx_array)
343 .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
344 let span_mask_t = super::ort_compat::tensor_from_ndarray(span_mask_array)
345 .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
346
347 let mut session = lock(&self.session);
349
350 let outputs = session
351 .run(ort::inputs![
352 "input_ids" => input_ids_t.into_dyn(),
353 "attention_mask" => attention_mask_t.into_dyn(),
354 "words_mask" => words_mask_t.into_dyn(),
355 "text_lengths" => text_lengths_t.into_dyn(),
356 "span_idx" => span_idx_t.into_dyn(),
357 "span_mask" => span_mask_t.into_dyn(),
358 ])
359 .map_err(|e| Error::Parse(format!("ONNX inference failed: {}", e)))?;
360
361 let entities = self.decode_output(
363 &outputs,
364 text,
365 &text_words,
366 entity_types,
367 entity_count,
368 threshold,
369 )?;
370 drop(outputs);
371 drop(session);
372
373 Ok(entities)
374 }
375
376 fn hash_text(text: &str) -> u64 {
378 use std::collections::hash_map::DefaultHasher;
379 use std::hash::{Hash, Hasher};
380 let mut hasher = DefaultHasher::new();
381 text.hash(&mut hasher);
382 hasher.finish()
383 }
384
385 fn hash_entity_types(entity_types: &[&str]) -> u64 {
387 use std::collections::hash_map::DefaultHasher;
388 use std::hash::{Hash, Hasher};
389 let mut hasher = DefaultHasher::new();
390 let mut sorted: Vec<&str> = entity_types.to_vec();
392 sorted.sort();
393 sorted.hash(&mut hasher);
394 hasher.finish()
395 }
396
397 fn encode_prompt_cached(
408 &self,
409 text_words: &[&str],
410 entity_types: &[&str],
411 ) -> Result<(Vec<i64>, Vec<i64>, Vec<i64>, i64, usize)> {
412 let cache = match &self.prompt_cache {
414 Some(c) => c,
415 None => return self.encode_prompt(text_words, entity_types),
416 };
417
418 let text = text_words.join(" ");
420 let text_hash = Self::hash_text(&text);
421 let entity_types_hash = Self::hash_entity_types(entity_types);
422 let key = PromptCacheKey {
423 text_hash,
424 entity_types_hash,
425 model_id: self.model_name.clone(),
426 };
427
428 let cached_result = {
430 let mut cache_guard = try_lock(cache)?;
431 cache_guard.get(&key).cloned()
432 };
433
434 if let Some(cached) = cached_result {
436 return Ok((
437 cached.input_ids,
438 cached.attention_mask,
439 cached.words_mask,
440 cached.text_lengths,
441 cached.entity_count,
442 ));
443 }
444
445 let result = self.encode_prompt(text_words, entity_types)?;
447
448 {
450 let mut cache_guard = try_lock(cache)?;
451 cache_guard.put(
452 key,
453 PromptCacheValue {
454 input_ids: result.0.clone(),
455 attention_mask: result.1.clone(),
456 words_mask: result.2.clone(),
457 text_lengths: result.3,
458 entity_count: result.4,
459 },
460 );
461 }
462
463 Ok(result)
464 }
465
466 pub(crate) fn encode_prompt(
478 &self,
479 text_words: &[&str],
480 entity_types: &[&str],
481 ) -> Result<(Vec<i64>, Vec<i64>, Vec<i64>, i64, usize)> {
482 let mut input_ids: Vec<i64> = Vec::new();
484 let mut word_mask: Vec<i64> = Vec::new();
485
486 input_ids.push(TOKEN_START as i64);
488 word_mask.push(0);
489
490 for entity_type in entity_types {
492 input_ids.push(TOKEN_ENT as i64);
494 word_mask.push(0);
495
496 let encoding = self
499 .tokenizer
500 .encode(entity_type.to_string(), false)
501 .map_err(|e| Error::Parse(format!("Tokenizer error: {}", e)))?;
502 for token_id in encoding.get_ids() {
503 input_ids.push(*token_id as i64);
504 word_mask.push(0);
505 }
506 }
507
508 input_ids.push(TOKEN_SEP as i64);
510 word_mask.push(0);
511
512 let mut word_id: i64 = 0;
514 for word in text_words {
515 let encoding = self
518 .tokenizer
519 .encode(word.to_string(), false)
520 .map_err(|e| Error::Parse(format!("Tokenizer error: {}", e)))?;
521
522 word_id += 1; for (token_idx, token_id) in encoding.get_ids().iter().enumerate() {
525 input_ids.push(*token_id as i64);
526 if token_idx == 0 {
528 word_mask.push(word_id);
529 } else {
530 word_mask.push(0);
531 }
532 }
533 }
534
535 input_ids.push(TOKEN_END as i64);
537 word_mask.push(0);
538
539 let seq_len = input_ids.len();
540 let mut attention_mask = Vec::with_capacity(seq_len);
542 attention_mask.resize(seq_len, 1);
543
544 Ok((
545 input_ids,
546 attention_mask,
547 word_mask,
548 word_id,
549 entity_types.len(),
550 ))
551 }
552
553 fn make_span_tensors(&self, num_words: usize) -> (Vec<i64>, Vec<bool>) {
558 let num_spans = num_words.checked_mul(MAX_SPAN_WIDTH).unwrap_or_else(|| {
560 log::warn!(
561 "Span count overflow: {} words * {} MAX_SPAN_WIDTH, using max",
562 num_words,
563 MAX_SPAN_WIDTH
564 );
565 usize::MAX
566 });
567 let span_idx_len = num_spans.checked_mul(2).unwrap_or_else(|| {
569 log::warn!(
570 "Span idx length overflow: {} spans * 2, using max",
571 num_spans
572 );
573 usize::MAX
574 });
575 let mut span_idx: Vec<i64> = vec![0; span_idx_len];
576 let mut span_mask: Vec<bool> = vec![false; num_spans];
577
578 for start in 0..num_words {
579 let remaining_width = num_words - start;
580 let actual_max_width = MAX_SPAN_WIDTH.min(remaining_width);
581
582 for width in 0..actual_max_width {
583 let dim = match start.checked_mul(MAX_SPAN_WIDTH) {
585 Some(v) => match v.checked_add(width) {
586 Some(d) => d,
587 None => {
588 log::warn!(
589 "Dim calculation overflow: {} * {} + {}, skipping span",
590 start,
591 MAX_SPAN_WIDTH,
592 width
593 );
594 continue;
595 }
596 },
597 None => {
598 log::warn!(
599 "Dim calculation overflow: {} * {}, skipping span",
600 start,
601 MAX_SPAN_WIDTH
602 );
603 continue;
604 }
605 };
606 if let Some(dim2) = dim.checked_mul(2) {
608 if dim2 + 1 < span_idx_len && dim < num_spans {
609 span_idx[dim2] = start as i64; span_idx[dim2 + 1] = (start + width) as i64; span_mask[dim] = true;
612 } else {
613 log::warn!(
614 "Span idx access out of bounds: dim={}, dim*2={}, span_idx_len={}, num_spans={}, skipping",
615 dim, dim2, span_idx_len, num_spans
616 );
617 }
618 } else {
619 log::warn!("Dim * 2 overflow: dim={}, skipping span", dim);
620 }
621 }
622 }
623
624 (span_idx, span_mask)
625 }
626
627 fn decode_output(
631 &self,
632 outputs: &ort::session::SessionOutputs,
633 text: &str,
634 text_words: &[&str],
635 entity_types: &[&str],
636 expected_num_classes: usize,
637 threshold: f32,
638 ) -> Result<Vec<Entity>> {
639 let text_char_count = text.chars().count();
642 let output = outputs
644 .iter()
645 .next()
646 .map(|(_, v)| v)
647 .ok_or_else(|| Error::Parse("No output from GLiNER model".to_string()))?;
648
649 let (_, data_slice) = output
651 .try_extract_tensor::<f32>()
652 .map_err(|e| Error::Parse(format!("Failed to extract output tensor: {}", e)))?;
653 let output_data: Vec<f32> = data_slice.to_vec();
654
655 let shape: Vec<i64> = match output.dtype() {
657 ort::value::ValueType::Tensor { shape, .. } => shape.iter().copied().collect(),
658 _ => return Err(Error::Parse("Output is not a tensor".to_string())),
659 };
660
661 log::debug!(
662 "[GLiNER] Output shape: {:?}, data len: {}, expected classes: {}",
663 shape,
664 output_data.len(),
665 expected_num_classes
666 );
667
668 if output_data.is_empty() || shape.iter().any(|&d| d == 0) {
669 return Err(Error::Inference(
670 "GLiNER ONNX returned empty/degenerate output tensor. This usually indicates an incompatible ONNX export for this implementation (shape mismatch or missing dynamic axes).".to_string(),
671 ));
672 }
673
674 let mut entities = Vec::with_capacity(32);
677 let num_text_words = text_words.len();
678
679 if shape.len() == 4 && shape[0] == 1 {
681 let out_num_words = shape[1] as usize;
682 let out_max_width = shape[2] as usize;
683 let num_classes = shape[3] as usize;
684
685 log::debug!(
686 "[GLiNER] Decoding: num_words={}, max_width={}, num_classes={}",
687 out_num_words,
688 out_max_width,
689 num_classes
690 );
691
692 if num_classes == 0 {
693 return Err(Error::Inference(
694 "GLiNER ONNX model produced num_classes=0. This export likely does not support dynamic entity types for the requested schema.".to_string(),
695 ));
696 }
697
698 for word_idx in 0..out_num_words.min(num_text_words) {
700 for width in 0..out_max_width.min(MAX_SPAN_WIDTH) {
701 let end_word = word_idx + width;
702 if end_word >= num_text_words {
703 continue;
704 }
705
706 for class_idx in 0..num_classes.min(entity_types.len()) {
707 let idx = (word_idx * out_max_width * num_classes)
708 + (width * num_classes)
709 + class_idx;
710
711 if idx < output_data.len() {
712 let logit = output_data[idx];
713 let score = 1.0 / (1.0 + (-logit).exp());
715
716 if score >= threshold {
717 let (char_start, char_end) = self.word_span_to_char_offsets(
718 text, text_words, word_idx, end_word,
719 );
720
721 let span_text = extract_char_slice_with_len(
724 text,
725 char_start,
726 char_end,
727 text_char_count,
728 );
729
730 let entity_type_str =
731 entity_types.get(class_idx).unwrap_or(&"OTHER");
732 let entity_type = Self::map_entity_type(entity_type_str);
733
734 entities.push(Entity::new(
735 span_text,
736 entity_type,
737 char_start,
738 char_end,
739 score as f64,
740 ));
741 }
742 }
743 }
744 }
745 }
746 } else if shape.len() == 3 && shape[0] == 1 {
747 let num_spans = shape[1] as usize;
749 let num_classes = shape[2] as usize;
750
751 if num_classes == 0 {
752 return Err(Error::Inference(
753 "GLiNER ONNX model produced num_classes=0. This export likely does not support dynamic entity types for the requested schema.".to_string(),
754 ));
755 }
756
757 for span_idx in 0..num_spans {
758 let word_idx = span_idx / MAX_SPAN_WIDTH;
759 let width = span_idx % MAX_SPAN_WIDTH;
760 let end_word = word_idx + width;
761
762 if word_idx >= num_text_words || end_word >= num_text_words {
763 continue;
764 }
765
766 for class_idx in 0..num_classes.min(entity_types.len()) {
767 let idx = span_idx * num_classes + class_idx;
768 if idx < output_data.len() {
769 let logit = output_data[idx];
770 let score = 1.0 / (1.0 + (-logit).exp());
771
772 if score >= threshold {
773 let (char_start, char_end) = self
774 .word_span_to_char_offsets(text, text_words, word_idx, end_word);
775
776 let span_text = extract_char_slice_with_len(
779 text,
780 char_start,
781 char_end,
782 text_char_count,
783 );
784
785 let entity_type_str = entity_types.get(class_idx).unwrap_or(&"OTHER");
786 let entity_type = Self::map_entity_type(entity_type_str);
787
788 entities.push(Entity::new(
789 span_text,
790 entity_type,
791 char_start,
792 char_end,
793 score as f64,
794 ));
795 }
796 }
797 }
798 }
799 }
800
801 entities.sort_unstable_by(|a, b| {
805 a.start
806 .cmp(&b.start)
807 .then_with(|| b.end.cmp(&a.end))
808 .then_with(|| {
809 b.confidence
810 .partial_cmp(&a.confidence)
811 .unwrap_or(std::cmp::Ordering::Equal)
812 })
813 });
814
815 entities.dedup_by(|a, b| a.start == b.start && a.end == b.end);
817
818 let entities = remove_overlapping_spans(entities);
822
823 let entities = entities
825 .into_iter()
826 .map(|mut e| {
827 while e.text.ends_with(['.', ',', ';', ':', '!', '?']) {
829 e.text.pop();
830 if e.end > e.start {
831 e.end -= 1;
832 }
833 }
834 while e.text.starts_with(['.', ',', ';', ':', '!', '?']) {
836 e.text.remove(0);
837 e.start += 1;
838 }
839
840 if e.entity_type.as_label().eq_ignore_ascii_case("PRODUCT")
846 && looks_like_company_name(&e.text)
847 {
848 e.entity_type = EntityType::Organization;
849 }
850
851 e
852 })
853 .filter(|e| !e.text.is_empty() && e.start < e.end)
854 .collect();
855
856 Ok(entities)
857 }
858
859 fn map_entity_type(type_str: &str) -> EntityType {
861 match type_str.to_lowercase().as_str() {
862 "person" | "per" => EntityType::Person,
863 "organization" | "org" | "company" => EntityType::Organization,
864 "location" | "loc" | "gpe" | "geo-loc" => EntityType::Location,
865 "facility" | "fac" => EntityType::custom("FACILITY", anno_core::EntityCategory::Place),
866 "product" | "prod" => EntityType::custom("PRODUCT", anno_core::EntityCategory::Misc),
867 "misc" | "other" => EntityType::Other("MISC".to_string()),
868 "date" | "time" => EntityType::Date,
869 "money" | "currency" => EntityType::Money,
870 "percent" | "percentage" => EntityType::Percent,
871 other => EntityType::Other(other.to_string()),
872 }
873 }
874
875 fn word_span_to_char_offsets(
880 &self,
881 text: &str,
882 words: &[&str],
883 start_word: usize,
884 end_word: usize,
885 ) -> (usize, usize) {
886 if words.is_empty()
888 || start_word >= words.len()
889 || end_word >= words.len()
890 || start_word > end_word
891 {
892 return (0, 0);
894 }
895
896 let mut byte_pos = 0;
897 let mut start_byte = 0;
898 let mut end_byte = text.len();
899 let mut found_start = false;
900 let mut found_end = false;
901
902 for (idx, word) in words.iter().enumerate() {
903 if let Some(pos) = text[byte_pos..].find(word) {
905 let word_start_byte = byte_pos + pos;
906 let word_end_byte = word_start_byte + word.len();
907
908 if idx == start_word {
909 start_byte = word_start_byte;
910 found_start = true;
911 }
912 if idx == end_word {
913 end_byte = word_end_byte;
914 found_end = true;
915 break;
916 }
917 byte_pos = word_end_byte;
918 } else {
919 }
922 }
923
924 if !found_start || !found_end {
926 (0, 0)
928 } else {
929 crate::offset::bytes_to_chars(text, start_byte, end_byte)
931 }
932 }
933}
934
935fn looks_like_company_name(text: &str) -> bool {
936 let t = text.trim();
938 if t.is_empty() {
939 return false;
940 }
941
942 let lower = t.to_lowercase();
943
944 let suffixes = [
946 " inc",
947 " inc.",
948 " ltd",
949 " ltd.",
950 " llc",
951 " llp",
952 " plc",
953 " co",
954 " co.",
955 " company",
956 " corp",
957 " corp.",
958 " corporation",
959 " gmbh",
960 " s.a.",
961 " sa",
962 ];
963 if suffixes.iter().any(|s| lower.ends_with(s)) {
964 return true;
965 }
966
967 if t.contains("株式会社") || t.contains("有限会社") || t.contains("公司") || t.contains("集团")
969 {
970 return true;
971 }
972
973 if t.contains("شركة") {
975 return true;
976 }
977
978 false
979}
980
981#[cfg(test)]
982mod postprocess_tests {
983 use super::looks_like_company_name;
984
985 #[test]
986 fn test_looks_like_company_name() {
987 assert!(looks_like_company_name("Apple Inc"));
988 assert!(looks_like_company_name("Acme Corp."));
989 assert!(looks_like_company_name("Example GmbH"));
990 assert!(looks_like_company_name("株式会社トヨタ自動車"));
991 assert!(looks_like_company_name("شركة أرامكو"));
992
993 assert!(!looks_like_company_name("Apple"));
994 assert!(!looks_like_company_name("New York"));
995 }
996}
997
998fn extract_char_slice(text: &str, char_start: usize, char_end: usize) -> String {
1007 let text_char_count = text.chars().count();
1010 extract_char_slice_with_len(text, char_start, char_end, text_char_count)
1011}
1012
1013fn extract_char_slice_with_len(
1018 text: &str,
1019 char_start: usize,
1020 char_end: usize,
1021 text_char_count: usize,
1022) -> String {
1023 if char_start >= text_char_count || char_end > text_char_count || char_start >= char_end {
1024 return String::new();
1025 }
1026 text.chars()
1027 .skip(char_start)
1028 .take(char_end.saturating_sub(char_start))
1029 .collect()
1030}
1031
1032#[cfg(feature = "onnx")]
1038const DEFAULT_GLINER_LABELS: &[&str] = &[
1039 "person",
1040 "organization",
1041 "location",
1042 "date",
1043 "time",
1044 "money",
1045 "percent",
1046 "product",
1047 "event",
1048 "facility",
1049 "work_of_art",
1050 "law",
1051 "language",
1052];
1053
1054#[cfg(feature = "onnx")]
1055impl crate::Model for GLiNEROnnx {
1056 fn extract_entities(&self, text: &str, _language: Option<&str>) -> crate::Result<Vec<Entity>> {
1057 self.extract(text, DEFAULT_GLINER_LABELS, 0.5)
1060 }
1061
1062 fn supported_types(&self) -> Vec<anno_core::EntityType> {
1063 DEFAULT_GLINER_LABELS
1065 .iter()
1066 .map(|label| anno_core::EntityType::Custom {
1067 name: (*label).to_string(),
1068 category: EntityCategory::Misc,
1069 })
1070 .collect()
1071 }
1072
1073 fn is_available(&self) -> bool {
1074 true }
1076
1077 fn name(&self) -> &'static str {
1078 "GLiNER-ONNX"
1079 }
1080
1081 fn description(&self) -> &'static str {
1082 "Zero-shot NER using GLiNER with ONNX Runtime backend"
1083 }
1084
1085 fn version(&self) -> String {
1086 format!(
1088 "gliner-onnx-{}-{}",
1089 self.model_name,
1090 if self.is_quantized { "q" } else { "fp32" }
1091 )
1092 }
1093}
1094
1095#[cfg(feature = "onnx")]
1096impl crate::backends::inference::ZeroShotNER for GLiNEROnnx {
1097 fn extract_with_types(
1098 &self,
1099 text: &str,
1100 entity_types: &[&str],
1101 threshold: f32,
1102 ) -> crate::Result<Vec<Entity>> {
1103 self.extract(text, entity_types, threshold)
1104 }
1105
1106 fn extract_with_descriptions(
1107 &self,
1108 text: &str,
1109 descriptions: &[&str],
1110 threshold: f32,
1111 ) -> crate::Result<Vec<Entity>> {
1112 self.extract(text, descriptions, threshold)
1114 }
1115
1116 fn default_types(&self) -> &[&'static str] {
1117 DEFAULT_GLINER_LABELS
1118 }
1119}
1120
1121#[cfg(not(feature = "onnx"))]
1126#[derive(Debug)]
1127pub struct GLiNEROnnx;
1128
1129#[cfg(not(feature = "onnx"))]
1130impl GLiNEROnnx {
1131 pub fn new(_model_name: &str) -> Result<Self> {
1133 Err(Error::InvalidInput(
1134 "GLiNER-ONNX requires the 'onnx' feature. \
1135 Build with: cargo build --features onnx"
1136 .to_string(),
1137 ))
1138 }
1139
1140 pub fn model_name(&self) -> &str {
1142 "gliner-not-enabled"
1143 }
1144
1145 pub fn extract(
1147 &self,
1148 _text: &str,
1149 _entity_types: &[&str],
1150 _threshold: f32,
1151 ) -> Result<Vec<Entity>> {
1152 Err(Error::InvalidInput(
1153 "GLiNER-ONNX requires the 'onnx' feature".to_string(),
1154 ))
1155 }
1156}
1157
1158#[cfg(not(feature = "onnx"))]
1159impl crate::Model for GLiNEROnnx {
1160 fn extract_entities(&self, _text: &str, _language: Option<&str>) -> crate::Result<Vec<Entity>> {
1161 Err(Error::InvalidInput(
1162 "GLiNER-ONNX requires the 'onnx' feature".to_string(),
1163 ))
1164 }
1165
1166 fn supported_types(&self) -> Vec<anno_core::EntityType> {
1167 vec![]
1168 }
1169
1170 fn is_available(&self) -> bool {
1171 false
1172 }
1173
1174 fn name(&self) -> &'static str {
1175 "GLiNER-ONNX (unavailable)"
1176 }
1177
1178 fn description(&self) -> &'static str {
1179 "GLiNER with ONNX Runtime backend - requires 'onnx' feature"
1180 }
1181}
1182
1183#[cfg(not(feature = "onnx"))]
1184impl crate::backends::inference::ZeroShotNER for GLiNEROnnx {
1185 fn extract_with_types(
1186 &self,
1187 _text: &str,
1188 _entity_types: &[&str],
1189 _threshold: f32,
1190 ) -> crate::Result<Vec<Entity>> {
1191 Err(Error::InvalidInput(
1192 "GLiNER-ONNX requires the 'onnx' feature".to_string(),
1193 ))
1194 }
1195
1196 fn extract_with_descriptions(
1197 &self,
1198 _text: &str,
1199 _descriptions: &[&str],
1200 _threshold: f32,
1201 ) -> crate::Result<Vec<Entity>> {
1202 Err(Error::InvalidInput(
1203 "GLiNER-ONNX requires the 'onnx' feature".to_string(),
1204 ))
1205 }
1206}
1207
1208#[cfg(feature = "onnx")]
1213impl crate::BatchCapable for GLiNEROnnx {
1214 fn extract_entities_batch(
1215 &self,
1216 texts: &[&str],
1217 _language: Option<&str>,
1218 ) -> Result<Vec<Vec<Entity>>> {
1219 if texts.is_empty() {
1220 return Ok(Vec::new());
1221 }
1222
1223 let default_types = DEFAULT_GLINER_LABELS;
1227 let threshold = 0.5;
1228
1229 texts
1230 .iter()
1231 .map(|text| self.extract(text, default_types, threshold))
1232 .collect()
1233 }
1234
1235 fn optimal_batch_size(&self) -> Option<usize> {
1236 Some(16)
1237 }
1238}
1239
1240#[cfg(not(feature = "onnx"))]
1241impl crate::BatchCapable for GLiNEROnnx {
1242 fn extract_entities_batch(
1243 &self,
1244 texts: &[&str],
1245 _language: Option<&str>,
1246 ) -> Result<Vec<Vec<Entity>>> {
1247 Err(Error::InvalidInput(
1248 "GLiNER-ONNX requires the 'onnx' feature".to_string(),
1249 ))
1250 }
1251
1252 fn optimal_batch_size(&self) -> Option<usize> {
1253 None
1254 }
1255}
1256
1257fn remove_overlapping_spans(mut entities: Vec<Entity>) -> Vec<Entity> {
1271 if entities.len() <= 1 {
1272 return entities;
1273 }
1274
1275 entities.sort_unstable_by(|a, b| {
1279 let len_a = a.end - a.start;
1280 let len_b = b.end - b.start;
1281 len_a.cmp(&len_b).then_with(|| {
1282 b.confidence
1283 .partial_cmp(&a.confidence)
1284 .unwrap_or(std::cmp::Ordering::Equal)
1285 })
1286 });
1287
1288 let mut result: Vec<Entity> = Vec::with_capacity(entities.len());
1289
1290 for entity in entities {
1291 let is_superset_of_existing = result.iter().any(|kept| {
1294 entity.start <= kept.start && entity.end >= kept.end
1296 });
1297
1298 if is_superset_of_existing {
1299 continue;
1301 }
1302
1303 let overlaps_existing = result.iter().any(|kept| {
1305 let entity_range = entity.start..entity.end;
1306 let kept_range = kept.start..kept.end;
1307 entity_range.start < kept_range.end && kept_range.start < entity_range.end
1309 });
1310
1311 if !overlaps_existing {
1312 result.push(entity);
1313 }
1314 }
1315
1316 result.sort_unstable_by_key(|e| e.start);
1319 result
1320}
1321
1322#[cfg(feature = "onnx")]
1327impl crate::StreamingCapable for GLiNEROnnx {
1328 fn recommended_chunk_size(&self) -> usize {
1329 4096 }
1331}
1332
1333#[cfg(not(feature = "onnx"))]
1334impl crate::StreamingCapable for GLiNEROnnx {
1335 fn recommended_chunk_size(&self) -> usize {
1336 4096
1337 }
1338}