Skip to main content

redact_ner/
recognizer.rs

1// Copyright 2026 Censgate LLC.
2// Licensed under the Apache License, Version 2.0. See the LICENSE file
3// in the project root for license information.
4
5use anyhow::{anyhow, Result};
6use ort::session::builder::GraphOptimizationLevel;
7use ort::session::Session;
8use ort::value::Value;
9use redact_core::{EntityType, Recognizer, RecognizerResult};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::path::Path;
13use std::sync::Mutex;
14use tracing::{debug, info, warn};
15
16use crate::tokenizer_wrapper::TokenizerWrapper;
17
18/// Configuration for NER recognizer
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct NerConfig {
21    /// Path to ONNX model file
22    pub model_path: String,
23
24    /// Path to tokenizer file (optional - will use model_path directory)
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub tokenizer_path: Option<String>,
27
28    /// Minimum confidence threshold
29    #[serde(default = "default_confidence")]
30    pub min_confidence: f32,
31
32    /// Maximum sequence length
33    #[serde(default = "default_max_length")]
34    pub max_seq_length: usize,
35
36    /// Entity type mappings from NER labels
37    #[serde(default)]
38    pub label_mappings: HashMap<String, EntityType>,
39
40    /// Label IDs to label strings mapping
41    #[serde(default)]
42    pub id2label: HashMap<usize, String>,
43}
44
45fn default_confidence() -> f32 {
46    0.7
47}
48
49fn default_max_length() -> usize {
50    512
51}
52
53impl Default for NerConfig {
54    fn default() -> Self {
55        let mut label_mappings = HashMap::new();
56        let mut id2label = HashMap::new();
57
58        // Default BIO tagging scheme mappings
59        label_mappings.insert("B-PER".to_string(), EntityType::Person);
60        label_mappings.insert("I-PER".to_string(), EntityType::Person);
61        label_mappings.insert("B-ORG".to_string(), EntityType::Organization);
62        label_mappings.insert("I-ORG".to_string(), EntityType::Organization);
63        label_mappings.insert("B-LOC".to_string(), EntityType::Location);
64        label_mappings.insert("I-LOC".to_string(), EntityType::Location);
65        label_mappings.insert("B-DATE".to_string(), EntityType::DateTime);
66        label_mappings.insert("I-DATE".to_string(), EntityType::DateTime);
67        label_mappings.insert("B-TIME".to_string(), EntityType::DateTime);
68        label_mappings.insert("I-TIME".to_string(), EntityType::DateTime);
69
70        // Default id2label for CoNLL-2003 style models
71        id2label.insert(0, "O".to_string());
72        id2label.insert(1, "B-PER".to_string());
73        id2label.insert(2, "I-PER".to_string());
74        id2label.insert(3, "B-ORG".to_string());
75        id2label.insert(4, "I-ORG".to_string());
76        id2label.insert(5, "B-LOC".to_string());
77        id2label.insert(6, "I-LOC".to_string());
78        id2label.insert(7, "B-MISC".to_string());
79        id2label.insert(8, "I-MISC".to_string());
80
81        Self {
82            model_path: String::new(),
83            tokenizer_path: None,
84            min_confidence: default_confidence(),
85            max_seq_length: default_max_length(),
86            label_mappings,
87            id2label,
88        }
89    }
90}
91
92/// NER-based recognizer using ONNX Runtime
93///
94/// **Status**: ✅ Fully operational with complete ONNX Runtime integration
95///
96/// This recognizer uses transformer-based Named Entity Recognition models for contextual
97/// PII detection. It automatically loads and runs ONNX models with:
98/// - Tokenization with HuggingFace tokenizers
99/// - ONNX Runtime inference with optimizations
100/// - BIO tag parsing for entity span extraction
101/// - Thread-safe session management
102///
103/// **To enable NER**:
104/// 1. Export your NER model to ONNX format using `scripts/export_ner_model.py`
105/// 2. Set `model_path` to point to your `.onnx` file
106/// 3. Optionally provide `tokenizer_path` or place `tokenizer.json` in the same directory
107///
108/// Without a model, this recognizer gracefully returns empty results and the system
109/// falls back to pattern-based detection (36+ entity types).
110pub struct NerRecognizer {
111    config: NerConfig,
112    tokenizer: Option<TokenizerWrapper>,
113    session: Option<Mutex<Session>>,
114    /// Whether the ONNX model accepts `token_type_ids` as an input.
115    /// BERT-family models require it; DistilBERT and others do not.
116    /// Determined at model-load time by inspecting `Session::inputs()`.
117    needs_token_type_ids: bool,
118}
119
120impl std::fmt::Debug for NerRecognizer {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        f.debug_struct("NerRecognizer")
123            .field("config", &self.config)
124            .field("tokenizer", &self.tokenizer)
125            .field("session", &self.session.as_ref().map(|_| "Session"))
126            .field("needs_token_type_ids", &self.needs_token_type_ids)
127            .finish()
128    }
129}
130
131impl NerRecognizer {
132    /// Create a new NER recognizer from a model file.
133    ///
134    /// Automatically loads `config.json` from the model directory (if present)
135    /// to get the correct `id2label` and `label_mappings` for the exported model.
136    /// Falls back to default CoNLL-2003 mappings when no config is found.
137    pub fn from_file<P: AsRef<Path>>(model_path: P) -> Result<Self> {
138        let model_path_ref = model_path.as_ref();
139        let model_path_str = model_path_ref.to_string_lossy().to_string();
140
141        // Try loading config.json from model directory (written by export_ner_model.py)
142        let config = if let Some(model_dir) = model_path_ref.parent() {
143            let config_path = model_dir.join("config.json");
144            if config_path.exists() {
145                debug!("Loading NER config from: {}", config_path.display());
146                match Self::load_config_from_file(&config_path, &model_path_str) {
147                    Ok(cfg) => cfg,
148                    Err(e) => {
149                        warn!("Failed to load NER config.json: {}. Using defaults.", e);
150                        NerConfig {
151                            model_path: model_path_str,
152                            ..Default::default()
153                        }
154                    }
155                }
156            } else {
157                debug!("No config.json in model directory, using default label mappings");
158                NerConfig {
159                    model_path: model_path_str,
160                    ..Default::default()
161                }
162            }
163        } else {
164            NerConfig {
165                model_path: model_path_str,
166                ..Default::default()
167            }
168        };
169
170        Self::from_config(config)
171    }
172
173    /// Load NER config from a JSON file produced by `export_ner_model.py`.
174    ///
175    /// Handles format differences between the Python export (string keys, PascalCase
176    /// entity names) and Rust types (usize keys, SCREAMING_SNAKE_CASE EntityType).
177    fn load_config_from_file(config_path: &Path, model_path: &str) -> Result<NerConfig> {
178        let json_str = std::fs::read_to_string(config_path)?;
179        let raw: serde_json::Value = serde_json::from_str(&json_str)?;
180
181        let defaults = NerConfig::default();
182
183        // Parse id2label: JSON has string keys like {"0": "O", "1": "B-MISC", ...}
184        let id2label = if let Some(obj) = raw.get("id2label").and_then(|v| v.as_object()) {
185            let mut map = HashMap::new();
186            for (k, v) in obj {
187                if let (Ok(id), Some(label)) = (k.parse::<usize>(), v.as_str()) {
188                    map.insert(id, label.to_string());
189                }
190            }
191            map
192        } else {
193            defaults.id2label.clone()
194        };
195
196        // Parse label_mappings: JSON has {"B-PER": "Person", ...}
197        // EntityType::from() handles case-insensitive conversion
198        let label_mappings =
199            if let Some(obj) = raw.get("label_mappings").and_then(|v| v.as_object()) {
200                let mut map = HashMap::new();
201                for (k, v) in obj {
202                    if let Some(entity_str) = v.as_str() {
203                        map.insert(k.clone(), EntityType::from(entity_str.to_string()));
204                    }
205                }
206                map
207            } else {
208                // Build label_mappings purely from id2label (no stale defaults).
209                let mut map = HashMap::new();
210                for label in id2label.values() {
211                    if label == "O" {
212                        continue;
213                    }
214                    let entity_type = label.split('-').next_back().unwrap_or(label);
215                    match entity_type {
216                        "PER" | "PERSON" => {
217                            map.insert(label.clone(), EntityType::Person);
218                        }
219                        "ORG" | "ORGANIZATION" => {
220                            map.insert(label.clone(), EntityType::Organization);
221                        }
222                        "LOC" | "LOCATION" | "GPE" => {
223                            map.insert(label.clone(), EntityType::Location);
224                        }
225                        "DATE" | "TIME" | "DATETIME" => {
226                            map.insert(label.clone(), EntityType::DateTime);
227                        }
228                        _ => {
229                            debug!("Unmapped NER label: {} — no EntityType match", label);
230                        }
231                    }
232                }
233                map
234            };
235
236        let min_confidence = raw
237            .get("min_confidence")
238            .and_then(|v| v.as_f64())
239            .map(|v| v as f32)
240            .unwrap_or(defaults.min_confidence);
241
242        let max_seq_length = raw
243            .get("max_seq_length")
244            .and_then(|v| v.as_u64())
245            .map(|v| v as usize)
246            .unwrap_or(defaults.max_seq_length);
247
248        // Intentionally ignore tokenizer_path from config.json: the export script
249        // writes a build-time path (e.g. /out/models/tokenizer.json) that won't exist
250        // at runtime. from_config() auto-discovers tokenizer.json from the model directory.
251        let tokenizer_path = None;
252
253        info!(
254            "Loaded NER config from {} ({} label mappings, {} id2label entries)",
255            config_path.display(),
256            label_mappings.len(),
257            id2label.len()
258        );
259
260        Ok(NerConfig {
261            model_path: model_path.to_string(),
262            tokenizer_path,
263            min_confidence,
264            max_seq_length,
265            label_mappings,
266            id2label,
267        })
268    }
269
270    /// Create a new NER recognizer from configuration
271    pub fn from_config(config: NerConfig) -> Result<Self> {
272        // Try to load tokenizer if available
273        let tokenizer = if let Some(ref tokenizer_path) = config.tokenizer_path {
274            debug!("Loading tokenizer from: {}", tokenizer_path);
275            match TokenizerWrapper::from_file(tokenizer_path) {
276                Ok(t) => {
277                    info!("✓ Tokenizer loaded successfully from: {}", tokenizer_path);
278                    Some(t)
279                }
280                Err(e) => {
281                    warn!(
282                        "Failed to load tokenizer: {}. NER will not be available.",
283                        e
284                    );
285                    None
286                }
287            }
288        } else if !config.model_path.is_empty() {
289            // Try to find tokenizer in same directory as model
290            let model_dir = Path::new(&config.model_path).parent();
291            if let Some(dir) = model_dir {
292                let tokenizer_json = dir.join("tokenizer.json");
293                if tokenizer_json.exists() {
294                    debug!("Loading tokenizer from: {}", tokenizer_json.display());
295                    match TokenizerWrapper::from_file(&tokenizer_json) {
296                        Ok(t) => {
297                            info!("✓ Tokenizer loaded successfully from model directory");
298                            Some(t)
299                        }
300                        Err(e) => {
301                            warn!("Failed to load tokenizer from model directory: {}", e);
302                            None
303                        }
304                    }
305                } else {
306                    debug!("No tokenizer.json found in model directory");
307                    None
308                }
309            } else {
310                None
311            }
312        } else {
313            None
314        };
315
316        // Try to load ONNX model if path is provided
317        let session = if !config.model_path.is_empty() {
318            let model_path = Path::new(&config.model_path);
319            if model_path.exists() {
320                debug!("Loading ONNX model from: {}", config.model_path);
321                match Session::builder()?
322                    .with_optimization_level(GraphOptimizationLevel::Level3)
323                    .map_err(|e| anyhow::anyhow!("{e}"))?
324                    .with_intra_threads(4)
325                    .map_err(|e| anyhow::anyhow!("{e}"))?
326                    .commit_from_file(&config.model_path)
327                {
328                    Ok(s) => {
329                        info!("✓ ONNX model loaded successfully: {}", config.model_path);
330                        Some(Mutex::new(s))
331                    }
332                    Err(e) => {
333                        warn!(
334                            "Failed to load ONNX model: {}. NER will not be available.",
335                            e
336                        );
337                        None
338                    }
339                }
340            } else {
341                debug!(
342                    "Model path provided but file does not exist: {}",
343                    config.model_path
344                );
345                None
346            }
347        } else {
348            debug!("No model path provided, NER will not be available");
349            None
350        };
351
352        // Inspect model inputs at construction time to determine whether the
353        // model expects token_type_ids (BERT-family) or not (DistilBERT, etc.).
354        let needs_token_type_ids = session.as_ref().is_some_and(|s| {
355            let guard = s.lock().expect("session lock poisoned during init");
356            let has_it = guard
357                .inputs()
358                .iter()
359                .any(|input| input.name() == "token_type_ids");
360            if has_it {
361                debug!("Model declares token_type_ids input — will include in inference");
362            } else {
363                debug!("Model does not declare token_type_ids — omitting from inference");
364            }
365            has_it
366        });
367
368        let is_available = tokenizer.is_some() && session.is_some();
369        if is_available {
370            info!("✓ NER is fully operational with ONNX Runtime");
371        } else {
372            info!("⚠ NER not available - using pattern-based detection (36+ entity types)");
373            if tokenizer.is_none() {
374                debug!("  Missing: tokenizer");
375            }
376            if session.is_none() {
377                debug!("  Missing: ONNX model");
378            }
379        }
380
381        Ok(Self {
382            config,
383            tokenizer,
384            session,
385            needs_token_type_ids,
386        })
387    }
388
389    /// Get the configuration
390    pub fn config(&self) -> &NerConfig {
391        &self.config
392    }
393
394    /// Check if NER is available (model and tokenizer loaded)
395    pub fn is_available(&self) -> bool {
396        self.tokenizer.is_some() && self.session.is_some()
397    }
398
399    /// Map NER label to entity type
400    fn map_label_to_entity(&self, label: &str) -> Option<EntityType> {
401        self.config.label_mappings.get(label).cloned()
402    }
403
404    /// Run inference on tokenized input
405    fn infer(&self, input_ids: &[u32], attention_mask: &[u32]) -> Result<Vec<Vec<f32>>> {
406        let session_mutex = self
407            .session
408            .as_ref()
409            .ok_or_else(|| anyhow!("ONNX session not loaded"))?;
410
411        let mut session = session_mutex
412            .lock()
413            .map_err(|e| anyhow!("Failed to lock session: {}", e))?;
414
415        // Create 2D arrays with shape [1, seq_len]
416        let seq_len = input_ids.len();
417        let input_ids_i64: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
418        let attention_mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
419
420        let input_ids_value = Value::from_array(([1, seq_len], input_ids_i64))?;
421        let attention_mask_value = Value::from_array(([1, seq_len], attention_mask_i64))?;
422
423        // Build inputs list — only include token_type_ids when the model expects it
424        // (BERT-family needs it; DistilBERT and others do not).
425        let mut inputs: Vec<(std::borrow::Cow<'_, str>, Value)> = vec![
426            ("input_ids".into(), input_ids_value.into()),
427            ("attention_mask".into(), attention_mask_value.into()),
428        ];
429
430        if self.needs_token_type_ids {
431            let token_type_ids_i64: Vec<i64> = vec![0i64; seq_len];
432            let token_type_ids_value = Value::from_array(([1, seq_len], token_type_ids_i64))?;
433            inputs.push(("token_type_ids".into(), token_type_ids_value.into()));
434        }
435
436        let outputs = session.run(inputs)?;
437
438        // Extract logits - shape should be [1, seq_len, num_labels]
439        let (shape, logits_data) = outputs["logits"].try_extract_tensor::<f32>()?;
440        let shape_dims: &[i64] = shape.as_ref();
441
442        if shape_dims.len() != 3 || shape_dims[0] != 1 {
443            return Err(anyhow!("Unexpected logits shape: {:?}", shape_dims));
444        }
445
446        let seq_len_out = shape_dims[1] as usize;
447        let num_labels = shape_dims[2] as usize;
448
449        // Convert to Vec<Vec<f32>> where outer vec is tokens, inner vec is label scores
450        let mut result = Vec::new();
451        for i in 0..seq_len_out {
452            let mut token_logits = Vec::new();
453            for j in 0..num_labels {
454                let idx = i * num_labels + j;
455                token_logits.push(logits_data[idx]);
456            }
457            result.push(token_logits);
458        }
459
460        Ok(result)
461    }
462
463    /// Apply softmax to convert logits to probabilities
464    fn softmax(logits: &[f32]) -> Vec<f32> {
465        let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
466        let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
467        logits
468            .iter()
469            .map(|&x| (x - max_logit).exp() / exp_sum)
470            .collect()
471    }
472
473    /// Parse BIO tags and extract entity spans
474    fn parse_bio_tags(
475        &self,
476        _text: &str,
477        predictions: &[usize],
478        probabilities: &[f32],
479        offsets: &[(usize, usize)],
480    ) -> Vec<RecognizerResult> {
481        let mut results = Vec::new();
482        let mut current_entity: Option<(EntityType, usize, usize, Vec<f32>)> = None;
483
484        for (idx, (&pred_id, &prob)) in predictions.iter().zip(probabilities.iter()).enumerate() {
485            // Skip padding tokens (offset (0,0))
486            if offsets[idx] == (0, 0) {
487                continue;
488            }
489
490            let label = self
491                .config
492                .id2label
493                .get(&pred_id)
494                .map(|s| s.as_str())
495                .unwrap_or("O");
496
497            if label.starts_with("B-") {
498                // Begin new entity - save previous if exists
499                if let Some((entity_type, start, end, probs)) = current_entity.take() {
500                    let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
501                    if avg_confidence >= self.config.min_confidence {
502                        results.push(RecognizerResult::new(
503                            entity_type,
504                            start,
505                            end,
506                            avg_confidence,
507                            self.name(),
508                        ));
509                    }
510                }
511
512                // Start new entity
513                if let Some(entity_type) = self.map_label_to_entity(label) {
514                    let start = offsets[idx].0;
515                    let end = offsets[idx].1;
516                    current_entity = Some((entity_type, start, end, vec![prob]));
517                }
518            } else if label.starts_with("I-") {
519                // Continue current entity
520                if let Some((ref entity_type, start, ref mut end, ref mut probs)) = current_entity {
521                    // Check if label matches current entity type
522                    if let Some(label_entity) = self.map_label_to_entity(label) {
523                        if label_entity == *entity_type {
524                            *end = offsets[idx].1;
525                            probs.push(prob);
526                        } else {
527                            // Different entity type - save current and start new
528                            let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
529                            if avg_confidence >= self.config.min_confidence {
530                                results.push(RecognizerResult::new(
531                                    entity_type.clone(),
532                                    start,
533                                    *end,
534                                    avg_confidence,
535                                    self.name(),
536                                ));
537                            }
538                            current_entity = None;
539                        }
540                    }
541                }
542            } else {
543                // "O" tag or unknown - end current entity
544                if let Some((entity_type, start, end, probs)) = current_entity.take() {
545                    let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
546                    if avg_confidence >= self.config.min_confidence {
547                        results.push(RecognizerResult::new(
548                            entity_type,
549                            start,
550                            end,
551                            avg_confidence,
552                            self.name(),
553                        ));
554                    }
555                }
556            }
557        }
558
559        // Don't forget the last entity
560        if let Some((entity_type, start, end, probs)) = current_entity {
561            let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
562            if avg_confidence >= self.config.min_confidence {
563                results.push(RecognizerResult::new(
564                    entity_type,
565                    start,
566                    end,
567                    avg_confidence,
568                    self.name(),
569                ));
570            }
571        }
572
573        results
574    }
575}
576
577impl Recognizer for NerRecognizer {
578    fn name(&self) -> &str {
579        "NerRecognizer"
580    }
581
582    fn supported_entities(&self) -> &[EntityType] {
583        &[
584            EntityType::Person,
585            EntityType::Organization,
586            EntityType::Location,
587            EntityType::DateTime,
588        ]
589    }
590
591    fn analyze(&self, text: &str, _language: &str) -> Result<Vec<RecognizerResult>> {
592        // Check if NER is available
593        if !self.is_available() {
594            return Ok(vec![]);
595        }
596
597        let tokenizer = self.tokenizer.as_ref().unwrap();
598
599        // Tokenize input
600        let mut encoding = tokenizer.encode(text, true)?;
601
602        // Get padding token ID
603        let pad_id = tokenizer.get_padding_id().unwrap_or(0);
604
605        // Pad/truncate to max sequence length
606        encoding.pad_to_length(self.config.max_seq_length, pad_id);
607
608        // Run inference
609        let logits = self.infer(&encoding.ids, &encoding.attention_mask)?;
610
611        // Convert logits to predictions
612        let mut predictions = Vec::new();
613        let mut probabilities = Vec::new();
614
615        for token_logits in &logits {
616            let probs = Self::softmax(token_logits);
617            let (pred_id, &max_prob) = probs
618                .iter()
619                .enumerate()
620                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
621                .unwrap();
622            predictions.push(pred_id);
623            probabilities.push(max_prob);
624        }
625
626        // Parse BIO tags to extract entities
627        let results = self.parse_bio_tags(text, &predictions, &probabilities, &encoding.offsets);
628
629        Ok(results)
630    }
631
632    fn supports_language(&self, language: &str) -> bool {
633        // Most multilingual NER models support these languages
634        matches!(
635            language,
636            "en" | "es" | "fr" | "de" | "it" | "pt" | "nl" | "pl" | "ru" | "zh" | "ja" | "ko"
637        )
638    }
639}
640
641#[cfg(test)]
642mod tests {
643    use super::*;
644    use std::io::Write;
645
646    #[test]
647    fn test_default_config() {
648        let config = NerConfig::default();
649        assert_eq!(config.min_confidence, 0.7);
650        assert_eq!(config.max_seq_length, 512);
651        assert!(!config.label_mappings.is_empty());
652    }
653
654    #[test]
655    fn test_label_mapping() {
656        let config = NerConfig::default();
657        let recognizer = NerRecognizer::from_config(config).unwrap();
658
659        assert_eq!(
660            recognizer.map_label_to_entity("B-PER"),
661            Some(EntityType::Person)
662        );
663        assert_eq!(
664            recognizer.map_label_to_entity("B-ORG"),
665            Some(EntityType::Organization)
666        );
667        assert_eq!(recognizer.map_label_to_entity("O"), None);
668    }
669
670    #[test]
671    fn test_recognizer_without_model() {
672        let config = NerConfig::default();
673        let recognizer = NerRecognizer::from_config(config).unwrap();
674
675        // Should not be available without model
676        assert!(!recognizer.is_available());
677
678        // Should return empty results
679        let results = recognizer.analyze("John Doe", "en").unwrap();
680        assert_eq!(results.len(), 0);
681    }
682
683    #[test]
684    fn test_recognizer_without_model_has_no_token_type_ids() {
685        let config = NerConfig::default();
686        let recognizer = NerRecognizer::from_config(config).unwrap();
687
688        // No session loaded → flag defaults to false
689        assert!(!recognizer.needs_token_type_ids);
690    }
691
692    // ---- load_config_from_file tests ----
693
694    /// Helper: write `contents` to a temp file and return its path.
695    fn write_temp_config(contents: &str) -> tempfile::NamedTempFile {
696        let mut f = tempfile::NamedTempFile::new().unwrap();
697        f.write_all(contents.as_bytes()).unwrap();
698        f.flush().unwrap();
699        f
700    }
701
702    #[test]
703    fn test_load_config_valid_with_both_id2label_and_label_mappings() {
704        let json = r#"{
705            "id2label": {
706                "0": "O",
707                "1": "B-MISC",
708                "2": "I-MISC",
709                "3": "B-PER",
710                "4": "I-PER",
711                "5": "B-ORG",
712                "6": "I-ORG",
713                "7": "B-LOC",
714                "8": "I-LOC"
715            },
716            "label_mappings": {
717                "B-PER": "Person",
718                "I-PER": "Person",
719                "B-ORG": "Organization",
720                "I-ORG": "Organization",
721                "B-LOC": "Location",
722                "I-LOC": "Location"
723            },
724            "min_confidence": 0.8,
725            "max_seq_length": 256,
726            "tokenizer_path": "/build/time/tokenizer.json"
727        }"#;
728
729        let f = write_temp_config(json);
730        let cfg = NerRecognizer::load_config_from_file(f.path(), "/runtime/model.onnx").unwrap();
731
732        // id2label parsed correctly
733        assert_eq!(cfg.id2label.len(), 9);
734        assert_eq!(cfg.id2label[&3], "B-PER");
735        assert_eq!(cfg.id2label[&5], "B-ORG");
736
737        // label_mappings parsed correctly (PascalCase → EntityType)
738        assert_eq!(cfg.label_mappings.len(), 6);
739        assert_eq!(cfg.label_mappings["B-PER"], EntityType::Person);
740        assert_eq!(cfg.label_mappings["B-ORG"], EntityType::Organization);
741        assert_eq!(cfg.label_mappings["B-LOC"], EntityType::Location);
742
743        // Scalars honoured
744        assert_eq!(cfg.min_confidence, 0.8);
745        assert_eq!(cfg.max_seq_length, 256);
746
747        // model_path overridden to runtime value
748        assert_eq!(cfg.model_path, "/runtime/model.onnx");
749
750        // tokenizer_path always suppressed regardless of config.json content
751        assert!(cfg.tokenizer_path.is_none());
752    }
753
754    #[test]
755    fn test_load_config_fallback_derives_label_mappings_from_id2label() {
756        // config.json has id2label but no label_mappings → derived path
757        let json = r#"{
758            "id2label": {
759                "0": "O",
760                "1": "B-MISC",
761                "2": "I-MISC",
762                "3": "B-PER",
763                "4": "I-PER",
764                "5": "B-ORG",
765                "6": "I-ORG",
766                "7": "B-LOC",
767                "8": "I-LOC"
768            }
769        }"#;
770
771        let f = write_temp_config(json);
772        let cfg = NerRecognizer::load_config_from_file(f.path(), "/m.onnx").unwrap();
773
774        // Derived mappings should include PER, ORG, LOC but NOT MISC or stale defaults
775        assert_eq!(cfg.label_mappings.get("B-PER"), Some(&EntityType::Person));
776        assert_eq!(cfg.label_mappings.get("I-PER"), Some(&EntityType::Person));
777        assert_eq!(
778            cfg.label_mappings.get("B-ORG"),
779            Some(&EntityType::Organization)
780        );
781        assert_eq!(cfg.label_mappings.get("B-LOC"), Some(&EntityType::Location));
782
783        // MISC labels should NOT appear (no EntityType mapping exists)
784        assert!(cfg.label_mappings.get("B-MISC").is_none());
785        assert!(cfg.label_mappings.get("I-MISC").is_none());
786
787        // No stale defaults: B-DATE / I-DATE should NOT leak in because
788        // they are not present in the provided id2label
789        assert!(cfg.label_mappings.get("B-DATE").is_none());
790        assert!(cfg.label_mappings.get("I-DATE").is_none());
791    }
792
793    #[test]
794    fn test_load_config_tokenizer_path_always_none() {
795        // Even when config.json explicitly sets tokenizer_path, the loader
796        // must suppress it (build-time path is stale at runtime).
797        let json = r#"{
798            "tokenizer_path": "/out/models/tokenizer.json",
799            "id2label": { "0": "O", "1": "B-PER" }
800        }"#;
801
802        let f = write_temp_config(json);
803        let cfg = NerRecognizer::load_config_from_file(f.path(), "/m.onnx").unwrap();
804        assert!(cfg.tokenizer_path.is_none());
805    }
806
807    #[test]
808    fn test_load_config_malformed_json_returns_err() {
809        let f = write_temp_config("{ this is not valid json }}}");
810        let result = NerRecognizer::load_config_from_file(f.path(), "/m.onnx");
811        assert!(result.is_err());
812    }
813
814    #[test]
815    fn test_load_config_empty_json_uses_defaults() {
816        // An empty JSON object should fall back to defaults for every field
817        let f = write_temp_config("{}");
818        let cfg = NerRecognizer::load_config_from_file(f.path(), "/m.onnx").unwrap();
819
820        let defaults = NerConfig::default();
821        assert_eq!(cfg.min_confidence, defaults.min_confidence);
822        assert_eq!(cfg.max_seq_length, defaults.max_seq_length);
823        // id2label falls back to defaults (no "id2label" key in JSON)
824        assert_eq!(cfg.id2label.len(), defaults.id2label.len());
825    }
826}