1use anyhow::{anyhow, Result};
7use ort::session::builder::GraphOptimizationLevel;
8use ort::session::Session;
9use ort::value::Value;
10use redact_core::{EntityType, Recognizer, RecognizerResult};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::path::Path;
14use std::sync::Mutex;
15use tracing::{debug, info, warn};
16
17use crate::tokenizer_wrapper::TokenizerWrapper;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct NerConfig {
22 pub model_path: String,
24
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub tokenizer_path: Option<String>,
28
29 #[serde(default = "default_confidence")]
31 pub min_confidence: f32,
32
33 #[serde(default = "default_max_length")]
35 pub max_seq_length: usize,
36
37 #[serde(default)]
39 pub label_mappings: HashMap<String, EntityType>,
40
41 #[serde(default)]
43 pub id2label: HashMap<usize, String>,
44}
45
46fn default_confidence() -> f32 {
47 0.7
48}
49
50fn default_max_length() -> usize {
51 512
52}
53
54impl Default for NerConfig {
55 fn default() -> Self {
56 let mut label_mappings = HashMap::new();
57 let mut id2label = HashMap::new();
58
59 label_mappings.insert("B-PER".to_string(), EntityType::Person);
61 label_mappings.insert("I-PER".to_string(), EntityType::Person);
62 label_mappings.insert("B-ORG".to_string(), EntityType::Organization);
63 label_mappings.insert("I-ORG".to_string(), EntityType::Organization);
64 label_mappings.insert("B-LOC".to_string(), EntityType::Location);
65 label_mappings.insert("I-LOC".to_string(), EntityType::Location);
66 label_mappings.insert("B-DATE".to_string(), EntityType::DateTime);
67 label_mappings.insert("I-DATE".to_string(), EntityType::DateTime);
68 label_mappings.insert("B-TIME".to_string(), EntityType::DateTime);
69 label_mappings.insert("I-TIME".to_string(), EntityType::DateTime);
70
71 id2label.insert(0, "O".to_string());
73 id2label.insert(1, "B-PER".to_string());
74 id2label.insert(2, "I-PER".to_string());
75 id2label.insert(3, "B-ORG".to_string());
76 id2label.insert(4, "I-ORG".to_string());
77 id2label.insert(5, "B-LOC".to_string());
78 id2label.insert(6, "I-LOC".to_string());
79 id2label.insert(7, "B-MISC".to_string());
80 id2label.insert(8, "I-MISC".to_string());
81
82 Self {
83 model_path: String::new(),
84 tokenizer_path: None,
85 min_confidence: default_confidence(),
86 max_seq_length: default_max_length(),
87 label_mappings,
88 id2label,
89 }
90 }
91}
92
93pub struct NerRecognizer {
112 config: NerConfig,
113 tokenizer: Option<TokenizerWrapper>,
114 session: Option<Mutex<Session>>,
115 needs_token_type_ids: bool,
119}
120
121impl std::fmt::Debug for NerRecognizer {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 f.debug_struct("NerRecognizer")
124 .field("config", &self.config)
125 .field("tokenizer", &self.tokenizer)
126 .field("session", &self.session.as_ref().map(|_| "Session"))
127 .field("needs_token_type_ids", &self.needs_token_type_ids)
128 .finish()
129 }
130}
131
132impl NerRecognizer {
133 pub fn from_file<P: AsRef<Path>>(model_path: P) -> Result<Self> {
139 let model_path_ref = model_path.as_ref();
140 let model_path_str = model_path_ref.to_string_lossy().to_string();
141
142 let config = if let Some(model_dir) = model_path_ref.parent() {
144 let config_path = model_dir.join("config.json");
145 if config_path.exists() {
146 debug!("Loading NER config from: {}", config_path.display());
147 match Self::load_config_from_file(&config_path, &model_path_str) {
148 Ok(cfg) => cfg,
149 Err(e) => {
150 warn!("Failed to load NER config.json: {}. Using defaults.", e);
151 NerConfig {
152 model_path: model_path_str,
153 ..Default::default()
154 }
155 }
156 }
157 } else {
158 debug!("No config.json in model directory, using default label mappings");
159 NerConfig {
160 model_path: model_path_str,
161 ..Default::default()
162 }
163 }
164 } else {
165 NerConfig {
166 model_path: model_path_str,
167 ..Default::default()
168 }
169 };
170
171 Self::from_config(config)
172 }
173
174 fn load_config_from_file(config_path: &Path, model_path: &str) -> Result<NerConfig> {
179 let json_str = std::fs::read_to_string(config_path)?;
180 let raw: serde_json::Value = serde_json::from_str(&json_str)?;
181
182 let defaults = NerConfig::default();
183
184 let id2label = if let Some(obj) = raw.get("id2label").and_then(|v| v.as_object()) {
186 let mut map = HashMap::new();
187 for (k, v) in obj {
188 if let (Ok(id), Some(label)) = (k.parse::<usize>(), v.as_str()) {
189 map.insert(id, label.to_string());
190 }
191 }
192 map
193 } else {
194 defaults.id2label.clone()
195 };
196
197 let label_mappings =
200 if let Some(obj) = raw.get("label_mappings").and_then(|v| v.as_object()) {
201 let mut map = HashMap::new();
202 for (k, v) in obj {
203 if let Some(entity_str) = v.as_str() {
204 map.insert(k.clone(), EntityType::from(entity_str.to_string()));
205 }
206 }
207 map
208 } else {
209 let mut map = HashMap::new();
211 for label in id2label.values() {
212 if label == "O" {
213 continue;
214 }
215 let entity_type = label.split('-').next_back().unwrap_or(label);
216 match entity_type {
217 "PER" | "PERSON" => {
218 map.insert(label.clone(), EntityType::Person);
219 }
220 "ORG" | "ORGANIZATION" => {
221 map.insert(label.clone(), EntityType::Organization);
222 }
223 "LOC" | "LOCATION" | "GPE" => {
224 map.insert(label.clone(), EntityType::Location);
225 }
226 "DATE" | "TIME" | "DATETIME" => {
227 map.insert(label.clone(), EntityType::DateTime);
228 }
229 _ => {
230 debug!("Unmapped NER label: {} — no EntityType match", label);
231 }
232 }
233 }
234 map
235 };
236
237 let min_confidence = raw
238 .get("min_confidence")
239 .and_then(|v| v.as_f64())
240 .map(|v| v as f32)
241 .unwrap_or(defaults.min_confidence);
242
243 let max_seq_length = raw
244 .get("max_seq_length")
245 .and_then(|v| v.as_u64())
246 .map(|v| v as usize)
247 .unwrap_or(defaults.max_seq_length);
248
249 let tokenizer_path = None;
253
254 info!(
255 "Loaded NER config from {} ({} label mappings, {} id2label entries)",
256 config_path.display(),
257 label_mappings.len(),
258 id2label.len()
259 );
260
261 Ok(NerConfig {
262 model_path: model_path.to_string(),
263 tokenizer_path,
264 min_confidence,
265 max_seq_length,
266 label_mappings,
267 id2label,
268 })
269 }
270
271 pub fn from_config(config: NerConfig) -> Result<Self> {
273 let tokenizer = if let Some(ref tokenizer_path) = config.tokenizer_path {
275 debug!("Loading tokenizer from: {}", tokenizer_path);
276 match TokenizerWrapper::from_file(tokenizer_path) {
277 Ok(t) => {
278 info!("✓ Tokenizer loaded successfully from: {}", tokenizer_path);
279 Some(t)
280 }
281 Err(e) => {
282 warn!(
283 "Failed to load tokenizer: {}. NER will not be available.",
284 e
285 );
286 None
287 }
288 }
289 } else if !config.model_path.is_empty() {
290 let model_dir = Path::new(&config.model_path).parent();
292 if let Some(dir) = model_dir {
293 let tokenizer_json = dir.join("tokenizer.json");
294 if tokenizer_json.exists() {
295 debug!("Loading tokenizer from: {}", tokenizer_json.display());
296 match TokenizerWrapper::from_file(&tokenizer_json) {
297 Ok(t) => {
298 info!("✓ Tokenizer loaded successfully from model directory");
299 Some(t)
300 }
301 Err(e) => {
302 warn!("Failed to load tokenizer from model directory: {}", e);
303 None
304 }
305 }
306 } else {
307 debug!("No tokenizer.json found in model directory");
308 None
309 }
310 } else {
311 None
312 }
313 } else {
314 None
315 };
316
317 let session = if !config.model_path.is_empty() {
319 let model_path = Path::new(&config.model_path);
320 if model_path.exists() {
321 debug!("Loading ONNX model from: {}", config.model_path);
322 match Session::builder()?
323 .with_optimization_level(GraphOptimizationLevel::Level3)?
324 .with_intra_threads(4)?
325 .commit_from_file(&config.model_path)
326 {
327 Ok(s) => {
328 info!("✓ ONNX model loaded successfully: {}", config.model_path);
329 Some(Mutex::new(s))
330 }
331 Err(e) => {
332 warn!(
333 "Failed to load ONNX model: {}. NER will not be available.",
334 e
335 );
336 None
337 }
338 }
339 } else {
340 debug!(
341 "Model path provided but file does not exist: {}",
342 config.model_path
343 );
344 None
345 }
346 } else {
347 debug!("No model path provided, NER will not be available");
348 None
349 };
350
351 let needs_token_type_ids = session.as_ref().is_some_and(|s| {
354 let guard = s.lock().expect("session lock poisoned during init");
355 let has_it = guard
356 .inputs()
357 .iter()
358 .any(|input| input.name() == "token_type_ids");
359 if has_it {
360 debug!("Model declares token_type_ids input — will include in inference");
361 } else {
362 debug!("Model does not declare token_type_ids — omitting from inference");
363 }
364 has_it
365 });
366
367 let is_available = tokenizer.is_some() && session.is_some();
368 if is_available {
369 info!("✓ NER is fully operational with ONNX Runtime");
370 } else {
371 info!("⚠ NER not available - using pattern-based detection (36+ entity types)");
372 if tokenizer.is_none() {
373 debug!(" Missing: tokenizer");
374 }
375 if session.is_none() {
376 debug!(" Missing: ONNX model");
377 }
378 }
379
380 Ok(Self {
381 config,
382 tokenizer,
383 session,
384 needs_token_type_ids,
385 })
386 }
387
388 pub fn config(&self) -> &NerConfig {
390 &self.config
391 }
392
393 pub fn is_available(&self) -> bool {
395 self.tokenizer.is_some() && self.session.is_some()
396 }
397
398 fn map_label_to_entity(&self, label: &str) -> Option<EntityType> {
400 self.config.label_mappings.get(label).cloned()
401 }
402
403 fn infer(&self, input_ids: &[u32], attention_mask: &[u32]) -> Result<Vec<Vec<f32>>> {
405 let session_mutex = self
406 .session
407 .as_ref()
408 .ok_or_else(|| anyhow!("ONNX session not loaded"))?;
409
410 let mut session = session_mutex
411 .lock()
412 .map_err(|e| anyhow!("Failed to lock session: {}", e))?;
413
414 let seq_len = input_ids.len();
416 let input_ids_i64: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
417 let attention_mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
418
419 let input_ids_value = Value::from_array(([1, seq_len], input_ids_i64))?;
420 let attention_mask_value = Value::from_array(([1, seq_len], attention_mask_i64))?;
421
422 let mut inputs: Vec<(std::borrow::Cow<'_, str>, Value)> = vec![
425 ("input_ids".into(), input_ids_value.into()),
426 ("attention_mask".into(), attention_mask_value.into()),
427 ];
428
429 if self.needs_token_type_ids {
430 let token_type_ids_i64: Vec<i64> = vec![0i64; seq_len];
431 let token_type_ids_value = Value::from_array(([1, seq_len], token_type_ids_i64))?;
432 inputs.push(("token_type_ids".into(), token_type_ids_value.into()));
433 }
434
435 let outputs = session.run(inputs)?;
436
437 let (shape, logits_data) = outputs["logits"].try_extract_tensor::<f32>()?;
439 let shape_dims: &[i64] = shape.as_ref();
440
441 if shape_dims.len() != 3 || shape_dims[0] != 1 {
442 return Err(anyhow!("Unexpected logits shape: {:?}", shape_dims));
443 }
444
445 let seq_len_out = shape_dims[1] as usize;
446 let num_labels = shape_dims[2] as usize;
447
448 let mut result = Vec::new();
450 for i in 0..seq_len_out {
451 let mut token_logits = Vec::new();
452 for j in 0..num_labels {
453 let idx = i * num_labels + j;
454 token_logits.push(logits_data[idx]);
455 }
456 result.push(token_logits);
457 }
458
459 Ok(result)
460 }
461
462 fn softmax(logits: &[f32]) -> Vec<f32> {
464 let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
465 let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
466 logits
467 .iter()
468 .map(|&x| (x - max_logit).exp() / exp_sum)
469 .collect()
470 }
471
472 fn parse_bio_tags(
474 &self,
475 _text: &str,
476 predictions: &[usize],
477 probabilities: &[f32],
478 offsets: &[(usize, usize)],
479 ) -> Vec<RecognizerResult> {
480 let mut results = Vec::new();
481 let mut current_entity: Option<(EntityType, usize, usize, Vec<f32>)> = None;
482
483 for (idx, (&pred_id, &prob)) in predictions.iter().zip(probabilities.iter()).enumerate() {
484 if offsets[idx] == (0, 0) {
486 continue;
487 }
488
489 let label = self
490 .config
491 .id2label
492 .get(&pred_id)
493 .map(|s| s.as_str())
494 .unwrap_or("O");
495
496 if label.starts_with("B-") {
497 if let Some((entity_type, start, end, probs)) = current_entity.take() {
499 let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
500 if avg_confidence >= self.config.min_confidence {
501 results.push(RecognizerResult::new(
502 entity_type,
503 start,
504 end,
505 avg_confidence,
506 self.name(),
507 ));
508 }
509 }
510
511 if let Some(entity_type) = self.map_label_to_entity(label) {
513 let start = offsets[idx].0;
514 let end = offsets[idx].1;
515 current_entity = Some((entity_type, start, end, vec![prob]));
516 }
517 } else if label.starts_with("I-") {
518 if let Some((ref entity_type, start, ref mut end, ref mut probs)) = current_entity {
520 if let Some(label_entity) = self.map_label_to_entity(label) {
522 if label_entity == *entity_type {
523 *end = offsets[idx].1;
524 probs.push(prob);
525 } else {
526 let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
528 if avg_confidence >= self.config.min_confidence {
529 results.push(RecognizerResult::new(
530 entity_type.clone(),
531 start,
532 *end,
533 avg_confidence,
534 self.name(),
535 ));
536 }
537 current_entity = None;
538 }
539 }
540 }
541 } else {
542 if let Some((entity_type, start, end, probs)) = current_entity.take() {
544 let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
545 if avg_confidence >= self.config.min_confidence {
546 results.push(RecognizerResult::new(
547 entity_type,
548 start,
549 end,
550 avg_confidence,
551 self.name(),
552 ));
553 }
554 }
555 }
556 }
557
558 if let Some((entity_type, start, end, probs)) = current_entity {
560 let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
561 if avg_confidence >= self.config.min_confidence {
562 results.push(RecognizerResult::new(
563 entity_type,
564 start,
565 end,
566 avg_confidence,
567 self.name(),
568 ));
569 }
570 }
571
572 results
573 }
574}
575
576impl Recognizer for NerRecognizer {
577 fn name(&self) -> &str {
578 "NerRecognizer"
579 }
580
581 fn supported_entities(&self) -> &[EntityType] {
582 &[
583 EntityType::Person,
584 EntityType::Organization,
585 EntityType::Location,
586 EntityType::DateTime,
587 ]
588 }
589
590 fn analyze(&self, text: &str, _language: &str) -> Result<Vec<RecognizerResult>> {
591 if !self.is_available() {
593 return Ok(vec![]);
594 }
595
596 let tokenizer = self.tokenizer.as_ref().unwrap();
597
598 let mut encoding = tokenizer.encode(text, true)?;
600
601 let pad_id = tokenizer.get_padding_id().unwrap_or(0);
603
604 encoding.pad_to_length(self.config.max_seq_length, pad_id);
606
607 let logits = self.infer(&encoding.ids, &encoding.attention_mask)?;
609
610 let mut predictions = Vec::new();
612 let mut probabilities = Vec::new();
613
614 for token_logits in &logits {
615 let probs = Self::softmax(token_logits);
616 let (pred_id, &max_prob) = probs
617 .iter()
618 .enumerate()
619 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
620 .unwrap();
621 predictions.push(pred_id);
622 probabilities.push(max_prob);
623 }
624
625 let results = self.parse_bio_tags(text, &predictions, &probabilities, &encoding.offsets);
627
628 Ok(results)
629 }
630
631 fn supports_language(&self, language: &str) -> bool {
632 matches!(
634 language,
635 "en" | "es" | "fr" | "de" | "it" | "pt" | "nl" | "pl" | "ru" | "zh" | "ja" | "ko"
636 )
637 }
638}
639
640#[cfg(test)]
641mod tests {
642 use super::*;
643 use std::io::Write;
644
645 #[test]
646 fn test_default_config() {
647 let config = NerConfig::default();
648 assert_eq!(config.min_confidence, 0.7);
649 assert_eq!(config.max_seq_length, 512);
650 assert!(!config.label_mappings.is_empty());
651 }
652
653 #[test]
654 fn test_label_mapping() {
655 let config = NerConfig::default();
656 let recognizer = NerRecognizer::from_config(config).unwrap();
657
658 assert_eq!(
659 recognizer.map_label_to_entity("B-PER"),
660 Some(EntityType::Person)
661 );
662 assert_eq!(
663 recognizer.map_label_to_entity("B-ORG"),
664 Some(EntityType::Organization)
665 );
666 assert_eq!(recognizer.map_label_to_entity("O"), None);
667 }
668
669 #[test]
670 fn test_recognizer_without_model() {
671 let config = NerConfig::default();
672 let recognizer = NerRecognizer::from_config(config).unwrap();
673
674 assert!(!recognizer.is_available());
676
677 let results = recognizer.analyze("John Doe", "en").unwrap();
679 assert_eq!(results.len(), 0);
680 }
681
682 #[test]
683 fn test_recognizer_without_model_has_no_token_type_ids() {
684 let config = NerConfig::default();
685 let recognizer = NerRecognizer::from_config(config).unwrap();
686
687 assert!(!recognizer.needs_token_type_ids);
689 }
690
691 fn write_temp_config(contents: &str) -> tempfile::NamedTempFile {
695 let mut f = tempfile::NamedTempFile::new().unwrap();
696 f.write_all(contents.as_bytes()).unwrap();
697 f.flush().unwrap();
698 f
699 }
700
701 #[test]
702 fn test_load_config_valid_with_both_id2label_and_label_mappings() {
703 let json = r#"{
704 "id2label": {
705 "0": "O",
706 "1": "B-MISC",
707 "2": "I-MISC",
708 "3": "B-PER",
709 "4": "I-PER",
710 "5": "B-ORG",
711 "6": "I-ORG",
712 "7": "B-LOC",
713 "8": "I-LOC"
714 },
715 "label_mappings": {
716 "B-PER": "Person",
717 "I-PER": "Person",
718 "B-ORG": "Organization",
719 "I-ORG": "Organization",
720 "B-LOC": "Location",
721 "I-LOC": "Location"
722 },
723 "min_confidence": 0.8,
724 "max_seq_length": 256,
725 "tokenizer_path": "/build/time/tokenizer.json"
726 }"#;
727
728 let f = write_temp_config(json);
729 let cfg = NerRecognizer::load_config_from_file(f.path(), "/runtime/model.onnx").unwrap();
730
731 assert_eq!(cfg.id2label.len(), 9);
733 assert_eq!(cfg.id2label[&3], "B-PER");
734 assert_eq!(cfg.id2label[&5], "B-ORG");
735
736 assert_eq!(cfg.label_mappings.len(), 6);
738 assert_eq!(cfg.label_mappings["B-PER"], EntityType::Person);
739 assert_eq!(cfg.label_mappings["B-ORG"], EntityType::Organization);
740 assert_eq!(cfg.label_mappings["B-LOC"], EntityType::Location);
741
742 assert_eq!(cfg.min_confidence, 0.8);
744 assert_eq!(cfg.max_seq_length, 256);
745
746 assert_eq!(cfg.model_path, "/runtime/model.onnx");
748
749 assert!(cfg.tokenizer_path.is_none());
751 }
752
753 #[test]
754 fn test_load_config_fallback_derives_label_mappings_from_id2label() {
755 let json = r#"{
757 "id2label": {
758 "0": "O",
759 "1": "B-MISC",
760 "2": "I-MISC",
761 "3": "B-PER",
762 "4": "I-PER",
763 "5": "B-ORG",
764 "6": "I-ORG",
765 "7": "B-LOC",
766 "8": "I-LOC"
767 }
768 }"#;
769
770 let f = write_temp_config(json);
771 let cfg = NerRecognizer::load_config_from_file(f.path(), "/m.onnx").unwrap();
772
773 assert_eq!(cfg.label_mappings.get("B-PER"), Some(&EntityType::Person));
775 assert_eq!(cfg.label_mappings.get("I-PER"), Some(&EntityType::Person));
776 assert_eq!(
777 cfg.label_mappings.get("B-ORG"),
778 Some(&EntityType::Organization)
779 );
780 assert_eq!(cfg.label_mappings.get("B-LOC"), Some(&EntityType::Location));
781
782 assert!(cfg.label_mappings.get("B-MISC").is_none());
784 assert!(cfg.label_mappings.get("I-MISC").is_none());
785
786 assert!(cfg.label_mappings.get("B-DATE").is_none());
789 assert!(cfg.label_mappings.get("I-DATE").is_none());
790 }
791
792 #[test]
793 fn test_load_config_tokenizer_path_always_none() {
794 let json = r#"{
797 "tokenizer_path": "/out/models/tokenizer.json",
798 "id2label": { "0": "O", "1": "B-PER" }
799 }"#;
800
801 let f = write_temp_config(json);
802 let cfg = NerRecognizer::load_config_from_file(f.path(), "/m.onnx").unwrap();
803 assert!(cfg.tokenizer_path.is_none());
804 }
805
806 #[test]
807 fn test_load_config_malformed_json_returns_err() {
808 let f = write_temp_config("{ this is not valid json }}}");
809 let result = NerRecognizer::load_config_from_file(f.path(), "/m.onnx");
810 assert!(result.is_err());
811 }
812
813 #[test]
814 fn test_load_config_empty_json_uses_defaults() {
815 let f = write_temp_config("{}");
817 let cfg = NerRecognizer::load_config_from_file(f.path(), "/m.onnx").unwrap();
818
819 let defaults = NerConfig::default();
820 assert_eq!(cfg.min_confidence, defaults.min_confidence);
821 assert_eq!(cfg.max_seq_length, defaults.max_seq_length);
822 assert_eq!(cfg.id2label.len(), defaults.id2label.len());
824 }
825}