Skip to main content

redact_ner/
recognizer.rs

1// Copyright (c) 2026 Censgate LLC.
2// Licensed under the Business Source License 1.1 (BUSL-1.1).
3// See the LICENSE file in the project root for license details,
4// including the Additional Use Grant, Change Date, and Change License.
5
6use anyhow::{anyhow, Result};
7use ort::session::builder::GraphOptimizationLevel;
8use ort::session::Session;
9use ort::value::Value;
10use redact_core::{EntityType, Recognizer, RecognizerResult};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::path::Path;
14use std::sync::Mutex;
15use tracing::{debug, info, warn};
16
17use crate::tokenizer_wrapper::TokenizerWrapper;
18
19/// Configuration for NER recognizer
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct NerConfig {
22    /// Path to ONNX model file
23    pub model_path: String,
24
25    /// Path to tokenizer file (optional - will use model_path directory)
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub tokenizer_path: Option<String>,
28
29    /// Minimum confidence threshold
30    #[serde(default = "default_confidence")]
31    pub min_confidence: f32,
32
33    /// Maximum sequence length
34    #[serde(default = "default_max_length")]
35    pub max_seq_length: usize,
36
37    /// Entity type mappings from NER labels
38    #[serde(default)]
39    pub label_mappings: HashMap<String, EntityType>,
40
41    /// Label IDs to label strings mapping
42    #[serde(default)]
43    pub id2label: HashMap<usize, String>,
44}
45
46fn default_confidence() -> f32 {
47    0.7
48}
49
50fn default_max_length() -> usize {
51    512
52}
53
54impl Default for NerConfig {
55    fn default() -> Self {
56        let mut label_mappings = HashMap::new();
57        let mut id2label = HashMap::new();
58
59        // Default BIO tagging scheme mappings
60        label_mappings.insert("B-PER".to_string(), EntityType::Person);
61        label_mappings.insert("I-PER".to_string(), EntityType::Person);
62        label_mappings.insert("B-ORG".to_string(), EntityType::Organization);
63        label_mappings.insert("I-ORG".to_string(), EntityType::Organization);
64        label_mappings.insert("B-LOC".to_string(), EntityType::Location);
65        label_mappings.insert("I-LOC".to_string(), EntityType::Location);
66        label_mappings.insert("B-DATE".to_string(), EntityType::DateTime);
67        label_mappings.insert("I-DATE".to_string(), EntityType::DateTime);
68        label_mappings.insert("B-TIME".to_string(), EntityType::DateTime);
69        label_mappings.insert("I-TIME".to_string(), EntityType::DateTime);
70
71        // Default id2label for CoNLL-2003 style models
72        id2label.insert(0, "O".to_string());
73        id2label.insert(1, "B-PER".to_string());
74        id2label.insert(2, "I-PER".to_string());
75        id2label.insert(3, "B-ORG".to_string());
76        id2label.insert(4, "I-ORG".to_string());
77        id2label.insert(5, "B-LOC".to_string());
78        id2label.insert(6, "I-LOC".to_string());
79        id2label.insert(7, "B-MISC".to_string());
80        id2label.insert(8, "I-MISC".to_string());
81
82        Self {
83            model_path: String::new(),
84            tokenizer_path: None,
85            min_confidence: default_confidence(),
86            max_seq_length: default_max_length(),
87            label_mappings,
88            id2label,
89        }
90    }
91}
92
93/// NER-based recognizer using ONNX Runtime
94///
95/// **Status**: ✅ Fully operational with complete ONNX Runtime integration
96///
97/// This recognizer uses transformer-based Named Entity Recognition models for contextual
98/// PII detection. It automatically loads and runs ONNX models with:
99/// - Tokenization with HuggingFace tokenizers
100/// - ONNX Runtime inference with optimizations
101/// - BIO tag parsing for entity span extraction
102/// - Thread-safe session management
103///
104/// **To enable NER**:
105/// 1. Export your NER model to ONNX format using `scripts/export_ner_model.py`
106/// 2. Set `model_path` to point to your `.onnx` file
107/// 3. Optionally provide `tokenizer_path` or place `tokenizer.json` in the same directory
108///
109/// Without a model, this recognizer gracefully returns empty results and the system
110/// falls back to pattern-based detection (36+ entity types).
111pub struct NerRecognizer {
112    config: NerConfig,
113    tokenizer: Option<TokenizerWrapper>,
114    session: Option<Mutex<Session>>,
115    /// Whether the ONNX model accepts `token_type_ids` as an input.
116    /// BERT-family models require it; DistilBERT and others do not.
117    /// Determined at model-load time by inspecting `Session::inputs()`.
118    needs_token_type_ids: bool,
119}
120
121impl std::fmt::Debug for NerRecognizer {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        f.debug_struct("NerRecognizer")
124            .field("config", &self.config)
125            .field("tokenizer", &self.tokenizer)
126            .field("session", &self.session.as_ref().map(|_| "Session"))
127            .field("needs_token_type_ids", &self.needs_token_type_ids)
128            .finish()
129    }
130}
131
132impl NerRecognizer {
133    /// Create a new NER recognizer from a model file.
134    ///
135    /// Automatically loads `config.json` from the model directory (if present)
136    /// to get the correct `id2label` and `label_mappings` for the exported model.
137    /// Falls back to default CoNLL-2003 mappings when no config is found.
138    pub fn from_file<P: AsRef<Path>>(model_path: P) -> Result<Self> {
139        let model_path_ref = model_path.as_ref();
140        let model_path_str = model_path_ref.to_string_lossy().to_string();
141
142        // Try loading config.json from model directory (written by export_ner_model.py)
143        let config = if let Some(model_dir) = model_path_ref.parent() {
144            let config_path = model_dir.join("config.json");
145            if config_path.exists() {
146                debug!("Loading NER config from: {}", config_path.display());
147                match Self::load_config_from_file(&config_path, &model_path_str) {
148                    Ok(cfg) => cfg,
149                    Err(e) => {
150                        warn!("Failed to load NER config.json: {}. Using defaults.", e);
151                        NerConfig {
152                            model_path: model_path_str,
153                            ..Default::default()
154                        }
155                    }
156                }
157            } else {
158                debug!("No config.json in model directory, using default label mappings");
159                NerConfig {
160                    model_path: model_path_str,
161                    ..Default::default()
162                }
163            }
164        } else {
165            NerConfig {
166                model_path: model_path_str,
167                ..Default::default()
168            }
169        };
170
171        Self::from_config(config)
172    }
173
174    /// Load NER config from a JSON file produced by `export_ner_model.py`.
175    ///
176    /// Handles format differences between the Python export (string keys, PascalCase
177    /// entity names) and Rust types (usize keys, SCREAMING_SNAKE_CASE EntityType).
178    fn load_config_from_file(config_path: &Path, model_path: &str) -> Result<NerConfig> {
179        let json_str = std::fs::read_to_string(config_path)?;
180        let raw: serde_json::Value = serde_json::from_str(&json_str)?;
181
182        let defaults = NerConfig::default();
183
184        // Parse id2label: JSON has string keys like {"0": "O", "1": "B-MISC", ...}
185        let id2label = if let Some(obj) = raw.get("id2label").and_then(|v| v.as_object()) {
186            let mut map = HashMap::new();
187            for (k, v) in obj {
188                if let (Ok(id), Some(label)) = (k.parse::<usize>(), v.as_str()) {
189                    map.insert(id, label.to_string());
190                }
191            }
192            map
193        } else {
194            defaults.id2label.clone()
195        };
196
197        // Parse label_mappings: JSON has {"B-PER": "Person", ...}
198        // EntityType::from() handles case-insensitive conversion
199        let label_mappings =
200            if let Some(obj) = raw.get("label_mappings").and_then(|v| v.as_object()) {
201                let mut map = HashMap::new();
202                for (k, v) in obj {
203                    if let Some(entity_str) = v.as_str() {
204                        map.insert(k.clone(), EntityType::from(entity_str.to_string()));
205                    }
206                }
207                map
208            } else {
209                // Build label_mappings purely from id2label (no stale defaults).
210                let mut map = HashMap::new();
211                for label in id2label.values() {
212                    if label == "O" {
213                        continue;
214                    }
215                    let entity_type = label.split('-').next_back().unwrap_or(label);
216                    match entity_type {
217                        "PER" | "PERSON" => {
218                            map.insert(label.clone(), EntityType::Person);
219                        }
220                        "ORG" | "ORGANIZATION" => {
221                            map.insert(label.clone(), EntityType::Organization);
222                        }
223                        "LOC" | "LOCATION" | "GPE" => {
224                            map.insert(label.clone(), EntityType::Location);
225                        }
226                        "DATE" | "TIME" | "DATETIME" => {
227                            map.insert(label.clone(), EntityType::DateTime);
228                        }
229                        _ => {
230                            debug!("Unmapped NER label: {} — no EntityType match", label);
231                        }
232                    }
233                }
234                map
235            };
236
237        let min_confidence = raw
238            .get("min_confidence")
239            .and_then(|v| v.as_f64())
240            .map(|v| v as f32)
241            .unwrap_or(defaults.min_confidence);
242
243        let max_seq_length = raw
244            .get("max_seq_length")
245            .and_then(|v| v.as_u64())
246            .map(|v| v as usize)
247            .unwrap_or(defaults.max_seq_length);
248
249        // Intentionally ignore tokenizer_path from config.json: the export script
250        // writes a build-time path (e.g. /out/models/tokenizer.json) that won't exist
251        // at runtime. from_config() auto-discovers tokenizer.json from the model directory.
252        let tokenizer_path = None;
253
254        info!(
255            "Loaded NER config from {} ({} label mappings, {} id2label entries)",
256            config_path.display(),
257            label_mappings.len(),
258            id2label.len()
259        );
260
261        Ok(NerConfig {
262            model_path: model_path.to_string(),
263            tokenizer_path,
264            min_confidence,
265            max_seq_length,
266            label_mappings,
267            id2label,
268        })
269    }
270
271    /// Create a new NER recognizer from configuration
272    pub fn from_config(config: NerConfig) -> Result<Self> {
273        // Try to load tokenizer if available
274        let tokenizer = if let Some(ref tokenizer_path) = config.tokenizer_path {
275            debug!("Loading tokenizer from: {}", tokenizer_path);
276            match TokenizerWrapper::from_file(tokenizer_path) {
277                Ok(t) => {
278                    info!("✓ Tokenizer loaded successfully from: {}", tokenizer_path);
279                    Some(t)
280                }
281                Err(e) => {
282                    warn!(
283                        "Failed to load tokenizer: {}. NER will not be available.",
284                        e
285                    );
286                    None
287                }
288            }
289        } else if !config.model_path.is_empty() {
290            // Try to find tokenizer in same directory as model
291            let model_dir = Path::new(&config.model_path).parent();
292            if let Some(dir) = model_dir {
293                let tokenizer_json = dir.join("tokenizer.json");
294                if tokenizer_json.exists() {
295                    debug!("Loading tokenizer from: {}", tokenizer_json.display());
296                    match TokenizerWrapper::from_file(&tokenizer_json) {
297                        Ok(t) => {
298                            info!("✓ Tokenizer loaded successfully from model directory");
299                            Some(t)
300                        }
301                        Err(e) => {
302                            warn!("Failed to load tokenizer from model directory: {}", e);
303                            None
304                        }
305                    }
306                } else {
307                    debug!("No tokenizer.json found in model directory");
308                    None
309                }
310            } else {
311                None
312            }
313        } else {
314            None
315        };
316
317        // Try to load ONNX model if path is provided
318        let session = if !config.model_path.is_empty() {
319            let model_path = Path::new(&config.model_path);
320            if model_path.exists() {
321                debug!("Loading ONNX model from: {}", config.model_path);
322                match Session::builder()?
323                    .with_optimization_level(GraphOptimizationLevel::Level3)
324                    .map_err(|e| anyhow::anyhow!("{e}"))?
325                    .with_intra_threads(4)
326                    .map_err(|e| anyhow::anyhow!("{e}"))?
327                    .commit_from_file(&config.model_path)
328                {
329                    Ok(s) => {
330                        info!("✓ ONNX model loaded successfully: {}", config.model_path);
331                        Some(Mutex::new(s))
332                    }
333                    Err(e) => {
334                        warn!(
335                            "Failed to load ONNX model: {}. NER will not be available.",
336                            e
337                        );
338                        None
339                    }
340                }
341            } else {
342                debug!(
343                    "Model path provided but file does not exist: {}",
344                    config.model_path
345                );
346                None
347            }
348        } else {
349            debug!("No model path provided, NER will not be available");
350            None
351        };
352
353        // Inspect model inputs at construction time to determine whether the
354        // model expects token_type_ids (BERT-family) or not (DistilBERT, etc.).
355        let needs_token_type_ids = session.as_ref().is_some_and(|s| {
356            let guard = s.lock().expect("session lock poisoned during init");
357            let has_it = guard
358                .inputs()
359                .iter()
360                .any(|input| input.name() == "token_type_ids");
361            if has_it {
362                debug!("Model declares token_type_ids input — will include in inference");
363            } else {
364                debug!("Model does not declare token_type_ids — omitting from inference");
365            }
366            has_it
367        });
368
369        let is_available = tokenizer.is_some() && session.is_some();
370        if is_available {
371            info!("✓ NER is fully operational with ONNX Runtime");
372        } else {
373            info!("⚠ NER not available - using pattern-based detection (36+ entity types)");
374            if tokenizer.is_none() {
375                debug!("  Missing: tokenizer");
376            }
377            if session.is_none() {
378                debug!("  Missing: ONNX model");
379            }
380        }
381
382        Ok(Self {
383            config,
384            tokenizer,
385            session,
386            needs_token_type_ids,
387        })
388    }
389
390    /// Get the configuration
391    pub fn config(&self) -> &NerConfig {
392        &self.config
393    }
394
395    /// Check if NER is available (model and tokenizer loaded)
396    pub fn is_available(&self) -> bool {
397        self.tokenizer.is_some() && self.session.is_some()
398    }
399
400    /// Map NER label to entity type
401    fn map_label_to_entity(&self, label: &str) -> Option<EntityType> {
402        self.config.label_mappings.get(label).cloned()
403    }
404
405    /// Run inference on tokenized input
406    fn infer(&self, input_ids: &[u32], attention_mask: &[u32]) -> Result<Vec<Vec<f32>>> {
407        let session_mutex = self
408            .session
409            .as_ref()
410            .ok_or_else(|| anyhow!("ONNX session not loaded"))?;
411
412        let mut session = session_mutex
413            .lock()
414            .map_err(|e| anyhow!("Failed to lock session: {}", e))?;
415
416        // Create 2D arrays with shape [1, seq_len]
417        let seq_len = input_ids.len();
418        let input_ids_i64: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
419        let attention_mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
420
421        let input_ids_value = Value::from_array(([1, seq_len], input_ids_i64))?;
422        let attention_mask_value = Value::from_array(([1, seq_len], attention_mask_i64))?;
423
424        // Build inputs list — only include token_type_ids when the model expects it
425        // (BERT-family needs it; DistilBERT and others do not).
426        let mut inputs: Vec<(std::borrow::Cow<'_, str>, Value)> = vec![
427            ("input_ids".into(), input_ids_value.into()),
428            ("attention_mask".into(), attention_mask_value.into()),
429        ];
430
431        if self.needs_token_type_ids {
432            let token_type_ids_i64: Vec<i64> = vec![0i64; seq_len];
433            let token_type_ids_value = Value::from_array(([1, seq_len], token_type_ids_i64))?;
434            inputs.push(("token_type_ids".into(), token_type_ids_value.into()));
435        }
436
437        let outputs = session.run(inputs)?;
438
439        // Extract logits - shape should be [1, seq_len, num_labels]
440        let (shape, logits_data) = outputs["logits"].try_extract_tensor::<f32>()?;
441        let shape_dims: &[i64] = shape.as_ref();
442
443        if shape_dims.len() != 3 || shape_dims[0] != 1 {
444            return Err(anyhow!("Unexpected logits shape: {:?}", shape_dims));
445        }
446
447        let seq_len_out = shape_dims[1] as usize;
448        let num_labels = shape_dims[2] as usize;
449
450        // Convert to Vec<Vec<f32>> where outer vec is tokens, inner vec is label scores
451        let mut result = Vec::new();
452        for i in 0..seq_len_out {
453            let mut token_logits = Vec::new();
454            for j in 0..num_labels {
455                let idx = i * num_labels + j;
456                token_logits.push(logits_data[idx]);
457            }
458            result.push(token_logits);
459        }
460
461        Ok(result)
462    }
463
464    /// Apply softmax to convert logits to probabilities
465    fn softmax(logits: &[f32]) -> Vec<f32> {
466        let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
467        let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
468        logits
469            .iter()
470            .map(|&x| (x - max_logit).exp() / exp_sum)
471            .collect()
472    }
473
474    /// Parse BIO tags and extract entity spans
475    fn parse_bio_tags(
476        &self,
477        _text: &str,
478        predictions: &[usize],
479        probabilities: &[f32],
480        offsets: &[(usize, usize)],
481    ) -> Vec<RecognizerResult> {
482        let mut results = Vec::new();
483        let mut current_entity: Option<(EntityType, usize, usize, Vec<f32>)> = None;
484
485        for (idx, (&pred_id, &prob)) in predictions.iter().zip(probabilities.iter()).enumerate() {
486            // Skip padding tokens (offset (0,0))
487            if offsets[idx] == (0, 0) {
488                continue;
489            }
490
491            let label = self
492                .config
493                .id2label
494                .get(&pred_id)
495                .map(|s| s.as_str())
496                .unwrap_or("O");
497
498            if label.starts_with("B-") {
499                // Begin new entity - save previous if exists
500                if let Some((entity_type, start, end, probs)) = current_entity.take() {
501                    let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
502                    if avg_confidence >= self.config.min_confidence {
503                        results.push(RecognizerResult::new(
504                            entity_type,
505                            start,
506                            end,
507                            avg_confidence,
508                            self.name(),
509                        ));
510                    }
511                }
512
513                // Start new entity
514                if let Some(entity_type) = self.map_label_to_entity(label) {
515                    let start = offsets[idx].0;
516                    let end = offsets[idx].1;
517                    current_entity = Some((entity_type, start, end, vec![prob]));
518                }
519            } else if label.starts_with("I-") {
520                // Continue current entity
521                if let Some((ref entity_type, start, ref mut end, ref mut probs)) = current_entity {
522                    // Check if label matches current entity type
523                    if let Some(label_entity) = self.map_label_to_entity(label) {
524                        if label_entity == *entity_type {
525                            *end = offsets[idx].1;
526                            probs.push(prob);
527                        } else {
528                            // Different entity type - save current and start new
529                            let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
530                            if avg_confidence >= self.config.min_confidence {
531                                results.push(RecognizerResult::new(
532                                    entity_type.clone(),
533                                    start,
534                                    *end,
535                                    avg_confidence,
536                                    self.name(),
537                                ));
538                            }
539                            current_entity = None;
540                        }
541                    }
542                }
543            } else {
544                // "O" tag or unknown - end current entity
545                if let Some((entity_type, start, end, probs)) = current_entity.take() {
546                    let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
547                    if avg_confidence >= self.config.min_confidence {
548                        results.push(RecognizerResult::new(
549                            entity_type,
550                            start,
551                            end,
552                            avg_confidence,
553                            self.name(),
554                        ));
555                    }
556                }
557            }
558        }
559
560        // Don't forget the last entity
561        if let Some((entity_type, start, end, probs)) = current_entity {
562            let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
563            if avg_confidence >= self.config.min_confidence {
564                results.push(RecognizerResult::new(
565                    entity_type,
566                    start,
567                    end,
568                    avg_confidence,
569                    self.name(),
570                ));
571            }
572        }
573
574        results
575    }
576}
577
578impl Recognizer for NerRecognizer {
579    fn name(&self) -> &str {
580        "NerRecognizer"
581    }
582
583    fn supported_entities(&self) -> &[EntityType] {
584        &[
585            EntityType::Person,
586            EntityType::Organization,
587            EntityType::Location,
588            EntityType::DateTime,
589        ]
590    }
591
592    fn analyze(&self, text: &str, _language: &str) -> Result<Vec<RecognizerResult>> {
593        // Check if NER is available
594        if !self.is_available() {
595            return Ok(vec![]);
596        }
597
598        let tokenizer = self.tokenizer.as_ref().unwrap();
599
600        // Tokenize input
601        let mut encoding = tokenizer.encode(text, true)?;
602
603        // Get padding token ID
604        let pad_id = tokenizer.get_padding_id().unwrap_or(0);
605
606        // Pad/truncate to max sequence length
607        encoding.pad_to_length(self.config.max_seq_length, pad_id);
608
609        // Run inference
610        let logits = self.infer(&encoding.ids, &encoding.attention_mask)?;
611
612        // Convert logits to predictions
613        let mut predictions = Vec::new();
614        let mut probabilities = Vec::new();
615
616        for token_logits in &logits {
617            let probs = Self::softmax(token_logits);
618            let (pred_id, &max_prob) = probs
619                .iter()
620                .enumerate()
621                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
622                .unwrap();
623            predictions.push(pred_id);
624            probabilities.push(max_prob);
625        }
626
627        // Parse BIO tags to extract entities
628        let results = self.parse_bio_tags(text, &predictions, &probabilities, &encoding.offsets);
629
630        Ok(results)
631    }
632
633    fn supports_language(&self, language: &str) -> bool {
634        // Most multilingual NER models support these languages
635        matches!(
636            language,
637            "en" | "es" | "fr" | "de" | "it" | "pt" | "nl" | "pl" | "ru" | "zh" | "ja" | "ko"
638        )
639    }
640}
641
642#[cfg(test)]
643mod tests {
644    use super::*;
645    use std::io::Write;
646
647    #[test]
648    fn test_default_config() {
649        let config = NerConfig::default();
650        assert_eq!(config.min_confidence, 0.7);
651        assert_eq!(config.max_seq_length, 512);
652        assert!(!config.label_mappings.is_empty());
653    }
654
655    #[test]
656    fn test_label_mapping() {
657        let config = NerConfig::default();
658        let recognizer = NerRecognizer::from_config(config).unwrap();
659
660        assert_eq!(
661            recognizer.map_label_to_entity("B-PER"),
662            Some(EntityType::Person)
663        );
664        assert_eq!(
665            recognizer.map_label_to_entity("B-ORG"),
666            Some(EntityType::Organization)
667        );
668        assert_eq!(recognizer.map_label_to_entity("O"), None);
669    }
670
671    #[test]
672    fn test_recognizer_without_model() {
673        let config = NerConfig::default();
674        let recognizer = NerRecognizer::from_config(config).unwrap();
675
676        // Should not be available without model
677        assert!(!recognizer.is_available());
678
679        // Should return empty results
680        let results = recognizer.analyze("John Doe", "en").unwrap();
681        assert_eq!(results.len(), 0);
682    }
683
684    #[test]
685    fn test_recognizer_without_model_has_no_token_type_ids() {
686        let config = NerConfig::default();
687        let recognizer = NerRecognizer::from_config(config).unwrap();
688
689        // No session loaded → flag defaults to false
690        assert!(!recognizer.needs_token_type_ids);
691    }
692
693    // ---- load_config_from_file tests ----
694
695    /// Helper: write `contents` to a temp file and return its path.
696    fn write_temp_config(contents: &str) -> tempfile::NamedTempFile {
697        let mut f = tempfile::NamedTempFile::new().unwrap();
698        f.write_all(contents.as_bytes()).unwrap();
699        f.flush().unwrap();
700        f
701    }
702
703    #[test]
704    fn test_load_config_valid_with_both_id2label_and_label_mappings() {
705        let json = r#"{
706            "id2label": {
707                "0": "O",
708                "1": "B-MISC",
709                "2": "I-MISC",
710                "3": "B-PER",
711                "4": "I-PER",
712                "5": "B-ORG",
713                "6": "I-ORG",
714                "7": "B-LOC",
715                "8": "I-LOC"
716            },
717            "label_mappings": {
718                "B-PER": "Person",
719                "I-PER": "Person",
720                "B-ORG": "Organization",
721                "I-ORG": "Organization",
722                "B-LOC": "Location",
723                "I-LOC": "Location"
724            },
725            "min_confidence": 0.8,
726            "max_seq_length": 256,
727            "tokenizer_path": "/build/time/tokenizer.json"
728        }"#;
729
730        let f = write_temp_config(json);
731        let cfg = NerRecognizer::load_config_from_file(f.path(), "/runtime/model.onnx").unwrap();
732
733        // id2label parsed correctly
734        assert_eq!(cfg.id2label.len(), 9);
735        assert_eq!(cfg.id2label[&3], "B-PER");
736        assert_eq!(cfg.id2label[&5], "B-ORG");
737
738        // label_mappings parsed correctly (PascalCase → EntityType)
739        assert_eq!(cfg.label_mappings.len(), 6);
740        assert_eq!(cfg.label_mappings["B-PER"], EntityType::Person);
741        assert_eq!(cfg.label_mappings["B-ORG"], EntityType::Organization);
742        assert_eq!(cfg.label_mappings["B-LOC"], EntityType::Location);
743
744        // Scalars honoured
745        assert_eq!(cfg.min_confidence, 0.8);
746        assert_eq!(cfg.max_seq_length, 256);
747
748        // model_path overridden to runtime value
749        assert_eq!(cfg.model_path, "/runtime/model.onnx");
750
751        // tokenizer_path always suppressed regardless of config.json content
752        assert!(cfg.tokenizer_path.is_none());
753    }
754
755    #[test]
756    fn test_load_config_fallback_derives_label_mappings_from_id2label() {
757        // config.json has id2label but no label_mappings → derived path
758        let json = r#"{
759            "id2label": {
760                "0": "O",
761                "1": "B-MISC",
762                "2": "I-MISC",
763                "3": "B-PER",
764                "4": "I-PER",
765                "5": "B-ORG",
766                "6": "I-ORG",
767                "7": "B-LOC",
768                "8": "I-LOC"
769            }
770        }"#;
771
772        let f = write_temp_config(json);
773        let cfg = NerRecognizer::load_config_from_file(f.path(), "/m.onnx").unwrap();
774
775        // Derived mappings should include PER, ORG, LOC but NOT MISC or stale defaults
776        assert_eq!(cfg.label_mappings.get("B-PER"), Some(&EntityType::Person));
777        assert_eq!(cfg.label_mappings.get("I-PER"), Some(&EntityType::Person));
778        assert_eq!(
779            cfg.label_mappings.get("B-ORG"),
780            Some(&EntityType::Organization)
781        );
782        assert_eq!(cfg.label_mappings.get("B-LOC"), Some(&EntityType::Location));
783
784        // MISC labels should NOT appear (no EntityType mapping exists)
785        assert!(cfg.label_mappings.get("B-MISC").is_none());
786        assert!(cfg.label_mappings.get("I-MISC").is_none());
787
788        // No stale defaults: B-DATE / I-DATE should NOT leak in because
789        // they are not present in the provided id2label
790        assert!(cfg.label_mappings.get("B-DATE").is_none());
791        assert!(cfg.label_mappings.get("I-DATE").is_none());
792    }
793
794    #[test]
795    fn test_load_config_tokenizer_path_always_none() {
796        // Even when config.json explicitly sets tokenizer_path, the loader
797        // must suppress it (build-time path is stale at runtime).
798        let json = r#"{
799            "tokenizer_path": "/out/models/tokenizer.json",
800            "id2label": { "0": "O", "1": "B-PER" }
801        }"#;
802
803        let f = write_temp_config(json);
804        let cfg = NerRecognizer::load_config_from_file(f.path(), "/m.onnx").unwrap();
805        assert!(cfg.tokenizer_path.is_none());
806    }
807
808    #[test]
809    fn test_load_config_malformed_json_returns_err() {
810        let f = write_temp_config("{ this is not valid json }}}");
811        let result = NerRecognizer::load_config_from_file(f.path(), "/m.onnx");
812        assert!(result.is_err());
813    }
814
815    #[test]
816    fn test_load_config_empty_json_uses_defaults() {
817        // An empty JSON object should fall back to defaults for every field
818        let f = write_temp_config("{}");
819        let cfg = NerRecognizer::load_config_from_file(f.path(), "/m.onnx").unwrap();
820
821        let defaults = NerConfig::default();
822        assert_eq!(cfg.min_confidence, defaults.min_confidence);
823        assert_eq!(cfg.max_seq_length, defaults.max_seq_length);
824        // id2label falls back to defaults (no "id2label" key in JSON)
825        assert_eq!(cfg.id2label.len(), defaults.id2label.len());
826    }
827}