Skip to main content

anno/backends/gliner_onnx/
mod.rs

1//! GLiNER-based NER implementation using ONNX Runtime.
2//!
3//! GLiNER (Generalist and Lightweight Model for Named Entity Recognition) is
4//! a popular approach to “open/zero-shot” NER. This implementation follows the GLiNER prompt format
5//! and common community conventions.
6//!
7//! ## Prompt Format
8//!
9//! GLiNER uses a special prompt format:
10//!
11//! ```text
12//! [START] <<ENT>> type1 <<ENT>> type2 <<SEP>> word1 word2 ... [END]
13//! ```
14//!
15//! Token IDs (for GLiNER tokenizer):
16//! - START = 1
17//! - END = 2
18//! - `<<ENT>>` = 128002
19//! - `<<SEP>>` = 128003
20//!
21//! ## Key Insight
22//!
23//! Each word is encoded SEPARATELY, preserving word boundaries.
24//! Output shape: [batch, num_words, max_width, num_entity_types]
25
26#![allow(missing_docs)] // Stub implementation
27#![allow(dead_code)] // Placeholder constants
28#![allow(clippy::type_complexity)] // Complex return tuples
29#![allow(clippy::manual_contains)] // Shape check style
30#![allow(unused_variables)] // Feature-gated code
31#![allow(clippy::items_after_test_module)] // Large file; keep local tests near helpers
32#![allow(unused_imports)] // EntityType used conditionally
33
34#[cfg(feature = "onnx")]
35use crate::sync::{lock, try_lock, Mutex};
36use crate::{Entity, Error, Result};
37use anno_core::{EntityCategory, EntityType};
38
39/// Special token IDs for GLiNER models
40const TOKEN_START: u32 = 1;
41const TOKEN_END: u32 = 2;
42const TOKEN_ENT: u32 = 128002;
43const TOKEN_SEP: u32 = 128003;
44
45/// Default max span width from GLiNER config
46const MAX_SPAN_WIDTH: usize = 12;
47
48/// Configuration for GLiNER model loading.
49#[cfg(feature = "onnx")]
50pub mod config;
51pub use config::*;
52
53pub struct GLiNEROnnx {
54    session: Mutex<ort::session::Session>,
55    /// Arc-wrapped tokenizer for cheap cloning across threads.
56    tokenizer: std::sync::Arc<tokenizers::Tokenizer>,
57    /// HuggingFace model identifier (e.g., "onnx-community/gliner_small-v2.1").
58    model_name: String,
59    /// Whether a quantized model was loaded.
60    is_quantized: bool,
61    /// LRU cache for prompt encodings (keyed by text + entity types).
62    prompt_cache: Option<Mutex<lru::LruCache<PromptCacheKey, PromptCacheValue>>>,
63}
64
65#[cfg(feature = "onnx")]
66mod inference;
67#[cfg(feature = "onnx")]
68pub(crate) use inference::looks_like_company_name;
69use inference::DEFAULT_GLINER_LABELS;
70impl crate::Model for GLiNEROnnx {
71    fn extract_entities(&self, text: &str, _language: Option<&str>) -> crate::Result<Vec<Entity>> {
72        // Use default labels for the Model trait interface
73        // For custom labels, use the extract(text, labels, threshold) method directly
74        self.extract(text, DEFAULT_GLINER_LABELS, 0.5)
75    }
76
77    fn supported_types(&self) -> Vec<anno_core::EntityType> {
78        // GLiNER supports any type via zero-shot - return the defaults
79        DEFAULT_GLINER_LABELS
80            .iter()
81            .map(|label| anno_core::EntityType::Custom {
82                name: (*label).to_string(),
83                category: EntityCategory::Misc,
84            })
85            .collect()
86    }
87
88    fn is_available(&self) -> bool {
89        true // If we got this far, it's available
90    }
91
92    fn name(&self) -> &'static str {
93        "GLiNER-ONNX"
94    }
95
96    fn description(&self) -> &'static str {
97        "Zero-shot NER using GLiNER with ONNX Runtime backend"
98    }
99
100    fn version(&self) -> String {
101        // Version depends on the model weights and quantization status
102        format!(
103            "gliner-onnx-{}-{}",
104            self.model_name,
105            if self.is_quantized { "q" } else { "fp32" }
106        )
107    }
108}
109
110#[cfg(feature = "onnx")]
111impl crate::backends::inference::ZeroShotNER for GLiNEROnnx {
112    fn extract_with_types(
113        &self,
114        text: &str,
115        entity_types: &[&str],
116        threshold: f32,
117    ) -> crate::Result<Vec<Entity>> {
118        self.extract(text, entity_types, threshold)
119    }
120
121    fn extract_with_descriptions(
122        &self,
123        text: &str,
124        descriptions: &[&str],
125        threshold: f32,
126    ) -> crate::Result<Vec<Entity>> {
127        // GLiNER encodes labels as text, so descriptions work the same way
128        self.extract(text, descriptions, threshold)
129    }
130
131    fn default_types(&self) -> &[&'static str] {
132        DEFAULT_GLINER_LABELS
133    }
134}
135
136// =============================================================================
137// Stub when feature disabled
138// =============================================================================
139
140#[cfg(not(feature = "onnx"))]
141#[derive(Debug)]
142pub struct GLiNEROnnx;
143
144#[cfg(not(feature = "onnx"))]
145impl GLiNEROnnx {
146    /// Create a new GLiNER model (stub - requires onnx feature).
147    pub fn new(_model_name: &str) -> Result<Self> {
148        Err(Error::InvalidInput(
149            "GLiNER-ONNX requires the 'onnx' feature. \
150             Build with: cargo build --features onnx"
151                .to_string(),
152        ))
153    }
154
155    /// Get the model name (stub).
156    pub fn model_name(&self) -> &str {
157        "gliner-not-enabled"
158    }
159
160    /// Extract entities (stub - requires onnx feature).
161    pub fn extract(
162        &self,
163        _text: &str,
164        _entity_types: &[&str],
165        _threshold: f32,
166    ) -> Result<Vec<Entity>> {
167        Err(Error::InvalidInput(
168            "GLiNER-ONNX requires the 'onnx' feature".to_string(),
169        ))
170    }
171}
172
173#[cfg(not(feature = "onnx"))]
174impl crate::Model for GLiNEROnnx {
175    fn extract_entities(&self, _text: &str, _language: Option<&str>) -> crate::Result<Vec<Entity>> {
176        Err(Error::InvalidInput(
177            "GLiNER-ONNX requires the 'onnx' feature".to_string(),
178        ))
179    }
180
181    fn supported_types(&self) -> Vec<anno_core::EntityType> {
182        vec![]
183    }
184
185    fn is_available(&self) -> bool {
186        false
187    }
188
189    fn name(&self) -> &'static str {
190        "GLiNER-ONNX (unavailable)"
191    }
192
193    fn description(&self) -> &'static str {
194        "GLiNER with ONNX Runtime backend - requires 'onnx' feature"
195    }
196}
197
198#[cfg(not(feature = "onnx"))]
199impl crate::backends::inference::ZeroShotNER for GLiNEROnnx {
200    fn extract_with_types(
201        &self,
202        _text: &str,
203        _entity_types: &[&str],
204        _threshold: f32,
205    ) -> crate::Result<Vec<Entity>> {
206        Err(Error::InvalidInput(
207            "GLiNER-ONNX requires the 'onnx' feature".to_string(),
208        ))
209    }
210
211    fn extract_with_descriptions(
212        &self,
213        _text: &str,
214        _descriptions: &[&str],
215        _threshold: f32,
216    ) -> crate::Result<Vec<Entity>> {
217        Err(Error::InvalidInput(
218            "GLiNER-ONNX requires the 'onnx' feature".to_string(),
219        ))
220    }
221}
222
223// =============================================================================
224// BatchCapable Trait Implementation
225// =============================================================================
226
227#[cfg(feature = "onnx")]
228impl crate::BatchCapable for GLiNEROnnx {
229    fn extract_entities_batch(
230        &self,
231        texts: &[&str],
232        _language: Option<&str>,
233    ) -> Result<Vec<Vec<Entity>>> {
234        if texts.is_empty() {
235            return Ok(Vec::new());
236        }
237
238        // GLiNER supports true batching with padded sequences
239        // For simplicity, we reuse the session efficiently with sequential calls
240        // The tokenizer and model weights stay cached
241        let default_types = DEFAULT_GLINER_LABELS;
242        let threshold = 0.5;
243
244        texts
245            .iter()
246            .map(|text| self.extract(text, default_types, threshold))
247            .collect()
248    }
249
250    fn optimal_batch_size(&self) -> Option<usize> {
251        Some(16)
252    }
253}
254
255#[cfg(not(feature = "onnx"))]
256impl crate::BatchCapable for GLiNEROnnx {
257    fn extract_entities_batch(
258        &self,
259        texts: &[&str],
260        _language: Option<&str>,
261    ) -> Result<Vec<Vec<Entity>>> {
262        Err(Error::InvalidInput(
263            "GLiNER-ONNX requires the 'onnx' feature".to_string(),
264        ))
265    }
266
267    fn optimal_batch_size(&self) -> Option<usize> {
268        None
269    }
270}
271
272// =============================================================================
273// StreamingCapable Trait Implementation
274// =============================================================================
275// Overlap Removal
276// =============================================================================
277
278/// Remove overlapping entity spans intelligently.
279///
280/// Strategy:
281/// 1. Prefer shorter spans when they have similar or higher confidence
282///    (e.g., prefer "Department of Defense" over "The Department of Defense")
283/// 2. For truly overlapping spans of similar length, keep highest confidence
284/// 3. Handle comma-separated entities (e.g., "IBM, NASA" should become "IBM" + "NASA")
285fn remove_overlapping_spans(mut entities: Vec<Entity>) -> Vec<Entity> {
286    if entities.len() <= 1 {
287        return entities;
288    }
289
290    // Performance: Use unstable sort (we don't need stable sort here)
291    // Sort by span length (shorter first), then by confidence descending
292    // This prefers shorter, more precise spans
293    entities.sort_unstable_by(|a, b| {
294        let len_a = a.end - a.start;
295        let len_b = b.end - b.start;
296        len_a.cmp(&len_b).then_with(|| {
297            b.confidence
298                .partial_cmp(&a.confidence)
299                .unwrap_or(std::cmp::Ordering::Equal)
300        })
301    });
302
303    let mut result: Vec<Entity> = Vec::with_capacity(entities.len());
304
305    for entity in entities {
306        // Check if this entity is FULLY CONTAINED by any already-kept entity
307        // If so, skip it (we already have a more precise version)
308        let is_superset_of_existing = result.iter().any(|kept| {
309            // Entity fully contains kept
310            entity.start <= kept.start && entity.end >= kept.end
311        });
312
313        if is_superset_of_existing {
314            // Skip - we have smaller, more precise entities
315            continue;
316        }
317
318        // Check if this entity overlaps (but doesn't contain) any kept entity
319        let overlaps_existing = result.iter().any(|kept| {
320            let entity_range = entity.start..entity.end;
321            let kept_range = kept.start..kept.end;
322            // Partial overlap (not full containment)
323            entity_range.start < kept_range.end && kept_range.start < entity_range.end
324        });
325
326        if !overlaps_existing {
327            result.push(entity);
328        }
329    }
330
331    // Performance: Use unstable sort (we don't need stable sort here)
332    // Re-sort by position for output
333    result.sort_unstable_by_key(|e| e.start);
334    result
335}
336
337// =============================================================================
338// StreamingCapable
339// =============================================================================
340
341#[cfg(feature = "onnx")]
342impl crate::StreamingCapable for GLiNEROnnx {
343    fn recommended_chunk_size(&self) -> usize {
344        4096 // Characters
345    }
346}
347
348#[cfg(not(feature = "onnx"))]
349impl crate::StreamingCapable for GLiNEROnnx {
350    fn recommended_chunk_size(&self) -> usize {
351        4096
352    }
353}
354
355#[cfg(test)]
356mod postprocess_tests;