Skip to main content

anno/backends/
onnx.rs

1//! BERT-based NER using ONNX Runtime.
2//!
3//! This module provides a reliable ONNX-based NER backend using standard
4//! BERT models fine-tuned for token classification (BIO tags).
5//!
6//! Unlike GLiNER which has ONNX export issues, this uses properly exported
7//! BERT NER models like `protectai/bert-base-NER-onnx`.
8//!
9//! ## Default Model
10//!
11//! Uses `protectai/bert-base-NER-onnx` which recognizes:
12//! - PER (Person)
13//! - ORG (Organization)
14//! - LOC (Location)
15//! - MISC (Miscellaneous)
16
17#![allow(missing_docs)] // Stub implementation
18#![allow(dead_code)] // Placeholder constants
19#![allow(clippy::manual_strip)] // Complex BIO tag parsing
20
21use crate::{Entity, Error, Result};
22#[cfg(feature = "onnx")]
23use anno_core::EntityType;
24
25#[cfg(feature = "onnx")]
26use {
27    crate::sync::lock,
28    hf_hub::api::sync::Api,
29    ndarray::Array2,
30    ort::{session::builder::GraphOptimizationLevel, session::Session},
31    std::collections::HashMap,
32    tokenizers::Tokenizer,
33};
34
35/// Default BERT NER ONNX model (properly exported, reliable).
36pub const DEFAULT_BERT_NER_MODEL: &str = "protectai/bert-base-NER-onnx";
37
38/// Configuration for BERT NER model loading.
39#[cfg(feature = "onnx")]
40#[derive(Debug, Clone)]
41pub struct BertNERConfig {
42    /// Prefer quantized models (INT8) for faster CPU inference.
43    pub prefer_quantized: bool,
44    /// ONNX optimization level (1-3, default 3).
45    pub optimization_level: u8,
46    /// Number of threads for inference (0 = auto).
47    pub num_threads: usize,
48}
49
50#[cfg(feature = "onnx")]
51impl Default for BertNERConfig {
52    fn default() -> Self {
53        Self {
54            prefer_quantized: true,
55            optimization_level: 3,
56            num_threads: 4,
57        }
58    }
59}
60
61/// BERT-based NER using ONNX Runtime.
62///
63/// Uses standard BERT models fine-tuned for NER with BIO tagging scheme.
64/// Thread-safe with `Arc<Tokenizer>` for efficient sharing.
65#[cfg(feature = "onnx")]
66pub struct BertNEROnnx {
67    session: crate::sync::Mutex<Session>,
68    /// Arc-wrapped tokenizer for cheap cloning across threads.
69    tokenizer: std::sync::Arc<Tokenizer>,
70    id_to_label: HashMap<usize, String>,
71    label_to_entity_type: HashMap<String, EntityType>,
72    model_name: String,
73    /// Whether a quantized model was loaded.
74    is_quantized: bool,
75}
76
77#[cfg(feature = "onnx")]
78impl BertNEROnnx {
79    /// Create a new BERT NER ONNX model with default config.
80    ///
81    /// # Arguments
82    /// * `model_name` - HuggingFace model identifier (e.g., "protectai/bert-base-NER-onnx")
83    ///
84    /// # Returns
85    /// BERT NER ONNX model instance
86    pub fn new(model_name: &str) -> Result<Self> {
87        Self::with_config(model_name, BertNERConfig::default())
88    }
89
90    /// Create a new BERT NER ONNX model with custom configuration.
91    ///
92    /// # Arguments
93    /// * `model_name` - HuggingFace model identifier
94    /// * `config` - Configuration for model loading
95    ///
96    /// # Example
97    ///
98    /// ```rust,ignore
99    /// let config = BertNERConfig {
100    ///     prefer_quantized: true,  // Use INT8 model for 2-4x speedup
101    ///     optimization_level: 3,
102    ///     num_threads: 8,
103    /// };
104    /// let model = BertNEROnnx::with_config("protectai/bert-base-NER-onnx", config)?;
105    /// ```
106    pub fn with_config(model_name: &str, config: BertNERConfig) -> Result<Self> {
107        let api = Api::new().map_err(|e| {
108            Error::Retrieval(format!("Failed to initialize HuggingFace API: {}", e))
109        })?;
110
111        let repo = api.model(model_name.to_string());
112
113        // Download model - try quantized first if preferred
114        let (model_path, is_quantized) = if config.prefer_quantized {
115            if let Ok(path) = repo.get("model_quantized.onnx") {
116                log::info!("[BERT-NER] Using quantized model (INT8)");
117                (path, true)
118            } else if let Ok(path) = repo.get("onnx/model_quantized.onnx") {
119                log::info!("[BERT-NER] Using quantized model (INT8)");
120                (path, true)
121            } else if let Ok(path) = repo.get("model_int8.onnx") {
122                log::info!("[BERT-NER] Using INT8 quantized model");
123                (path, true)
124            } else {
125                // Fall back to FP32
126                let path = repo
127                    .get("model.onnx")
128                    .or_else(|_| repo.get("onnx/model.onnx"))
129                    .map_err(|e| {
130                        Error::Retrieval(format!("Failed to download model.onnx: {}", e))
131                    })?;
132                log::info!("[BERT-NER] Using FP32 model (quantized not available)");
133                (path, false)
134            }
135        } else {
136            let path = repo
137                .get("model.onnx")
138                .or_else(|_| repo.get("onnx/model.onnx"))
139                .map_err(|e| Error::Retrieval(format!("Failed to download model.onnx: {}", e)))?;
140            (path, false)
141        };
142
143        // Download tokenizer.json
144        let tokenizer_path = repo
145            .get("tokenizer.json")
146            .map_err(|e| Error::Retrieval(format!("Failed to download tokenizer.json: {}", e)))?;
147
148        // Download config.json for label mapping
149        let config_path = repo
150            .get("config.json")
151            .map_err(|e| Error::Retrieval(format!("Failed to download config.json: {}", e)))?;
152
153        // Load tokenizer
154        let tokenizer = Tokenizer::from_file(&tokenizer_path)
155            .map_err(|e| Error::Retrieval(format!("Failed to load tokenizer: {}", e)))?;
156
157        // Load config and extract id2label mapping
158        let config_str = std::fs::read_to_string(&config_path)
159            .map_err(|e| Error::Retrieval(format!("Failed to read config.json: {}", e)))?;
160        let config_json: serde_json::Value = serde_json::from_str(&config_str)
161            .map_err(|e| Error::Parse(format!("Failed to parse config.json: {}", e)))?;
162
163        // Build label mappings
164        let id_to_label = Self::build_id_to_label(&config_json);
165        let label_to_entity_type = Self::build_label_to_entity_type();
166
167        // Build session with optimization settings
168        let opt_level = match config.optimization_level {
169            1 => GraphOptimizationLevel::Level1,
170            2 => GraphOptimizationLevel::Level2,
171            _ => GraphOptimizationLevel::Level3,
172        };
173
174        let mut builder = Session::builder()
175            .map_err(|e| Error::Retrieval(format!("Failed to create session builder: {}", e)))?
176            .with_optimization_level(opt_level)
177            .map_err(|e| Error::Retrieval(format!("Failed to set optimization level: {}", e)))?;
178
179        if config.num_threads > 0 {
180            builder = builder
181                .with_intra_threads(config.num_threads)
182                .map_err(|e| Error::Retrieval(format!("Failed to set threads: {}", e)))?;
183        }
184
185        let session = builder
186            .commit_from_file(&model_path)
187            .map_err(|e| Error::Retrieval(format!("Failed to load ONNX model: {}", e)))?;
188
189        Ok(Self {
190            session: crate::sync::Mutex::new(session),
191            tokenizer: std::sync::Arc::new(tokenizer),
192            id_to_label,
193            label_to_entity_type,
194            model_name: model_name.to_string(),
195            is_quantized,
196        })
197    }
198
199    /// Check if a quantized model was loaded.
200    #[must_use]
201    pub fn is_quantized(&self) -> bool {
202        self.is_quantized
203    }
204
205    /// Get a clone of the tokenizer Arc (cheap).
206    #[must_use]
207    pub fn tokenizer(&self) -> std::sync::Arc<Tokenizer> {
208        std::sync::Arc::clone(&self.tokenizer)
209    }
210
211    /// Build id_to_label mapping from config.
212    fn build_id_to_label(config_json: &serde_json::Value) -> HashMap<usize, String> {
213        let mut map = HashMap::new();
214        if let Some(id2label) = config_json.get("id2label") {
215            if let Some(obj) = id2label.as_object() {
216                for (id_str, label_value) in obj {
217                    if let (Ok(id), Some(label)) = (id_str.parse::<usize>(), label_value.as_str()) {
218                        map.insert(id, label.to_string());
219                    }
220                }
221            }
222        }
223        // Fallback for CoNLL-03 format
224        if map.is_empty() {
225            map.insert(0, "O".to_string());
226            map.insert(1, "B-MISC".to_string());
227            map.insert(2, "I-MISC".to_string());
228            map.insert(3, "B-PER".to_string());
229            map.insert(4, "I-PER".to_string());
230            map.insert(5, "B-ORG".to_string());
231            map.insert(6, "I-ORG".to_string());
232            map.insert(7, "B-LOC".to_string());
233            map.insert(8, "I-LOC".to_string());
234        }
235        map
236    }
237
238    /// Build label_to_entity_type mapping for common NER labels.
239    fn build_label_to_entity_type() -> HashMap<String, EntityType> {
240        let mut map = HashMap::new();
241        // Standard CoNLL-03 labels
242        map.insert("B-PER".to_string(), EntityType::Person);
243        map.insert("I-PER".to_string(), EntityType::Person);
244        map.insert("B-ORG".to_string(), EntityType::Organization);
245        map.insert("I-ORG".to_string(), EntityType::Organization);
246        map.insert("B-LOC".to_string(), EntityType::Location);
247        map.insert("I-LOC".to_string(), EntityType::Location);
248        map.insert("B-MISC".to_string(), EntityType::Other("misc".to_string()));
249        map.insert("I-MISC".to_string(), EntityType::Other("misc".to_string()));
250        // Alternative formats
251        map.insert("PER".to_string(), EntityType::Person);
252        map.insert("ORG".to_string(), EntityType::Organization);
253        map.insert("LOC".to_string(), EntityType::Location);
254        map.insert("MISC".to_string(), EntityType::Other("misc".to_string()));
255        map
256    }
257
258    /// Extract entities from text using BERT NER.
259    ///
260    /// # Arguments
261    /// * `text` - Text to extract entities from
262    /// * `_language` - Optional language hint (unused, model handles multiple languages)
263    ///
264    /// # Returns
265    /// Vector of NER entities with positions, types, and confidence scores
266    pub fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
267        if text.is_empty() {
268            return Ok(vec![]);
269        }
270
271        // Tokenize input text
272        let encoding = self
273            .tokenizer
274            .encode(text, true)
275            .map_err(|e| Error::Parse(format!("Failed to tokenize input: {}", e)))?;
276
277        let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
278        let attention_mask: Vec<i64> = encoding
279            .get_attention_mask()
280            .iter()
281            .map(|&mask| mask as i64)
282            .collect();
283        // Performance: Pre-allocate token_type_ids with known size
284        // token_type_ids: all zeros for single-sequence NER
285        let token_type_ids: Vec<i64> = vec![0i64; input_ids.len()];
286
287        let batch_size = 1;
288        let seq_len = input_ids.len();
289
290        // Create input tensors
291        let input_ids_array: Array2<i64> =
292            Array2::from_shape_vec((batch_size, seq_len), input_ids.clone())
293                .map_err(|e| Error::Parse(format!("Failed to create input_ids array: {}", e)))?;
294
295        let attention_mask_array: Array2<i64> =
296            Array2::from_shape_vec((batch_size, seq_len), attention_mask.clone()).map_err(|e| {
297                Error::Parse(format!("Failed to create attention_mask array: {}", e))
298            })?;
299
300        let token_type_ids_array: Array2<i64> =
301            Array2::from_shape_vec((batch_size, seq_len), token_type_ids).map_err(|e| {
302                Error::Parse(format!("Failed to create token_type_ids array: {}", e))
303            })?;
304
305        let input_ids_tensor = super::ort_compat::tensor_from_ndarray(input_ids_array)
306            .map_err(|e| Error::Parse(format!("Failed to create input_ids tensor: {}", e)))?;
307
308        let attention_mask_tensor = super::ort_compat::tensor_from_ndarray(attention_mask_array)
309            .map_err(|e| Error::Parse(format!("Failed to create attention_mask tensor: {}", e)))?;
310
311        let token_type_ids_tensor = super::ort_compat::tensor_from_ndarray(token_type_ids_array)
312            .map_err(|e| Error::Parse(format!("Failed to create token_type_ids tensor: {}", e)))?;
313
314        // Run inference with blocking lock for thread-safe parallel access
315        let mut session = lock(&self.session);
316
317        let outputs = session
318            .run(ort::inputs![
319                "input_ids" => input_ids_tensor.into_dyn(),
320                "attention_mask" => attention_mask_tensor.into_dyn(),
321                "token_type_ids" => token_type_ids_tensor.into_dyn(),
322            ])
323            .map_err(|e| Error::Parse(format!("ONNX inference failed: {}", e)))?;
324
325        // Get logits output - BERT NER models have "logits" as output
326        let logits = outputs.get("logits").ok_or_else(|| {
327            Error::Parse("ONNX model output does not contain 'logits' key".to_string())
328        })?;
329
330        // Decode logits to entities
331        self.decode_output(logits, text, &encoding)
332    }
333
334    /// Decode model output logits to NER entities.
335    fn decode_output(
336        &self,
337        output: &ort::value::DynValue,
338        text: &str,
339        encoding: &tokenizers::Encoding,
340    ) -> Result<Vec<Entity>> {
341        // Extract logits as f32 array - ort returns (Shape, &[f32])
342        let (shape, logits_data) = output
343            .try_extract_tensor::<f32>()
344            .map_err(|e| Error::Parse(format!("Failed to extract logits tensor: {}", e)))?;
345
346        // Expected shape: [batch_size, seq_len, num_labels]
347        if shape.len() != 3 || shape[0] != 1 {
348            return Err(Error::Parse(format!(
349                "Unexpected logits shape: {:?}",
350                shape
351            )));
352        }
353
354        let seq_len = shape[1] as usize;
355        let num_labels = shape[2] as usize;
356
357        // Get token offsets for mapping back to character positions
358        let offsets = encoding.get_offsets();
359
360        // `tokenizers::Encoding::get_offsets()` uses byte offsets in Rust. `Entity` uses character
361        // offsets, so convert via SpanConverter when constructing entities.
362        let span_converter = crate::offset::SpanConverter::new(text);
363
364        // Helper to access logits[0, token_idx, label_idx] in flattened array
365        let get_logit = |token_idx: usize, label_idx: usize| -> f32 {
366            logits_data[token_idx * num_labels + label_idx]
367        };
368
369        // Performance: Pre-allocate entities vec with estimated capacity
370        // Most texts have 0-20 entities, but we'll start with a reasonable default
371        let mut entities = Vec::with_capacity(16);
372        let mut current_entity: Option<(usize, usize, EntityType, f64)> = None; // (start_byte, end_byte, type, confidence)
373
374        for token_idx in 0..seq_len {
375            // Skip special tokens (no offset)
376            if token_idx >= offsets.len() {
377                continue;
378            }
379            let (byte_start, byte_end) = offsets[token_idx];
380            if byte_start == byte_end {
381                // Special token, finalize current entity
382                if let Some((start, end, entity_type, conf)) = current_entity.take() {
383                    if start < end && end <= text.len() {
384                        if let Some(entity_text) = text.get(start..end) {
385                            let entity_text = entity_text.trim();
386                            if !entity_text.is_empty() {
387                                entities.push(Entity::new(
388                                    entity_text.to_string(),
389                                    entity_type,
390                                    span_converter.byte_to_char(start),
391                                    span_converter.byte_to_char(end),
392                                    conf,
393                                ));
394                            }
395                        }
396                    }
397                }
398                continue;
399            }
400
401            // Get logits for this token and find argmax
402            let mut max_idx = 0;
403            let mut max_val = f32::NEG_INFINITY;
404            for label_idx in 0..num_labels {
405                let val = get_logit(token_idx, label_idx);
406                if val > max_val {
407                    max_val = val;
408                    max_idx = label_idx;
409                }
410            }
411
412            // Convert to probability using softmax
413            let exp_sum: f32 = (0..num_labels)
414                .map(|i| (get_logit(token_idx, i) - max_val).exp())
415                .sum();
416            // Handle division by zero: if exp_sum == 0.0 or num_labels == 0, use fallback
417            let confidence = if exp_sum > 0.0 && num_labels > 0 {
418                (1.0_f32 / exp_sum) as f64 // exp(0) / exp_sum = 1/exp_sum
419            } else {
420                0.0 // Fallback for edge cases
421            };
422
423            let label = self
424                .id_to_label
425                .get(&max_idx)
426                .cloned()
427                .unwrap_or_else(|| format!("LABEL_{}", max_idx));
428
429            // Skip "O" (outside) labels
430            if label == "O" {
431                if let Some((start, end, entity_type, conf)) = current_entity.take() {
432                    if start < end && end <= text.len() {
433                        if let Some(entity_text) = text.get(start..end) {
434                            let entity_text = entity_text.trim();
435                            if !entity_text.is_empty() {
436                                entities.push(Entity::new(
437                                    entity_text.to_string(),
438                                    entity_type,
439                                    span_converter.byte_to_char(start),
440                                    span_converter.byte_to_char(end),
441                                    conf,
442                                ));
443                            }
444                        }
445                    }
446                }
447                continue;
448            }
449
450            // Parse BIO tag
451            let (bio, entity_label) = if label.starts_with("B-") {
452                ("B", label[2..].to_string())
453            } else if label.starts_with("I-") {
454                ("I", label[2..].to_string())
455            } else {
456                ("B", label.clone())
457            };
458
459            let entity_type = self
460                .label_to_entity_type
461                .get(&format!("B-{}", entity_label))
462                .or_else(|| self.label_to_entity_type.get(&entity_label))
463                .cloned()
464                .unwrap_or_else(|| EntityType::Other(entity_label.clone()));
465
466            match bio {
467                "B" => {
468                    // Check if this B- tag should merge with previous entity
469                    // This handles subword tokenization where "Biden" becomes ["B", "##iden"]
470                    // and both tokens get B-PER labels
471                    let should_merge = if let Some((_, prev_end, ref prev_type, _)) = current_entity
472                    {
473                        // Merge if: same type AND adjacent (no gap or only whitespace)
474                        std::mem::discriminant(prev_type) == std::mem::discriminant(&entity_type)
475                            && byte_start <= prev_end + 1 // Adjacent or overlapping
476                    } else {
477                        false
478                    };
479
480                    if should_merge {
481                        // Extend the current entity instead of starting new
482                        if let Some((start, _, prev_type, conf)) = current_entity.take() {
483                            current_entity = Some((start, byte_end, prev_type, conf));
484                        }
485                    } else {
486                        // Finalize previous entity and start new
487                        if let Some((start, end, prev_type, conf)) = current_entity.take() {
488                            if start < end && end <= text.len() {
489                                if let Some(entity_text) = text.get(start..end) {
490                                    let entity_text = entity_text.trim();
491                                    if !entity_text.is_empty() {
492                                        entities.push(Entity::new(
493                                            entity_text.to_string(),
494                                            prev_type,
495                                            span_converter.byte_to_char(start),
496                                            span_converter.byte_to_char(end),
497                                            conf,
498                                        ));
499                                    }
500                                }
501                            }
502                        }
503                        // Start new entity
504                        current_entity = Some((byte_start, byte_end, entity_type, confidence));
505                    }
506                }
507                "I" => {
508                    // Continue current entity if same type
509                    if let Some((start, _end, ref prev_type, conf)) = current_entity {
510                        if std::mem::discriminant(prev_type) == std::mem::discriminant(&entity_type)
511                        {
512                            current_entity = Some((start, byte_end, entity_type, conf));
513                        } else {
514                            // Different type - finalize and start new
515                            if start < _end && _end <= text.len() {
516                                if let Some(entity_text) = text.get(start.._end) {
517                                    let entity_text = entity_text.trim();
518                                    if !entity_text.is_empty() {
519                                        entities.push(Entity::new(
520                                            entity_text.to_string(),
521                                            prev_type.clone(),
522                                            span_converter.byte_to_char(start),
523                                            span_converter.byte_to_char(_end),
524                                            conf,
525                                        ));
526                                    }
527                                }
528                            }
529                            current_entity = Some((byte_start, byte_end, entity_type, confidence));
530                        }
531                    } else {
532                        // No current entity, treat I- as B-
533                        current_entity = Some((byte_start, byte_end, entity_type, confidence));
534                    }
535                }
536                _ => {}
537            }
538        }
539
540        // Finalize last entity
541        if let Some((start, end, entity_type, conf)) = current_entity {
542            if start < end && end <= text.len() {
543                if let Some(entity_text) = text.get(start..end) {
544                    let entity_text = entity_text.trim();
545                    if !entity_text.is_empty() {
546                        entities.push(Entity::new(
547                            entity_text.to_string(),
548                            entity_type,
549                            span_converter.byte_to_char(start),
550                            span_converter.byte_to_char(end),
551                            conf,
552                        ));
553                    }
554                }
555            }
556        }
557
558        Ok(entities)
559    }
560
561    /// Get the model name.
562    pub fn model_name(&self) -> &str {
563        &self.model_name
564    }
565}
566
567#[cfg(feature = "onnx")]
568impl crate::Model for BertNEROnnx {
569    fn extract_entities(&self, text: &str, language: Option<&str>) -> Result<Vec<Entity>> {
570        self.extract_entities(text, language)
571    }
572
573    fn supported_types(&self) -> Vec<EntityType> {
574        vec![
575            EntityType::Person,
576            EntityType::Organization,
577            EntityType::Location,
578            EntityType::Other("MISC".to_string()),
579        ]
580    }
581
582    fn is_available(&self) -> bool {
583        true
584    }
585
586    fn name(&self) -> &'static str {
587        "bert-onnx"
588    }
589
590    fn description(&self) -> &'static str {
591        "BERT-based NER using ONNX Runtime (PER/ORG/LOC/MISC)"
592    }
593
594    fn version(&self) -> String {
595        format!(
596            "bert-onnx-{}-{}",
597            self.model_name,
598            if self.is_quantized { "q" } else { "fp32" }
599        )
600    }
601
602    fn capabilities(&self) -> crate::ModelCapabilities {
603        crate::ModelCapabilities {
604            batch_capable: true,
605            streaming_capable: true,
606            ..Default::default()
607        }
608    }
609}
610
611impl crate::NamedEntityCapable for BertNEROnnx {}
612
613// =============================================================================
614// BatchCapable Trait Implementation
615// =============================================================================
616
617#[cfg(feature = "onnx")]
618impl crate::BatchCapable for BertNEROnnx {
619    fn optimal_batch_size(&self) -> Option<usize> {
620        Some(8)
621    }
622}
623
624// =============================================================================
625// StreamingCapable Trait Implementation
626// =============================================================================
627
628#[cfg(feature = "onnx")]
629impl crate::StreamingCapable for BertNEROnnx {
630    fn recommended_chunk_size(&self) -> usize {
631        512 // BERT context window
632    }
633}
634
635// Stub implementation when feature is disabled
636#[cfg(not(feature = "onnx"))]
637pub struct BertNEROnnx;
638
639#[cfg(not(feature = "onnx"))]
640impl BertNEROnnx {
641    pub fn new(_model_name: &str) -> Result<Self> {
642        Err(Error::Parse(
643            "BERT NER ONNX support requires 'onnx' feature".to_string(),
644        ))
645    }
646
647    pub fn extract_entities(&self, _text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
648        Err(Error::Parse(
649            "BERT NER ONNX support requires 'onnx' feature".to_string(),
650        ))
651    }
652
653    pub fn model_name(&self) -> &str {
654        "onnx-not-enabled"
655    }
656}
657
658#[cfg(not(feature = "onnx"))]
659impl crate::Model for BertNEROnnx {
660    fn extract_entities(&self, _text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
661        Err(Error::Parse(
662            "BERT NER ONNX support requires 'onnx' feature".to_string(),
663        ))
664    }
665
666    fn supported_types(&self) -> Vec<anno_core::EntityType> {
667        vec![]
668    }
669
670    fn is_available(&self) -> bool {
671        false
672    }
673}