Skip to main content

redact_ner/
recognizer.rs

1// Copyright (c) 2026 Censgate LLC.
2// Licensed under the Business Source License 1.1 (BUSL-1.1).
3// See the LICENSE file in the project root for license details,
4// including the Additional Use Grant, Change Date, and Change License.
5
6use anyhow::{anyhow, Result};
7use ort::session::builder::GraphOptimizationLevel;
8use ort::session::Session;
9use ort::value::Value;
10use redact_core::{EntityType, Recognizer, RecognizerResult};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::path::Path;
14use std::sync::Mutex;
15use tracing::{debug, info, warn};
16
17use crate::tokenizer_wrapper::TokenizerWrapper;
18
19/// Configuration for NER recognizer
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct NerConfig {
22    /// Path to ONNX model file
23    pub model_path: String,
24
25    /// Path to tokenizer file (optional - will use model_path directory)
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub tokenizer_path: Option<String>,
28
29    /// Minimum confidence threshold
30    #[serde(default = "default_confidence")]
31    pub min_confidence: f32,
32
33    /// Maximum sequence length
34    #[serde(default = "default_max_length")]
35    pub max_seq_length: usize,
36
37    /// Entity type mappings from NER labels
38    #[serde(default)]
39    pub label_mappings: HashMap<String, EntityType>,
40
41    /// Label IDs to label strings mapping
42    #[serde(default)]
43    pub id2label: HashMap<usize, String>,
44}
45
46fn default_confidence() -> f32 {
47    0.7
48}
49
50fn default_max_length() -> usize {
51    512
52}
53
54impl Default for NerConfig {
55    fn default() -> Self {
56        let mut label_mappings = HashMap::new();
57        let mut id2label = HashMap::new();
58
59        // Default BIO tagging scheme mappings
60        label_mappings.insert("B-PER".to_string(), EntityType::Person);
61        label_mappings.insert("I-PER".to_string(), EntityType::Person);
62        label_mappings.insert("B-ORG".to_string(), EntityType::Organization);
63        label_mappings.insert("I-ORG".to_string(), EntityType::Organization);
64        label_mappings.insert("B-LOC".to_string(), EntityType::Location);
65        label_mappings.insert("I-LOC".to_string(), EntityType::Location);
66        label_mappings.insert("B-DATE".to_string(), EntityType::DateTime);
67        label_mappings.insert("I-DATE".to_string(), EntityType::DateTime);
68        label_mappings.insert("B-TIME".to_string(), EntityType::DateTime);
69        label_mappings.insert("I-TIME".to_string(), EntityType::DateTime);
70
71        // Default id2label for CoNLL-2003 style models
72        id2label.insert(0, "O".to_string());
73        id2label.insert(1, "B-PER".to_string());
74        id2label.insert(2, "I-PER".to_string());
75        id2label.insert(3, "B-ORG".to_string());
76        id2label.insert(4, "I-ORG".to_string());
77        id2label.insert(5, "B-LOC".to_string());
78        id2label.insert(6, "I-LOC".to_string());
79        id2label.insert(7, "B-MISC".to_string());
80        id2label.insert(8, "I-MISC".to_string());
81
82        Self {
83            model_path: String::new(),
84            tokenizer_path: None,
85            min_confidence: default_confidence(),
86            max_seq_length: default_max_length(),
87            label_mappings,
88            id2label,
89        }
90    }
91}
92
93/// NER-based recognizer using ONNX Runtime
94///
95/// **Status**: ✅ Fully operational with complete ONNX Runtime integration
96///
97/// This recognizer uses transformer-based Named Entity Recognition models for contextual
98/// PII detection. It automatically loads and runs ONNX models with:
99/// - Tokenization with HuggingFace tokenizers
100/// - ONNX Runtime inference with optimizations
101/// - BIO tag parsing for entity span extraction
102/// - Thread-safe session management
103///
104/// **To enable NER**:
105/// 1. Export your NER model to ONNX format using `scripts/export_ner_model.py`
106/// 2. Set `model_path` to point to your `.onnx` file
107/// 3. Optionally provide `tokenizer_path` or place `tokenizer.json` in the same directory
108///
109/// Without a model, this recognizer gracefully returns empty results and the system
110/// falls back to pattern-based detection (36+ entity types).
111pub struct NerRecognizer {
112    config: NerConfig,
113    tokenizer: Option<TokenizerWrapper>,
114    session: Option<Mutex<Session>>,
115}
116
117impl std::fmt::Debug for NerRecognizer {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        f.debug_struct("NerRecognizer")
120            .field("config", &self.config)
121            .field("tokenizer", &self.tokenizer)
122            .field("session", &self.session.as_ref().map(|_| "Session"))
123            .finish()
124    }
125}
126
127impl NerRecognizer {
128    /// Create a new NER recognizer from a model file
129    pub fn from_file<P: AsRef<Path>>(model_path: P) -> Result<Self> {
130        let config = NerConfig {
131            model_path: model_path.as_ref().to_string_lossy().to_string(),
132            ..Default::default()
133        };
134        Self::from_config(config)
135    }
136
137    /// Create a new NER recognizer from configuration
138    pub fn from_config(config: NerConfig) -> Result<Self> {
139        // Try to load tokenizer if available
140        let tokenizer = if let Some(ref tokenizer_path) = config.tokenizer_path {
141            debug!("Loading tokenizer from: {}", tokenizer_path);
142            match TokenizerWrapper::from_file(tokenizer_path) {
143                Ok(t) => {
144                    info!("✓ Tokenizer loaded successfully from: {}", tokenizer_path);
145                    Some(t)
146                }
147                Err(e) => {
148                    warn!(
149                        "Failed to load tokenizer: {}. NER will not be available.",
150                        e
151                    );
152                    None
153                }
154            }
155        } else if !config.model_path.is_empty() {
156            // Try to find tokenizer in same directory as model
157            let model_dir = Path::new(&config.model_path).parent();
158            if let Some(dir) = model_dir {
159                let tokenizer_json = dir.join("tokenizer.json");
160                if tokenizer_json.exists() {
161                    debug!("Loading tokenizer from: {}", tokenizer_json.display());
162                    match TokenizerWrapper::from_file(&tokenizer_json) {
163                        Ok(t) => {
164                            info!("✓ Tokenizer loaded successfully from model directory");
165                            Some(t)
166                        }
167                        Err(e) => {
168                            warn!("Failed to load tokenizer from model directory: {}", e);
169                            None
170                        }
171                    }
172                } else {
173                    debug!("No tokenizer.json found in model directory");
174                    None
175                }
176            } else {
177                None
178            }
179        } else {
180            None
181        };
182
183        // Try to load ONNX model if path is provided
184        let session = if !config.model_path.is_empty() {
185            let model_path = Path::new(&config.model_path);
186            if model_path.exists() {
187                debug!("Loading ONNX model from: {}", config.model_path);
188                match Session::builder()?
189                    .with_optimization_level(GraphOptimizationLevel::Level3)?
190                    .with_intra_threads(4)?
191                    .commit_from_file(&config.model_path)
192                {
193                    Ok(s) => {
194                        info!("✓ ONNX model loaded successfully: {}", config.model_path);
195                        Some(Mutex::new(s))
196                    }
197                    Err(e) => {
198                        warn!(
199                            "Failed to load ONNX model: {}. NER will not be available.",
200                            e
201                        );
202                        None
203                    }
204                }
205            } else {
206                debug!(
207                    "Model path provided but file does not exist: {}",
208                    config.model_path
209                );
210                None
211            }
212        } else {
213            debug!("No model path provided, NER will not be available");
214            None
215        };
216
217        let is_available = tokenizer.is_some() && session.is_some();
218        if is_available {
219            info!("✓ NER is fully operational with ONNX Runtime");
220        } else {
221            info!("⚠ NER not available - using pattern-based detection (36+ entity types)");
222            if tokenizer.is_none() {
223                debug!("  Missing: tokenizer");
224            }
225            if session.is_none() {
226                debug!("  Missing: ONNX model");
227            }
228        }
229
230        Ok(Self {
231            config,
232            tokenizer,
233            session,
234        })
235    }
236
237    /// Get the configuration
238    pub fn config(&self) -> &NerConfig {
239        &self.config
240    }
241
242    /// Check if NER is available (model and tokenizer loaded)
243    pub fn is_available(&self) -> bool {
244        self.tokenizer.is_some() && self.session.is_some()
245    }
246
247    /// Map NER label to entity type
248    fn map_label_to_entity(&self, label: &str) -> Option<EntityType> {
249        self.config.label_mappings.get(label).cloned()
250    }
251
252    /// Run inference on tokenized input
253    fn infer(&self, input_ids: &[u32], attention_mask: &[u32]) -> Result<Vec<Vec<f32>>> {
254        let session_mutex = self
255            .session
256            .as_ref()
257            .ok_or_else(|| anyhow!("ONNX session not loaded"))?;
258
259        let mut session = session_mutex
260            .lock()
261            .map_err(|e| anyhow!("Failed to lock session: {}", e))?;
262
263        // Create 2D arrays with shape [1, seq_len]
264        let seq_len = input_ids.len();
265        let input_ids_i64: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
266        let attention_mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
267
268        // Create Value objects using shape + data tuple approach
269        let input_ids_value = Value::from_array(([1, seq_len], input_ids_i64))?;
270        let attention_mask_value = Value::from_array(([1, seq_len], attention_mask_i64))?;
271
272        // Run inference
273        let outputs = session.run(ort::inputs![
274            "input_ids" => input_ids_value,
275            "attention_mask" => attention_mask_value,
276        ])?;
277
278        // Extract logits - shape should be [1, seq_len, num_labels]
279        let (shape, logits_data) = outputs["logits"].try_extract_tensor::<f32>()?;
280        let shape_dims = shape.as_ref();
281
282        if shape_dims.len() != 3 || shape_dims[0] != 1 {
283            return Err(anyhow!("Unexpected logits shape: {:?}", shape_dims));
284        }
285
286        let seq_len_out = shape_dims[1] as usize;
287        let num_labels = shape_dims[2] as usize;
288
289        // Convert to Vec<Vec<f32>> where outer vec is tokens, inner vec is label scores
290        let mut result = Vec::new();
291        for i in 0..seq_len_out {
292            let mut token_logits = Vec::new();
293            for j in 0..num_labels {
294                let idx = i * num_labels + j;
295                token_logits.push(logits_data[idx]);
296            }
297            result.push(token_logits);
298        }
299
300        Ok(result)
301    }
302
303    /// Apply softmax to convert logits to probabilities
304    fn softmax(logits: &[f32]) -> Vec<f32> {
305        let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
306        let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
307        logits
308            .iter()
309            .map(|&x| (x - max_logit).exp() / exp_sum)
310            .collect()
311    }
312
313    /// Parse BIO tags and extract entity spans
314    fn parse_bio_tags(
315        &self,
316        _text: &str,
317        predictions: &[usize],
318        probabilities: &[f32],
319        offsets: &[(usize, usize)],
320    ) -> Vec<RecognizerResult> {
321        let mut results = Vec::new();
322        let mut current_entity: Option<(EntityType, usize, usize, Vec<f32>)> = None;
323
324        for (idx, (&pred_id, &prob)) in predictions.iter().zip(probabilities.iter()).enumerate() {
325            // Skip padding tokens (offset (0,0))
326            if offsets[idx] == (0, 0) {
327                continue;
328            }
329
330            let label = self
331                .config
332                .id2label
333                .get(&pred_id)
334                .map(|s| s.as_str())
335                .unwrap_or("O");
336
337            if label.starts_with("B-") {
338                // Begin new entity - save previous if exists
339                if let Some((entity_type, start, end, probs)) = current_entity.take() {
340                    let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
341                    if avg_confidence >= self.config.min_confidence {
342                        results.push(RecognizerResult::new(
343                            entity_type,
344                            start,
345                            end,
346                            avg_confidence,
347                            self.name(),
348                        ));
349                    }
350                }
351
352                // Start new entity
353                if let Some(entity_type) = self.map_label_to_entity(label) {
354                    let start = offsets[idx].0;
355                    let end = offsets[idx].1;
356                    current_entity = Some((entity_type, start, end, vec![prob]));
357                }
358            } else if label.starts_with("I-") {
359                // Continue current entity
360                if let Some((ref entity_type, start, ref mut end, ref mut probs)) = current_entity {
361                    // Check if label matches current entity type
362                    if let Some(label_entity) = self.map_label_to_entity(label) {
363                        if label_entity == *entity_type {
364                            *end = offsets[idx].1;
365                            probs.push(prob);
366                        } else {
367                            // Different entity type - save current and start new
368                            let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
369                            if avg_confidence >= self.config.min_confidence {
370                                results.push(RecognizerResult::new(
371                                    entity_type.clone(),
372                                    start,
373                                    *end,
374                                    avg_confidence,
375                                    self.name(),
376                                ));
377                            }
378                            current_entity = None;
379                        }
380                    }
381                }
382            } else {
383                // "O" tag or unknown - end current entity
384                if let Some((entity_type, start, end, probs)) = current_entity.take() {
385                    let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
386                    if avg_confidence >= self.config.min_confidence {
387                        results.push(RecognizerResult::new(
388                            entity_type,
389                            start,
390                            end,
391                            avg_confidence,
392                            self.name(),
393                        ));
394                    }
395                }
396            }
397        }
398
399        // Don't forget the last entity
400        if let Some((entity_type, start, end, probs)) = current_entity {
401            let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
402            if avg_confidence >= self.config.min_confidence {
403                results.push(RecognizerResult::new(
404                    entity_type,
405                    start,
406                    end,
407                    avg_confidence,
408                    self.name(),
409                ));
410            }
411        }
412
413        results
414    }
415}
416
417impl Recognizer for NerRecognizer {
418    fn name(&self) -> &str {
419        "NerRecognizer"
420    }
421
422    fn supported_entities(&self) -> &[EntityType] {
423        &[
424            EntityType::Person,
425            EntityType::Organization,
426            EntityType::Location,
427            EntityType::DateTime,
428        ]
429    }
430
431    fn analyze(&self, text: &str, _language: &str) -> Result<Vec<RecognizerResult>> {
432        // Check if NER is available
433        if !self.is_available() {
434            return Ok(vec![]);
435        }
436
437        let tokenizer = self.tokenizer.as_ref().unwrap();
438
439        // Tokenize input
440        let mut encoding = tokenizer.encode(text, true)?;
441
442        // Get padding token ID
443        let pad_id = tokenizer.get_padding_id().unwrap_or(0);
444
445        // Pad/truncate to max sequence length
446        encoding.pad_to_length(self.config.max_seq_length, pad_id);
447
448        // Run inference
449        let logits = self.infer(&encoding.ids, &encoding.attention_mask)?;
450
451        // Convert logits to predictions
452        let mut predictions = Vec::new();
453        let mut probabilities = Vec::new();
454
455        for token_logits in &logits {
456            let probs = Self::softmax(token_logits);
457            let (pred_id, &max_prob) = probs
458                .iter()
459                .enumerate()
460                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
461                .unwrap();
462            predictions.push(pred_id);
463            probabilities.push(max_prob);
464        }
465
466        // Parse BIO tags to extract entities
467        let results = self.parse_bio_tags(text, &predictions, &probabilities, &encoding.offsets);
468
469        Ok(results)
470    }
471
472    fn supports_language(&self, language: &str) -> bool {
473        // Most multilingual NER models support these languages
474        matches!(
475            language,
476            "en" | "es" | "fr" | "de" | "it" | "pt" | "nl" | "pl" | "ru" | "zh" | "ja" | "ko"
477        )
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484
485    #[test]
486    fn test_default_config() {
487        let config = NerConfig::default();
488        assert_eq!(config.min_confidence, 0.7);
489        assert_eq!(config.max_seq_length, 512);
490        assert!(!config.label_mappings.is_empty());
491    }
492
493    #[test]
494    fn test_label_mapping() {
495        let config = NerConfig::default();
496        let recognizer = NerRecognizer::from_config(config).unwrap();
497
498        assert_eq!(
499            recognizer.map_label_to_entity("B-PER"),
500            Some(EntityType::Person)
501        );
502        assert_eq!(
503            recognizer.map_label_to_entity("B-ORG"),
504            Some(EntityType::Organization)
505        );
506        assert_eq!(recognizer.map_label_to_entity("O"), None);
507    }
508
509    #[test]
510    fn test_recognizer_without_model() {
511        let config = NerConfig::default();
512        let recognizer = NerRecognizer::from_config(config).unwrap();
513
514        // Should not be available without model
515        assert!(!recognizer.is_available());
516
517        // Should return empty results
518        let results = recognizer.analyze("John Doe", "en").unwrap();
519        assert_eq!(results.len(), 0);
520    }
521}