Skip to main content

anno/backends/
gliner_onnx.rs

1//! GLiNER-based NER implementation using ONNX Runtime.
2//!
3//! GLiNER (Generalist and Lightweight Model for Named Entity Recognition) is
4//! a popular approach to “open/zero-shot” NER. This implementation follows the GLiNER prompt format
5//! and common community conventions.
6//!
7//! ## Prompt Format
8//!
9//! GLiNER uses a special prompt format:
10//!
11//! ```text
12//! [START] <<ENT>> type1 <<ENT>> type2 <<SEP>> word1 word2 ... [END]
13//! ```
14//!
15//! Token IDs (for GLiNER tokenizer):
16//! - START = 1
17//! - END = 2
18//! - `<<ENT>>` = 128002
19//! - `<<SEP>>` = 128003
20//!
21//! ## Key Insight
22//!
23//! Each word is encoded SEPARATELY, preserving word boundaries.
24//! Output shape: [batch, num_words, max_width, num_entity_types]
25
26#![allow(missing_docs)] // Stub implementation
27#![allow(dead_code)] // Placeholder constants
28#![allow(clippy::type_complexity)] // Complex return tuples
29#![allow(clippy::manual_contains)] // Shape check style
30#![allow(unused_variables)] // Feature-gated code
31#![allow(clippy::items_after_test_module)] // Large file; keep local tests near helpers
32#![allow(unused_imports)] // EntityType used conditionally
33
34#[cfg(feature = "onnx")]
35use crate::sync::{lock, try_lock, Mutex};
36use crate::{Entity, Error, Result};
37use anno_core::{EntityCategory, EntityType};
38
39/// Special token IDs for GLiNER models
40const TOKEN_START: u32 = 1;
41const TOKEN_END: u32 = 2;
42const TOKEN_ENT: u32 = 128002;
43const TOKEN_SEP: u32 = 128003;
44
45/// Default max span width from GLiNER config
46const MAX_SPAN_WIDTH: usize = 12;
47
48/// Configuration for GLiNER model loading.
49#[cfg(feature = "onnx")]
50#[derive(Debug, Clone)]
51pub struct GLiNERConfig {
52    /// Prefer quantized models (INT8) for faster CPU inference.
53    pub prefer_quantized: bool,
54    /// ONNX optimization level (1-3, default 3).
55    pub optimization_level: u8,
56    /// Number of threads for inference (0 = auto).
57    pub num_threads: usize,
58    /// Cache size for prompt encodings (0 = disabled, default 100).
59    ///
60    /// The prompt cache stores encoded prompts keyed by (text, entity_types, model_id).
61    /// This can materially reduce repeated work in evaluation loops and API usage patterns
62    /// where the same text is queried with multiple type sets.
63    pub prompt_cache_size: usize,
64}
65
66#[cfg(feature = "onnx")]
67impl Default for GLiNERConfig {
68    fn default() -> Self {
69        Self {
70            prefer_quantized: true,
71            optimization_level: 3,
72            num_threads: 4,
73            prompt_cache_size: 100,
74        }
75    }
76}
77
78/// Cache key for prompt encodings.
79///
80/// Keyed by (text_hash, entity_types_hash, model_id) to ensure cache hits
81/// only when text, entity types, and model are identical.
82#[cfg(feature = "onnx")]
83#[derive(Debug, Clone, PartialEq, Eq, Hash)]
84struct PromptCacheKey {
85    text_hash: u64,
86    entity_types_hash: u64,
87    model_id: String,
88}
89
90/// Cached prompt encoding result.
91#[cfg(feature = "onnx")]
92#[derive(Debug, Clone)]
93struct PromptCacheValue {
94    input_ids: Vec<i64>,
95    attention_mask: Vec<i64>,
96    words_mask: Vec<i64>,
97    text_lengths: i64,
98    entity_count: usize,
99}
100
101/// GLiNER model for zero-shot NER.
102///
103/// Thread-safe with `Arc<Tokenizer>` for efficient sharing across threads.
104#[cfg(feature = "onnx")]
105#[derive(Debug)]
106pub struct GLiNEROnnx {
107    session: Mutex<ort::session::Session>,
108    /// Arc-wrapped tokenizer for cheap cloning across threads.
109    tokenizer: std::sync::Arc<tokenizers::Tokenizer>,
110    /// HuggingFace model identifier (e.g., "onnx-community/gliner_small-v2.1").
111    model_name: String,
112    /// Whether a quantized model was loaded.
113    is_quantized: bool,
114    /// LRU cache for prompt encodings (keyed by text + entity types).
115    prompt_cache: Option<Mutex<lru::LruCache<PromptCacheKey, PromptCacheValue>>>,
116}
117
118#[cfg(feature = "onnx")]
119impl GLiNEROnnx {
120    /// Create a new GLiNER model from HuggingFace with default config.
121    pub fn new(model_name: &str) -> Result<Self> {
122        Self::with_config(model_name, GLiNERConfig::default())
123    }
124
125    /// Create a new GLiNER model with custom configuration.
126    ///
127    /// # Arguments
128    ///
129    /// * `model_name` - HuggingFace model ID (e.g., "onnx-community/gliner_small-v2.1")
130    /// * `config` - Configuration for model loading
131    ///
132    /// # Example
133    ///
134    /// ```rust,ignore
135    /// let config = GLiNERConfig {
136    ///     prefer_quantized: true,  // Use INT8 model for 2-4x speedup
137    ///     optimization_level: 3,
138    ///     num_threads: 8,
139    /// };
140    /// let model = GLiNEROnnx::with_config("onnx-community/gliner_small-v2.1", config)?;
141    /// ```
142    ///
143    /// Automatically loads `.env` for HF_TOKEN if present.
144    pub fn with_config(model_name: &str, config: GLiNERConfig) -> Result<Self> {
145        use hf_hub::api::sync::{Api, ApiBuilder};
146        use ort::execution_providers::CPUExecutionProvider;
147        use ort::session::builder::GraphOptimizationLevel;
148        use ort::session::Session;
149
150        // Load .env if present (for HF_TOKEN)
151        crate::env::load_dotenv();
152
153        let api = if let Some(token) = crate::env::hf_token() {
154            ApiBuilder::new()
155                .with_token(Some(token))
156                .build()
157                .map_err(|e| Error::Retrieval(format!("HuggingFace API with token: {}", e)))?
158        } else {
159            Api::new().map_err(|e| {
160                Error::Retrieval(format!("Failed to initialize HuggingFace API: {}", e))
161            })?
162        };
163
164        let repo = api.model(model_name.to_string());
165
166        // Download model - try quantized first if preferred
167        let (model_path, is_quantized) = if config.prefer_quantized {
168            // Try quantized variants first
169            if let Ok(path) = repo.get("onnx/model_quantized.onnx") {
170                log::info!("[GLiNER] Using quantized model (INT8)");
171                (path, true)
172            } else if let Ok(path) = repo.get("model_quantized.onnx") {
173                log::info!("[GLiNER] Using quantized model (INT8)");
174                (path, true)
175            } else if let Ok(path) = repo.get("onnx/model_int8.onnx") {
176                log::info!("[GLiNER] Using INT8 quantized model");
177                (path, true)
178            } else {
179                // Fall back to FP32
180                let path = repo
181                    .get("onnx/model.onnx")
182                    .or_else(|_| repo.get("model.onnx"))
183                    .map_err(|e| {
184                        Error::Retrieval(format!("Failed to download model.onnx: {}", e))
185                    })?;
186                log::info!("[GLiNER] Using FP32 model (quantized not available)");
187                (path, false)
188            }
189        } else {
190            let path = repo
191                .get("onnx/model.onnx")
192                .or_else(|_| repo.get("model.onnx"))
193                .map_err(|e| Error::Retrieval(format!("Failed to download model.onnx: {}", e)))?;
194            (path, false)
195        };
196
197        let tokenizer_path = repo
198            .get("tokenizer.json")
199            .map_err(|e| Error::Retrieval(format!("Failed to download tokenizer.json: {}", e)))?;
200
201        // Build session with optimization settings
202        let opt_level = match config.optimization_level {
203            1 => GraphOptimizationLevel::Level1,
204            2 => GraphOptimizationLevel::Level2,
205            _ => GraphOptimizationLevel::Level3,
206        };
207
208        let mut builder = Session::builder()
209            .map_err(|e| Error::Retrieval(format!("Failed to create ONNX session builder: {}", e)))?
210            .with_optimization_level(opt_level)
211            .map_err(|e| Error::Retrieval(format!("Failed to set optimization level: {}", e)))?
212            .with_execution_providers([CPUExecutionProvider::default().build()])
213            .map_err(|e| Error::Retrieval(format!("Failed to set execution providers: {}", e)))?;
214
215        if config.num_threads > 0 {
216            builder = builder
217                .with_intra_threads(config.num_threads)
218                .map_err(|e| Error::Retrieval(format!("Failed to set threads: {}", e)))?;
219        }
220
221        let session = builder
222            .commit_from_file(&model_path)
223            .map_err(|e| Error::Retrieval(format!("Failed to load ONNX model: {}", e)))?;
224
225        let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
226            .map_err(|e| Error::Retrieval(format!("Failed to load tokenizer: {}", e)))?;
227
228        log::debug!("[GLiNER] Model loaded");
229
230        // Initialize prompt cache if enabled
231        let prompt_cache = if config.prompt_cache_size > 0 {
232            use lru::LruCache;
233            use std::num::NonZeroUsize;
234            Some(Mutex::new(LruCache::new(
235                NonZeroUsize::new(config.prompt_cache_size).expect("prompt_cache_size must be > 0"),
236            )))
237        } else {
238            None
239        };
240
241        Ok(Self {
242            session: Mutex::new(session),
243            tokenizer: std::sync::Arc::new(tokenizer),
244            model_name: model_name.to_string(),
245            is_quantized,
246            prompt_cache,
247        })
248    }
249
250    /// Check if a quantized model was loaded.
251    #[must_use]
252    pub fn is_quantized(&self) -> bool {
253        self.is_quantized
254    }
255
256    /// Get a clone of the tokenizer Arc (cheap).
257    #[must_use]
258    pub fn tokenizer(&self) -> std::sync::Arc<tokenizers::Tokenizer> {
259        std::sync::Arc::clone(&self.tokenizer)
260    }
261
262    /// Get model name.
263    pub fn model_name(&self) -> &str {
264        &self.model_name
265    }
266
267    /// Extract entities from text using GLiNER zero-shot NER.
268    ///
269    /// # Arguments
270    /// * `text` - The text to extract entities from
271    /// * `entity_types` - Entity type labels to detect (e.g., ["person", "organization"])
272    /// * `threshold` - Confidence threshold (0.0-1.0, recommended: 0.5)
273    ///
274    /// # Example
275    ///
276    /// ```rust,ignore
277    /// let gliner = GLiNEROnnx::new("onnx-community/gliner_small-v2.1")?;
278    /// let entities = gliner.extract("John works at Apple", &["person", "organization"], 0.5)?;
279    /// ```
280    pub fn extract(
281        &self,
282        text: &str,
283        entity_types: &[&str],
284        threshold: f32,
285    ) -> Result<Vec<Entity>> {
286        if text.is_empty() || entity_types.is_empty() {
287            return Ok(vec![]);
288        }
289
290        // Split text into words (this implementation uses whitespace splitting)
291        let text_words: Vec<&str> = text.split_whitespace().collect();
292        let num_text_words = text_words.len();
293
294        if num_text_words == 0 {
295            return Ok(vec![]);
296        }
297
298        // Encode input following the GLiNER prompt format: word-by-word encoding
299        // Use cached version if cache is enabled
300        let (input_ids, attention_mask, words_mask, text_lengths, entity_count) =
301            self.encode_prompt_cached(&text_words, entity_types)?;
302
303        // Generate span tensors
304        let (span_idx, span_mask) = self.make_span_tensors(num_text_words);
305
306        // Build ort tensors
307        use ndarray::{Array2, Array3};
308        use ort::value::Tensor;
309
310        let batch_size = 1;
311        let seq_len = input_ids.len();
312        // Use checked_mul to prevent overflow (same pattern as gliner2.rs:2388)
313        let num_spans = num_text_words.checked_mul(MAX_SPAN_WIDTH).ok_or_else(|| {
314            Error::InvalidInput(format!(
315                "Span count overflow: {} words * {} MAX_SPAN_WIDTH",
316                num_text_words, MAX_SPAN_WIDTH
317            ))
318        })?;
319
320        let input_ids_array = Array2::from_shape_vec((batch_size, seq_len), input_ids)
321            .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
322        let attention_mask_array = Array2::from_shape_vec((batch_size, seq_len), attention_mask)
323            .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
324        let words_mask_array = Array2::from_shape_vec((batch_size, seq_len), words_mask)
325            .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
326        let text_lengths_array =
327            Array2::from_shape_vec((batch_size, 1), vec![num_text_words as i64])
328                .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
329        let span_idx_array = Array3::from_shape_vec((batch_size, num_spans, 2), span_idx)
330            .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
331        let span_mask_array = Array2::from_shape_vec((batch_size, num_spans), span_mask)
332            .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
333
334        let input_ids_t = super::ort_compat::tensor_from_ndarray(input_ids_array)
335            .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
336        let attention_mask_t = super::ort_compat::tensor_from_ndarray(attention_mask_array)
337            .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
338        let words_mask_t = super::ort_compat::tensor_from_ndarray(words_mask_array)
339            .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
340        let text_lengths_t = super::ort_compat::tensor_from_ndarray(text_lengths_array)
341            .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
342        let span_idx_t = super::ort_compat::tensor_from_ndarray(span_idx_array)
343            .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
344        let span_mask_t = super::ort_compat::tensor_from_ndarray(span_mask_array)
345            .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
346
347        // Run inference with blocking lock for thread-safe parallel access
348        let mut session = lock(&self.session);
349
350        let outputs = session
351            .run(ort::inputs![
352                "input_ids" => input_ids_t.into_dyn(),
353                "attention_mask" => attention_mask_t.into_dyn(),
354                "words_mask" => words_mask_t.into_dyn(),
355                "text_lengths" => text_lengths_t.into_dyn(),
356                "span_idx" => span_idx_t.into_dyn(),
357                "span_mask" => span_mask_t.into_dyn(),
358            ])
359            .map_err(|e| Error::Parse(format!("ONNX inference failed: {}", e)))?;
360
361        // Decode output
362        let entities = self.decode_output(
363            &outputs,
364            text,
365            &text_words,
366            entity_types,
367            entity_count,
368            threshold,
369        )?;
370        drop(outputs);
371        drop(session);
372
373        Ok(entities)
374    }
375
376    /// Hash text for cache key.
377    fn hash_text(text: &str) -> u64 {
378        use std::collections::hash_map::DefaultHasher;
379        use std::hash::{Hash, Hasher};
380        let mut hasher = DefaultHasher::new();
381        text.hash(&mut hasher);
382        hasher.finish()
383    }
384
385    /// Hash entity types for cache key (sorted for consistency).
386    fn hash_entity_types(entity_types: &[&str]) -> u64 {
387        use std::collections::hash_map::DefaultHasher;
388        use std::hash::{Hash, Hasher};
389        let mut hasher = DefaultHasher::new();
390        // Sort entity types for consistent hashing regardless of input order
391        let mut sorted: Vec<&str> = entity_types.to_vec();
392        sorted.sort();
393        sorted.hash(&mut hasher);
394        hasher.finish()
395    }
396
397    /// Encode prompt with LRU caching for performance.
398    ///
399    /// Caches the result of `encode_prompt` keyed by (text_hash, entity_types_hash, model_id).
400    /// This provides significant speedup when the same text is queried with different entity types
401    /// (common in evaluation loops).
402    ///
403    /// # Lock Strategy
404    ///
405    /// The lock is dropped before the expensive `encode_prompt` operation to avoid blocking
406    /// other threads. This allows concurrent cache lookups while encoding proceeds.
407    fn encode_prompt_cached(
408        &self,
409        text_words: &[&str],
410        entity_types: &[&str],
411    ) -> Result<(Vec<i64>, Vec<i64>, Vec<i64>, i64, usize)> {
412        // If cache is disabled, use direct encoding
413        let cache = match &self.prompt_cache {
414            Some(c) => c,
415            None => return self.encode_prompt(text_words, entity_types),
416        };
417
418        // Build cache key
419        let text = text_words.join(" ");
420        let text_hash = Self::hash_text(&text);
421        let entity_types_hash = Self::hash_entity_types(entity_types);
422        let key = PromptCacheKey {
423            text_hash,
424            entity_types_hash,
425            model_id: self.model_name.clone(),
426        };
427
428        // Check cache (lock scope minimized)
429        let cached_result = {
430            let mut cache_guard = try_lock(cache)?;
431            cache_guard.get(&key).cloned()
432        };
433
434        // Cache hit: return immediately
435        if let Some(cached) = cached_result {
436            return Ok((
437                cached.input_ids,
438                cached.attention_mask,
439                cached.words_mask,
440                cached.text_lengths,
441                cached.entity_count,
442            ));
443        }
444
445        // Cache miss: compute encoding (lock is dropped, allowing other threads to proceed)
446        let result = self.encode_prompt(text_words, entity_types)?;
447
448        // Store in cache (re-acquire lock)
449        {
450            let mut cache_guard = try_lock(cache)?;
451            cache_guard.put(
452                key,
453                PromptCacheValue {
454                    input_ids: result.0.clone(),
455                    attention_mask: result.1.clone(),
456                    words_mask: result.2.clone(),
457                    text_lengths: result.3,
458                    entity_count: result.4,
459                },
460            );
461        }
462
463        Ok(result)
464    }
465
466    /// Encode prompt following the GLiNER prompt format: word-by-word encoding.
467    ///
468    /// Structure: [START] <<ENT>> type1 <<ENT>> type2 <<SEP>> word1 word2 ... [END]
469    ///
470    /// # Performance
471    ///
472    /// This method performs tokenization and encoding, which can be expensive.
473    /// Consider caching the result if the same (text, entity_types) combination
474    /// is queried multiple times.
475    ///
476    /// For cached encoding, use `encode_prompt_cached` instead.
477    pub(crate) fn encode_prompt(
478        &self,
479        text_words: &[&str],
480        entity_types: &[&str],
481    ) -> Result<(Vec<i64>, Vec<i64>, Vec<i64>, i64, usize)> {
482        // Build token sequence word by word
483        let mut input_ids: Vec<i64> = Vec::new();
484        let mut word_mask: Vec<i64> = Vec::new();
485
486        // Add start token
487        input_ids.push(TOKEN_START as i64);
488        word_mask.push(0);
489
490        // Add entity types: <<ENT>> type1 <<ENT>> type2 ...
491        for entity_type in entity_types {
492            // Add <<ENT>> token
493            input_ids.push(TOKEN_ENT as i64);
494            word_mask.push(0);
495
496            // Encode entity type word(s)
497            // Note: tokenizers::Tokenizer::encode requires String, not &str
498            let encoding = self
499                .tokenizer
500                .encode(entity_type.to_string(), false)
501                .map_err(|e| Error::Parse(format!("Tokenizer error: {}", e)))?;
502            for token_id in encoding.get_ids() {
503                input_ids.push(*token_id as i64);
504                word_mask.push(0);
505            }
506        }
507
508        // Add <<SEP>> token
509        input_ids.push(TOKEN_SEP as i64);
510        word_mask.push(0);
511
512        // Add text words (this is where word_mask starts counting from 1)
513        let mut word_id: i64 = 0;
514        for word in text_words {
515            // Encode word
516            // Note: tokenizers::Tokenizer::encode requires String, not &str
517            let encoding = self
518                .tokenizer
519                .encode(word.to_string(), false)
520                .map_err(|e| Error::Parse(format!("Tokenizer error: {}", e)))?;
521
522            word_id += 1; // Increment before first token of word
523
524            for (token_idx, token_id) in encoding.get_ids().iter().enumerate() {
525                input_ids.push(*token_id as i64);
526                // First subword token gets the word ID, rest get 0
527                if token_idx == 0 {
528                    word_mask.push(word_id);
529                } else {
530                    word_mask.push(0);
531                }
532            }
533        }
534
535        // Add end token
536        input_ids.push(TOKEN_END as i64);
537        word_mask.push(0);
538
539        let seq_len = input_ids.len();
540        // Performance: Pre-allocate attention_mask with known size
541        let mut attention_mask = Vec::with_capacity(seq_len);
542        attention_mask.resize(seq_len, 1);
543
544        Ok((
545            input_ids,
546            attention_mask,
547            word_mask,
548            word_id,
549            entity_types.len(),
550        ))
551    }
552
553    /// Generate span tensors following the GLiNER span layout.
554    ///
555    /// Shape: [num_words * max_width, 2] for span_idx
556    /// Shape: [num_words * max_width] for span_mask
557    fn make_span_tensors(&self, num_words: usize) -> (Vec<i64>, Vec<bool>) {
558        // Use checked_mul to prevent overflow (same pattern as gliner2.rs:2388)
559        let num_spans = num_words.checked_mul(MAX_SPAN_WIDTH).unwrap_or_else(|| {
560            log::warn!(
561                "Span count overflow: {} words * {} MAX_SPAN_WIDTH, using max",
562                num_words,
563                MAX_SPAN_WIDTH
564            );
565            usize::MAX
566        });
567        // Check for overflow in num_spans * 2
568        let span_idx_len = num_spans.checked_mul(2).unwrap_or_else(|| {
569            log::warn!(
570                "Span idx length overflow: {} spans * 2, using max",
571                num_spans
572            );
573            usize::MAX
574        });
575        let mut span_idx: Vec<i64> = vec![0; span_idx_len];
576        let mut span_mask: Vec<bool> = vec![false; num_spans];
577
578        for start in 0..num_words {
579            let remaining_width = num_words - start;
580            let actual_max_width = MAX_SPAN_WIDTH.min(remaining_width);
581
582            for width in 0..actual_max_width {
583                // Check for overflow in dim calculation (same pattern as nuner.rs:399)
584                let dim = match start.checked_mul(MAX_SPAN_WIDTH) {
585                    Some(v) => match v.checked_add(width) {
586                        Some(d) => d,
587                        None => {
588                            log::warn!(
589                                "Dim calculation overflow: {} * {} + {}, skipping span",
590                                start,
591                                MAX_SPAN_WIDTH,
592                                width
593                            );
594                            continue;
595                        }
596                    },
597                    None => {
598                        log::warn!(
599                            "Dim calculation overflow: {} * {}, skipping span",
600                            start,
601                            MAX_SPAN_WIDTH
602                        );
603                        continue;
604                    }
605                };
606                // Check bounds before array access (dim * 2 could overflow or exceed span_idx_len)
607                if let Some(dim2) = dim.checked_mul(2) {
608                    if dim2 + 1 < span_idx_len && dim < num_spans {
609                        span_idx[dim2] = start as i64; // start offset
610                        span_idx[dim2 + 1] = (start + width) as i64; // end offset
611                        span_mask[dim] = true;
612                    } else {
613                        log::warn!(
614                            "Span idx access out of bounds: dim={}, dim*2={}, span_idx_len={}, num_spans={}, skipping",
615                            dim, dim2, span_idx_len, num_spans
616                        );
617                    }
618                } else {
619                    log::warn!("Dim * 2 overflow: dim={}, skipping span", dim);
620                }
621            }
622        }
623
624        (span_idx, span_mask)
625    }
626
627    /// Decode model output following the GLiNER output layout.
628    ///
629    /// Expected output shape: [batch, num_words, max_width, num_entity_types]
630    fn decode_output(
631        &self,
632        outputs: &ort::session::SessionOutputs,
633        text: &str,
634        text_words: &[&str],
635        entity_types: &[&str],
636        expected_num_classes: usize,
637        threshold: f32,
638    ) -> Result<Vec<Entity>> {
639        // Performance: Cache text length once (used in extract_char_slice calls)
640        // ROI: High - called once, saves O(n) per entity in decode loops
641        let text_char_count = text.chars().count();
642        // Get output tensor
643        let output = outputs
644            .iter()
645            .next()
646            .map(|(_, v)| v)
647            .ok_or_else(|| Error::Parse("No output from GLiNER model".to_string()))?;
648
649        // Extract tensor data
650        let (_, data_slice) = output
651            .try_extract_tensor::<f32>()
652            .map_err(|e| Error::Parse(format!("Failed to extract output tensor: {}", e)))?;
653        let output_data: Vec<f32> = data_slice.to_vec();
654
655        // Get output shape
656        let shape: Vec<i64> = match output.dtype() {
657            ort::value::ValueType::Tensor { shape, .. } => shape.iter().copied().collect(),
658            _ => return Err(Error::Parse("Output is not a tensor".to_string())),
659        };
660
661        log::debug!(
662            "[GLiNER] Output shape: {:?}, data len: {}, expected classes: {}",
663            shape,
664            output_data.len(),
665            expected_num_classes
666        );
667
668        if output_data.is_empty() || shape.iter().any(|&d| d == 0) {
669            return Err(Error::Inference(
670                "GLiNER ONNX returned empty/degenerate output tensor. This usually indicates an incompatible ONNX export for this implementation (shape mismatch or missing dynamic axes).".to_string(),
671            ));
672        }
673
674        // Performance: Pre-allocate entities vec with estimated capacity
675        // Most texts have 0-50 entities, but we'll start with a reasonable default
676        let mut entities = Vec::with_capacity(32);
677        let num_text_words = text_words.len();
678
679        // Expected shape: [batch, num_words, max_width, num_classes]
680        if shape.len() == 4 && shape[0] == 1 {
681            let out_num_words = shape[1] as usize;
682            let out_max_width = shape[2] as usize;
683            let num_classes = shape[3] as usize;
684
685            log::debug!(
686                "[GLiNER] Decoding: num_words={}, max_width={}, num_classes={}",
687                out_num_words,
688                out_max_width,
689                num_classes
690            );
691
692            if num_classes == 0 {
693                return Err(Error::Inference(
694                    "GLiNER ONNX model produced num_classes=0. This export likely does not support dynamic entity types for the requested schema.".to_string(),
695                ));
696            }
697
698            // Iterate over spans and apply sigmoid threshold
699            for word_idx in 0..out_num_words.min(num_text_words) {
700                for width in 0..out_max_width.min(MAX_SPAN_WIDTH) {
701                    let end_word = word_idx + width;
702                    if end_word >= num_text_words {
703                        continue;
704                    }
705
706                    for class_idx in 0..num_classes.min(entity_types.len()) {
707                        let idx = (word_idx * out_max_width * num_classes)
708                            + (width * num_classes)
709                            + class_idx;
710
711                        if idx < output_data.len() {
712                            let logit = output_data[idx];
713                            // Apply sigmoid
714                            let score = 1.0 / (1.0 + (-logit).exp());
715
716                            if score >= threshold {
717                                let (char_start, char_end) = self.word_span_to_char_offsets(
718                                    text, text_words, word_idx, end_word,
719                                );
720
721                                // Extract actual text from source to preserve original whitespace
722                                // Performance: Use optimized extraction with cached length
723                                let span_text = extract_char_slice_with_len(
724                                    text,
725                                    char_start,
726                                    char_end,
727                                    text_char_count,
728                                );
729
730                                let entity_type_str =
731                                    entity_types.get(class_idx).unwrap_or(&"OTHER");
732                                let entity_type = Self::map_entity_type(entity_type_str);
733
734                                entities.push(Entity::new(
735                                    span_text,
736                                    entity_type,
737                                    char_start,
738                                    char_end,
739                                    score as f64,
740                                ));
741                            }
742                        }
743                    }
744                }
745            }
746        } else if shape.len() == 3 && shape[0] == 1 {
747            // Alternative shape: [batch, num_spans, num_classes]
748            let num_spans = shape[1] as usize;
749            let num_classes = shape[2] as usize;
750
751            if num_classes == 0 {
752                return Err(Error::Inference(
753                    "GLiNER ONNX model produced num_classes=0. This export likely does not support dynamic entity types for the requested schema.".to_string(),
754                ));
755            }
756
757            for span_idx in 0..num_spans {
758                let word_idx = span_idx / MAX_SPAN_WIDTH;
759                let width = span_idx % MAX_SPAN_WIDTH;
760                let end_word = word_idx + width;
761
762                if word_idx >= num_text_words || end_word >= num_text_words {
763                    continue;
764                }
765
766                for class_idx in 0..num_classes.min(entity_types.len()) {
767                    let idx = span_idx * num_classes + class_idx;
768                    if idx < output_data.len() {
769                        let logit = output_data[idx];
770                        let score = 1.0 / (1.0 + (-logit).exp());
771
772                        if score >= threshold {
773                            let (char_start, char_end) = self
774                                .word_span_to_char_offsets(text, text_words, word_idx, end_word);
775
776                            // Extract actual text from source to preserve original whitespace
777                            // Performance: Use optimized extraction with cached length
778                            let span_text = extract_char_slice_with_len(
779                                text,
780                                char_start,
781                                char_end,
782                                text_char_count,
783                            );
784
785                            let entity_type_str = entity_types.get(class_idx).unwrap_or(&"OTHER");
786                            let entity_type = Self::map_entity_type(entity_type_str);
787
788                            entities.push(Entity::new(
789                                span_text,
790                                entity_type,
791                                char_start,
792                                char_end,
793                                score as f64,
794                            ));
795                        }
796                    }
797                }
798            }
799        }
800
801        // Performance: Use unstable sort (we don't need stable sort here)
802        // Performance: Use unstable sort (we don't need stable sort here)
803        // Sort by start position, then by descending span length, then by descending confidence
804        entities.sort_unstable_by(|a, b| {
805            a.start
806                .cmp(&b.start)
807                .then_with(|| b.end.cmp(&a.end))
808                .then_with(|| {
809                    b.confidence
810                        .partial_cmp(&a.confidence)
811                        .unwrap_or(std::cmp::Ordering::Equal)
812                })
813        });
814
815        // Remove exact duplicates
816        entities.dedup_by(|a, b| a.start == b.start && a.end == b.end);
817
818        // Remove overlapping spans, keeping the highest confidence one
819        // This addresses the common issue where GLiNER detects both
820        // "The Department of Defense" and "Department of Defense"
821        let entities = remove_overlapping_spans(entities);
822
823        // Post-process: strip trailing punctuation from entity spans
824        let entities = entities
825            .into_iter()
826            .map(|mut e| {
827                // Strip trailing punctuation that shouldn't be part of entities
828                while e.text.ends_with(['.', ',', ';', ':', '!', '?']) {
829                    e.text.pop();
830                    if e.end > e.start {
831                        e.end -= 1;
832                    }
833                }
834                // Also strip leading punctuation
835                while e.text.starts_with(['.', ',', ';', ':', '!', '?']) {
836                    e.text.remove(0);
837                    e.start += 1;
838                }
839
840                // Post-process: GLiNER sometimes tags obvious companies as PRODUCT.
841                // If the surface form has strong company markers, remap PRODUCT → ORG.
842                //
843                // Keep this conservative: only remap when the mention itself looks like a company
844                // ("Inc", "Ltd", "LLC", "株式会社", etc.) to avoid collapsing real products.
845                if e.entity_type.as_label().eq_ignore_ascii_case("PRODUCT")
846                    && looks_like_company_name(&e.text)
847                {
848                    e.entity_type = EntityType::Organization;
849                }
850
851                e
852            })
853            .filter(|e| !e.text.is_empty() && e.start < e.end)
854            .collect();
855
856        Ok(entities)
857    }
858
859    /// Map entity type string to EntityType enum.
860    fn map_entity_type(type_str: &str) -> EntityType {
861        match type_str.to_lowercase().as_str() {
862            "person" | "per" => EntityType::Person,
863            "organization" | "org" | "company" => EntityType::Organization,
864            "location" | "loc" | "gpe" | "geo-loc" => EntityType::Location,
865            "facility" | "fac" => EntityType::custom("FACILITY", anno_core::EntityCategory::Place),
866            "product" | "prod" => EntityType::custom("PRODUCT", anno_core::EntityCategory::Misc),
867            "misc" | "other" => EntityType::Other("MISC".to_string()),
868            "date" | "time" => EntityType::Date,
869            "money" | "currency" => EntityType::Money,
870            "percent" | "percentage" => EntityType::Percent,
871            other => EntityType::Other(other.to_string()),
872        }
873    }
874
875    /// Convert word indices to character offsets.
876    ///
877    /// This function correctly handles Unicode text by converting byte offsets
878    /// to character offsets using the offset module's bytes_to_chars function.
879    fn word_span_to_char_offsets(
880        &self,
881        text: &str,
882        words: &[&str],
883        start_word: usize,
884        end_word: usize,
885    ) -> (usize, usize) {
886        // Defensive: Validate bounds
887        if words.is_empty()
888            || start_word >= words.len()
889            || end_word >= words.len()
890            || start_word > end_word
891        {
892            // Return safe defaults: empty span (0, 0)
893            return (0, 0);
894        }
895
896        let mut byte_pos = 0;
897        let mut start_byte = 0;
898        let mut end_byte = text.len();
899        let mut found_start = false;
900        let mut found_end = false;
901
902        for (idx, word) in words.iter().enumerate() {
903            // Search for the word in the remaining text (by bytes)
904            if let Some(pos) = text[byte_pos..].find(word) {
905                let word_start_byte = byte_pos + pos;
906                let word_end_byte = word_start_byte + word.len();
907
908                if idx == start_word {
909                    start_byte = word_start_byte;
910                    found_start = true;
911                }
912                if idx == end_word {
913                    end_byte = word_end_byte;
914                    found_end = true;
915                    break;
916                }
917                byte_pos = word_end_byte;
918            } else {
919                // Word not found - this shouldn't happen in normal operation,
920                // but if it does, we can't reliably compute offsets
921            }
922        }
923
924        // If we didn't find the words, return safe defaults
925        if !found_start || !found_end {
926            // Return empty span to avoid incorrect entity extraction
927            (0, 0)
928        } else {
929            // Convert byte offsets to character offsets
930            crate::offset::bytes_to_chars(text, start_byte, end_byte)
931        }
932    }
933}
934
935fn looks_like_company_name(text: &str) -> bool {
936    // Keep the logic cheap and conservative (no regex): normalize and check suffix markers.
937    let t = text.trim();
938    if t.is_empty() {
939        return false;
940    }
941
942    let lower = t.to_lowercase();
943
944    // Western-ish suffixes
945    let suffixes = [
946        " inc",
947        " inc.",
948        " ltd",
949        " ltd.",
950        " llc",
951        " llp",
952        " plc",
953        " co",
954        " co.",
955        " company",
956        " corp",
957        " corp.",
958        " corporation",
959        " gmbh",
960        " s.a.",
961        " sa",
962    ];
963    if suffixes.iter().any(|s| lower.ends_with(s)) {
964        return true;
965    }
966
967    // CJK org markers
968    if t.contains("株式会社") || t.contains("有限会社") || t.contains("公司") || t.contains("集团")
969    {
970        return true;
971    }
972
973    // Arabic "company" marker
974    if t.contains("شركة") {
975        return true;
976    }
977
978    false
979}
980
981#[cfg(test)]
982mod postprocess_tests {
983    use super::looks_like_company_name;
984
985    #[test]
986    fn test_looks_like_company_name() {
987        assert!(looks_like_company_name("Apple Inc"));
988        assert!(looks_like_company_name("Acme Corp."));
989        assert!(looks_like_company_name("Example GmbH"));
990        assert!(looks_like_company_name("株式会社トヨタ自動車"));
991        assert!(looks_like_company_name("شركة أرامكو"));
992
993        assert!(!looks_like_company_name("Apple"));
994        assert!(!looks_like_company_name("New York"));
995    }
996}
997
998/// Extract a substring by character offsets (not byte offsets).
999///
1000/// This handles Unicode text correctly by iterating over characters.
1001///
1002/// # Performance
1003///
1004/// For repeated calls on the same text, consider using `extract_char_slice_with_len`
1005/// with a cached text length to avoid recalculating `text.chars().count()`.
1006fn extract_char_slice(text: &str, char_start: usize, char_end: usize) -> String {
1007    // Performance optimization: Use Entity's optimized method if we have cached length
1008    // For single calls, this is fine. For batch operations, cache text.chars().count()
1009    let text_char_count = text.chars().count();
1010    extract_char_slice_with_len(text, char_start, char_end, text_char_count)
1011}
1012
1013/// Extract a substring by character offsets with pre-computed text length.
1014///
1015/// This is a performance optimization for batch operations where you've already
1016/// computed `text.chars().count()`.
1017fn extract_char_slice_with_len(
1018    text: &str,
1019    char_start: usize,
1020    char_end: usize,
1021    text_char_count: usize,
1022) -> String {
1023    if char_start >= text_char_count || char_end > text_char_count || char_start >= char_end {
1024        return String::new();
1025    }
1026    text.chars()
1027        .skip(char_start)
1028        .take(char_end.saturating_sub(char_start))
1029        .collect()
1030}
1031
1032// =============================================================================
1033// Model Trait Implementation
1034// =============================================================================
1035
1036/// Default entity types for zero-shot GLiNER when used via the Model trait.
1037#[cfg(feature = "onnx")]
1038const DEFAULT_GLINER_LABELS: &[&str] = &[
1039    "person",
1040    "organization",
1041    "location",
1042    "date",
1043    "time",
1044    "money",
1045    "percent",
1046    "product",
1047    "event",
1048    "facility",
1049    "work_of_art",
1050    "law",
1051    "language",
1052];
1053
1054#[cfg(feature = "onnx")]
1055impl crate::Model for GLiNEROnnx {
1056    fn extract_entities(&self, text: &str, _language: Option<&str>) -> crate::Result<Vec<Entity>> {
1057        // Use default labels for the Model trait interface
1058        // For custom labels, use the extract(text, labels, threshold) method directly
1059        self.extract(text, DEFAULT_GLINER_LABELS, 0.5)
1060    }
1061
1062    fn supported_types(&self) -> Vec<anno_core::EntityType> {
1063        // GLiNER supports any type via zero-shot - return the defaults
1064        DEFAULT_GLINER_LABELS
1065            .iter()
1066            .map(|label| anno_core::EntityType::Custom {
1067                name: (*label).to_string(),
1068                category: EntityCategory::Misc,
1069            })
1070            .collect()
1071    }
1072
1073    fn is_available(&self) -> bool {
1074        true // If we got this far, it's available
1075    }
1076
1077    fn name(&self) -> &'static str {
1078        "GLiNER-ONNX"
1079    }
1080
1081    fn description(&self) -> &'static str {
1082        "Zero-shot NER using GLiNER with ONNX Runtime backend"
1083    }
1084
1085    fn version(&self) -> String {
1086        // Version depends on the model weights and quantization status
1087        format!(
1088            "gliner-onnx-{}-{}",
1089            self.model_name,
1090            if self.is_quantized { "q" } else { "fp32" }
1091        )
1092    }
1093}
1094
1095#[cfg(feature = "onnx")]
1096impl crate::backends::inference::ZeroShotNER for GLiNEROnnx {
1097    fn extract_with_types(
1098        &self,
1099        text: &str,
1100        entity_types: &[&str],
1101        threshold: f32,
1102    ) -> crate::Result<Vec<Entity>> {
1103        self.extract(text, entity_types, threshold)
1104    }
1105
1106    fn extract_with_descriptions(
1107        &self,
1108        text: &str,
1109        descriptions: &[&str],
1110        threshold: f32,
1111    ) -> crate::Result<Vec<Entity>> {
1112        // GLiNER encodes labels as text, so descriptions work the same way
1113        self.extract(text, descriptions, threshold)
1114    }
1115
1116    fn default_types(&self) -> &[&'static str] {
1117        DEFAULT_GLINER_LABELS
1118    }
1119}
1120
1121// =============================================================================
1122// Stub when feature disabled
1123// =============================================================================
1124
1125#[cfg(not(feature = "onnx"))]
1126#[derive(Debug)]
1127pub struct GLiNEROnnx;
1128
1129#[cfg(not(feature = "onnx"))]
1130impl GLiNEROnnx {
1131    /// Create a new GLiNER model (stub - requires onnx feature).
1132    pub fn new(_model_name: &str) -> Result<Self> {
1133        Err(Error::InvalidInput(
1134            "GLiNER-ONNX requires the 'onnx' feature. \
1135             Build with: cargo build --features onnx"
1136                .to_string(),
1137        ))
1138    }
1139
1140    /// Get the model name (stub).
1141    pub fn model_name(&self) -> &str {
1142        "gliner-not-enabled"
1143    }
1144
1145    /// Extract entities (stub - requires onnx feature).
1146    pub fn extract(
1147        &self,
1148        _text: &str,
1149        _entity_types: &[&str],
1150        _threshold: f32,
1151    ) -> Result<Vec<Entity>> {
1152        Err(Error::InvalidInput(
1153            "GLiNER-ONNX requires the 'onnx' feature".to_string(),
1154        ))
1155    }
1156}
1157
1158#[cfg(not(feature = "onnx"))]
1159impl crate::Model for GLiNEROnnx {
1160    fn extract_entities(&self, _text: &str, _language: Option<&str>) -> crate::Result<Vec<Entity>> {
1161        Err(Error::InvalidInput(
1162            "GLiNER-ONNX requires the 'onnx' feature".to_string(),
1163        ))
1164    }
1165
1166    fn supported_types(&self) -> Vec<anno_core::EntityType> {
1167        vec![]
1168    }
1169
1170    fn is_available(&self) -> bool {
1171        false
1172    }
1173
1174    fn name(&self) -> &'static str {
1175        "GLiNER-ONNX (unavailable)"
1176    }
1177
1178    fn description(&self) -> &'static str {
1179        "GLiNER with ONNX Runtime backend - requires 'onnx' feature"
1180    }
1181}
1182
1183#[cfg(not(feature = "onnx"))]
1184impl crate::backends::inference::ZeroShotNER for GLiNEROnnx {
1185    fn extract_with_types(
1186        &self,
1187        _text: &str,
1188        _entity_types: &[&str],
1189        _threshold: f32,
1190    ) -> crate::Result<Vec<Entity>> {
1191        Err(Error::InvalidInput(
1192            "GLiNER-ONNX requires the 'onnx' feature".to_string(),
1193        ))
1194    }
1195
1196    fn extract_with_descriptions(
1197        &self,
1198        _text: &str,
1199        _descriptions: &[&str],
1200        _threshold: f32,
1201    ) -> crate::Result<Vec<Entity>> {
1202        Err(Error::InvalidInput(
1203            "GLiNER-ONNX requires the 'onnx' feature".to_string(),
1204        ))
1205    }
1206}
1207
1208// =============================================================================
1209// BatchCapable Trait Implementation
1210// =============================================================================
1211
1212#[cfg(feature = "onnx")]
1213impl crate::BatchCapable for GLiNEROnnx {
1214    fn extract_entities_batch(
1215        &self,
1216        texts: &[&str],
1217        _language: Option<&str>,
1218    ) -> Result<Vec<Vec<Entity>>> {
1219        if texts.is_empty() {
1220            return Ok(Vec::new());
1221        }
1222
1223        // GLiNER supports true batching with padded sequences
1224        // For simplicity, we reuse the session efficiently with sequential calls
1225        // The tokenizer and model weights stay cached
1226        let default_types = DEFAULT_GLINER_LABELS;
1227        let threshold = 0.5;
1228
1229        texts
1230            .iter()
1231            .map(|text| self.extract(text, default_types, threshold))
1232            .collect()
1233    }
1234
1235    fn optimal_batch_size(&self) -> Option<usize> {
1236        Some(16)
1237    }
1238}
1239
1240#[cfg(not(feature = "onnx"))]
1241impl crate::BatchCapable for GLiNEROnnx {
1242    fn extract_entities_batch(
1243        &self,
1244        texts: &[&str],
1245        _language: Option<&str>,
1246    ) -> Result<Vec<Vec<Entity>>> {
1247        Err(Error::InvalidInput(
1248            "GLiNER-ONNX requires the 'onnx' feature".to_string(),
1249        ))
1250    }
1251
1252    fn optimal_batch_size(&self) -> Option<usize> {
1253        None
1254    }
1255}
1256
1257// =============================================================================
1258// StreamingCapable Trait Implementation
1259// =============================================================================
1260// Overlap Removal
1261// =============================================================================
1262
1263/// Remove overlapping entity spans intelligently.
1264///
1265/// Strategy:
1266/// 1. Prefer shorter spans when they have similar or higher confidence
1267///    (e.g., prefer "Department of Defense" over "The Department of Defense")
1268/// 2. For truly overlapping spans of similar length, keep highest confidence
1269/// 3. Handle comma-separated entities (e.g., "IBM, NASA" should become "IBM" + "NASA")
1270fn remove_overlapping_spans(mut entities: Vec<Entity>) -> Vec<Entity> {
1271    if entities.len() <= 1 {
1272        return entities;
1273    }
1274
1275    // Performance: Use unstable sort (we don't need stable sort here)
1276    // Sort by span length (shorter first), then by confidence descending
1277    // This prefers shorter, more precise spans
1278    entities.sort_unstable_by(|a, b| {
1279        let len_a = a.end - a.start;
1280        let len_b = b.end - b.start;
1281        len_a.cmp(&len_b).then_with(|| {
1282            b.confidence
1283                .partial_cmp(&a.confidence)
1284                .unwrap_or(std::cmp::Ordering::Equal)
1285        })
1286    });
1287
1288    let mut result: Vec<Entity> = Vec::with_capacity(entities.len());
1289
1290    for entity in entities {
1291        // Check if this entity is FULLY CONTAINED by any already-kept entity
1292        // If so, skip it (we already have a more precise version)
1293        let is_superset_of_existing = result.iter().any(|kept| {
1294            // Entity fully contains kept
1295            entity.start <= kept.start && entity.end >= kept.end
1296        });
1297
1298        if is_superset_of_existing {
1299            // Skip - we have smaller, more precise entities
1300            continue;
1301        }
1302
1303        // Check if this entity overlaps (but doesn't contain) any kept entity
1304        let overlaps_existing = result.iter().any(|kept| {
1305            let entity_range = entity.start..entity.end;
1306            let kept_range = kept.start..kept.end;
1307            // Partial overlap (not full containment)
1308            entity_range.start < kept_range.end && kept_range.start < entity_range.end
1309        });
1310
1311        if !overlaps_existing {
1312            result.push(entity);
1313        }
1314    }
1315
1316    // Performance: Use unstable sort (we don't need stable sort here)
1317    // Re-sort by position for output
1318    result.sort_unstable_by_key(|e| e.start);
1319    result
1320}
1321
1322// =============================================================================
1323// StreamingCapable
1324// =============================================================================
1325
1326#[cfg(feature = "onnx")]
1327impl crate::StreamingCapable for GLiNEROnnx {
1328    fn recommended_chunk_size(&self) -> usize {
1329        4096 // Characters
1330    }
1331}
1332
1333#[cfg(not(feature = "onnx"))]
1334impl crate::StreamingCapable for GLiNEROnnx {
1335    fn recommended_chunk_size(&self) -> usize {
1336        4096
1337    }
1338}