1use anyhow::{anyhow, Result};
7use ort::session::builder::GraphOptimizationLevel;
8use ort::session::Session;
9use ort::value::Value;
10use redact_core::{EntityType, Recognizer, RecognizerResult};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::path::Path;
14use std::sync::Mutex;
15use tracing::{debug, info, warn};
16
17use crate::tokenizer_wrapper::TokenizerWrapper;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct NerConfig {
22 pub model_path: String,
24
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub tokenizer_path: Option<String>,
28
29 #[serde(default = "default_confidence")]
31 pub min_confidence: f32,
32
33 #[serde(default = "default_max_length")]
35 pub max_seq_length: usize,
36
37 #[serde(default)]
39 pub label_mappings: HashMap<String, EntityType>,
40
41 #[serde(default)]
43 pub id2label: HashMap<usize, String>,
44}
45
46fn default_confidence() -> f32 {
47 0.7
48}
49
50fn default_max_length() -> usize {
51 512
52}
53
54impl Default for NerConfig {
55 fn default() -> Self {
56 let mut label_mappings = HashMap::new();
57 let mut id2label = HashMap::new();
58
59 label_mappings.insert("B-PER".to_string(), EntityType::Person);
61 label_mappings.insert("I-PER".to_string(), EntityType::Person);
62 label_mappings.insert("B-ORG".to_string(), EntityType::Organization);
63 label_mappings.insert("I-ORG".to_string(), EntityType::Organization);
64 label_mappings.insert("B-LOC".to_string(), EntityType::Location);
65 label_mappings.insert("I-LOC".to_string(), EntityType::Location);
66 label_mappings.insert("B-DATE".to_string(), EntityType::DateTime);
67 label_mappings.insert("I-DATE".to_string(), EntityType::DateTime);
68 label_mappings.insert("B-TIME".to_string(), EntityType::DateTime);
69 label_mappings.insert("I-TIME".to_string(), EntityType::DateTime);
70
71 id2label.insert(0, "O".to_string());
73 id2label.insert(1, "B-PER".to_string());
74 id2label.insert(2, "I-PER".to_string());
75 id2label.insert(3, "B-ORG".to_string());
76 id2label.insert(4, "I-ORG".to_string());
77 id2label.insert(5, "B-LOC".to_string());
78 id2label.insert(6, "I-LOC".to_string());
79 id2label.insert(7, "B-MISC".to_string());
80 id2label.insert(8, "I-MISC".to_string());
81
82 Self {
83 model_path: String::new(),
84 tokenizer_path: None,
85 min_confidence: default_confidence(),
86 max_seq_length: default_max_length(),
87 label_mappings,
88 id2label,
89 }
90 }
91}
92
93pub struct NerRecognizer {
112 config: NerConfig,
113 tokenizer: Option<TokenizerWrapper>,
114 session: Option<Mutex<Session>>,
115 needs_token_type_ids: bool,
119}
120
121impl std::fmt::Debug for NerRecognizer {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 f.debug_struct("NerRecognizer")
124 .field("config", &self.config)
125 .field("tokenizer", &self.tokenizer)
126 .field("session", &self.session.as_ref().map(|_| "Session"))
127 .field("needs_token_type_ids", &self.needs_token_type_ids)
128 .finish()
129 }
130}
131
132impl NerRecognizer {
133 pub fn from_file<P: AsRef<Path>>(model_path: P) -> Result<Self> {
139 let model_path_ref = model_path.as_ref();
140 let model_path_str = model_path_ref.to_string_lossy().to_string();
141
142 let config = if let Some(model_dir) = model_path_ref.parent() {
144 let config_path = model_dir.join("config.json");
145 if config_path.exists() {
146 debug!("Loading NER config from: {}", config_path.display());
147 match Self::load_config_from_file(&config_path, &model_path_str) {
148 Ok(cfg) => cfg,
149 Err(e) => {
150 warn!("Failed to load NER config.json: {}. Using defaults.", e);
151 NerConfig {
152 model_path: model_path_str,
153 ..Default::default()
154 }
155 }
156 }
157 } else {
158 debug!("No config.json in model directory, using default label mappings");
159 NerConfig {
160 model_path: model_path_str,
161 ..Default::default()
162 }
163 }
164 } else {
165 NerConfig {
166 model_path: model_path_str,
167 ..Default::default()
168 }
169 };
170
171 Self::from_config(config)
172 }
173
174 fn load_config_from_file(config_path: &Path, model_path: &str) -> Result<NerConfig> {
179 let json_str = std::fs::read_to_string(config_path)?;
180 let raw: serde_json::Value = serde_json::from_str(&json_str)?;
181
182 let defaults = NerConfig::default();
183
184 let id2label = if let Some(obj) = raw.get("id2label").and_then(|v| v.as_object()) {
186 let mut map = HashMap::new();
187 for (k, v) in obj {
188 if let (Ok(id), Some(label)) = (k.parse::<usize>(), v.as_str()) {
189 map.insert(id, label.to_string());
190 }
191 }
192 map
193 } else {
194 defaults.id2label.clone()
195 };
196
197 let label_mappings =
200 if let Some(obj) = raw.get("label_mappings").and_then(|v| v.as_object()) {
201 let mut map = HashMap::new();
202 for (k, v) in obj {
203 if let Some(entity_str) = v.as_str() {
204 map.insert(k.clone(), EntityType::from(entity_str.to_string()));
205 }
206 }
207 map
208 } else {
209 let mut map = HashMap::new();
211 for label in id2label.values() {
212 if label == "O" {
213 continue;
214 }
215 let entity_type = label.split('-').next_back().unwrap_or(label);
216 match entity_type {
217 "PER" | "PERSON" => {
218 map.insert(label.clone(), EntityType::Person);
219 }
220 "ORG" | "ORGANIZATION" => {
221 map.insert(label.clone(), EntityType::Organization);
222 }
223 "LOC" | "LOCATION" | "GPE" => {
224 map.insert(label.clone(), EntityType::Location);
225 }
226 "DATE" | "TIME" | "DATETIME" => {
227 map.insert(label.clone(), EntityType::DateTime);
228 }
229 _ => {
230 debug!("Unmapped NER label: {} — no EntityType match", label);
231 }
232 }
233 }
234 map
235 };
236
237 let min_confidence = raw
238 .get("min_confidence")
239 .and_then(|v| v.as_f64())
240 .map(|v| v as f32)
241 .unwrap_or(defaults.min_confidence);
242
243 let max_seq_length = raw
244 .get("max_seq_length")
245 .and_then(|v| v.as_u64())
246 .map(|v| v as usize)
247 .unwrap_or(defaults.max_seq_length);
248
249 let tokenizer_path = None;
253
254 info!(
255 "Loaded NER config from {} ({} label mappings, {} id2label entries)",
256 config_path.display(),
257 label_mappings.len(),
258 id2label.len()
259 );
260
261 Ok(NerConfig {
262 model_path: model_path.to_string(),
263 tokenizer_path,
264 min_confidence,
265 max_seq_length,
266 label_mappings,
267 id2label,
268 })
269 }
270
271 pub fn from_config(config: NerConfig) -> Result<Self> {
273 let tokenizer = if let Some(ref tokenizer_path) = config.tokenizer_path {
275 debug!("Loading tokenizer from: {}", tokenizer_path);
276 match TokenizerWrapper::from_file(tokenizer_path) {
277 Ok(t) => {
278 info!("✓ Tokenizer loaded successfully from: {}", tokenizer_path);
279 Some(t)
280 }
281 Err(e) => {
282 warn!(
283 "Failed to load tokenizer: {}. NER will not be available.",
284 e
285 );
286 None
287 }
288 }
289 } else if !config.model_path.is_empty() {
290 let model_dir = Path::new(&config.model_path).parent();
292 if let Some(dir) = model_dir {
293 let tokenizer_json = dir.join("tokenizer.json");
294 if tokenizer_json.exists() {
295 debug!("Loading tokenizer from: {}", tokenizer_json.display());
296 match TokenizerWrapper::from_file(&tokenizer_json) {
297 Ok(t) => {
298 info!("✓ Tokenizer loaded successfully from model directory");
299 Some(t)
300 }
301 Err(e) => {
302 warn!("Failed to load tokenizer from model directory: {}", e);
303 None
304 }
305 }
306 } else {
307 debug!("No tokenizer.json found in model directory");
308 None
309 }
310 } else {
311 None
312 }
313 } else {
314 None
315 };
316
317 let session = if !config.model_path.is_empty() {
319 let model_path = Path::new(&config.model_path);
320 if model_path.exists() {
321 debug!("Loading ONNX model from: {}", config.model_path);
322 match Session::builder()?
323 .with_optimization_level(GraphOptimizationLevel::Level3)
324 .map_err(|e| anyhow::anyhow!("{e}"))?
325 .with_intra_threads(4)
326 .map_err(|e| anyhow::anyhow!("{e}"))?
327 .commit_from_file(&config.model_path)
328 {
329 Ok(s) => {
330 info!("✓ ONNX model loaded successfully: {}", config.model_path);
331 Some(Mutex::new(s))
332 }
333 Err(e) => {
334 warn!(
335 "Failed to load ONNX model: {}. NER will not be available.",
336 e
337 );
338 None
339 }
340 }
341 } else {
342 debug!(
343 "Model path provided but file does not exist: {}",
344 config.model_path
345 );
346 None
347 }
348 } else {
349 debug!("No model path provided, NER will not be available");
350 None
351 };
352
353 let needs_token_type_ids = session.as_ref().is_some_and(|s| {
356 let guard = s.lock().expect("session lock poisoned during init");
357 let has_it = guard
358 .inputs()
359 .iter()
360 .any(|input| input.name() == "token_type_ids");
361 if has_it {
362 debug!("Model declares token_type_ids input — will include in inference");
363 } else {
364 debug!("Model does not declare token_type_ids — omitting from inference");
365 }
366 has_it
367 });
368
369 let is_available = tokenizer.is_some() && session.is_some();
370 if is_available {
371 info!("✓ NER is fully operational with ONNX Runtime");
372 } else {
373 info!("⚠ NER not available - using pattern-based detection (36+ entity types)");
374 if tokenizer.is_none() {
375 debug!(" Missing: tokenizer");
376 }
377 if session.is_none() {
378 debug!(" Missing: ONNX model");
379 }
380 }
381
382 Ok(Self {
383 config,
384 tokenizer,
385 session,
386 needs_token_type_ids,
387 })
388 }
389
390 pub fn config(&self) -> &NerConfig {
392 &self.config
393 }
394
395 pub fn is_available(&self) -> bool {
397 self.tokenizer.is_some() && self.session.is_some()
398 }
399
400 fn map_label_to_entity(&self, label: &str) -> Option<EntityType> {
402 self.config.label_mappings.get(label).cloned()
403 }
404
405 fn infer(&self, input_ids: &[u32], attention_mask: &[u32]) -> Result<Vec<Vec<f32>>> {
407 let session_mutex = self
408 .session
409 .as_ref()
410 .ok_or_else(|| anyhow!("ONNX session not loaded"))?;
411
412 let mut session = session_mutex
413 .lock()
414 .map_err(|e| anyhow!("Failed to lock session: {}", e))?;
415
416 let seq_len = input_ids.len();
418 let input_ids_i64: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
419 let attention_mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
420
421 let input_ids_value = Value::from_array(([1, seq_len], input_ids_i64))?;
422 let attention_mask_value = Value::from_array(([1, seq_len], attention_mask_i64))?;
423
424 let mut inputs: Vec<(std::borrow::Cow<'_, str>, Value)> = vec![
427 ("input_ids".into(), input_ids_value.into()),
428 ("attention_mask".into(), attention_mask_value.into()),
429 ];
430
431 if self.needs_token_type_ids {
432 let token_type_ids_i64: Vec<i64> = vec![0i64; seq_len];
433 let token_type_ids_value = Value::from_array(([1, seq_len], token_type_ids_i64))?;
434 inputs.push(("token_type_ids".into(), token_type_ids_value.into()));
435 }
436
437 let outputs = session.run(inputs)?;
438
439 let (shape, logits_data) = outputs["logits"].try_extract_tensor::<f32>()?;
441 let shape_dims: &[i64] = shape.as_ref();
442
443 if shape_dims.len() != 3 || shape_dims[0] != 1 {
444 return Err(anyhow!("Unexpected logits shape: {:?}", shape_dims));
445 }
446
447 let seq_len_out = shape_dims[1] as usize;
448 let num_labels = shape_dims[2] as usize;
449
450 let mut result = Vec::new();
452 for i in 0..seq_len_out {
453 let mut token_logits = Vec::new();
454 for j in 0..num_labels {
455 let idx = i * num_labels + j;
456 token_logits.push(logits_data[idx]);
457 }
458 result.push(token_logits);
459 }
460
461 Ok(result)
462 }
463
464 fn softmax(logits: &[f32]) -> Vec<f32> {
466 let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
467 let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
468 logits
469 .iter()
470 .map(|&x| (x - max_logit).exp() / exp_sum)
471 .collect()
472 }
473
474 fn parse_bio_tags(
476 &self,
477 _text: &str,
478 predictions: &[usize],
479 probabilities: &[f32],
480 offsets: &[(usize, usize)],
481 ) -> Vec<RecognizerResult> {
482 let mut results = Vec::new();
483 let mut current_entity: Option<(EntityType, usize, usize, Vec<f32>)> = None;
484
485 for (idx, (&pred_id, &prob)) in predictions.iter().zip(probabilities.iter()).enumerate() {
486 if offsets[idx] == (0, 0) {
488 continue;
489 }
490
491 let label = self
492 .config
493 .id2label
494 .get(&pred_id)
495 .map(|s| s.as_str())
496 .unwrap_or("O");
497
498 if label.starts_with("B-") {
499 if let Some((entity_type, start, end, probs)) = current_entity.take() {
501 let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
502 if avg_confidence >= self.config.min_confidence {
503 results.push(RecognizerResult::new(
504 entity_type,
505 start,
506 end,
507 avg_confidence,
508 self.name(),
509 ));
510 }
511 }
512
513 if let Some(entity_type) = self.map_label_to_entity(label) {
515 let start = offsets[idx].0;
516 let end = offsets[idx].1;
517 current_entity = Some((entity_type, start, end, vec![prob]));
518 }
519 } else if label.starts_with("I-") {
520 if let Some((ref entity_type, start, ref mut end, ref mut probs)) = current_entity {
522 if let Some(label_entity) = self.map_label_to_entity(label) {
524 if label_entity == *entity_type {
525 *end = offsets[idx].1;
526 probs.push(prob);
527 } else {
528 let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
530 if avg_confidence >= self.config.min_confidence {
531 results.push(RecognizerResult::new(
532 entity_type.clone(),
533 start,
534 *end,
535 avg_confidence,
536 self.name(),
537 ));
538 }
539 current_entity = None;
540 }
541 }
542 }
543 } else {
544 if let Some((entity_type, start, end, probs)) = current_entity.take() {
546 let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
547 if avg_confidence >= self.config.min_confidence {
548 results.push(RecognizerResult::new(
549 entity_type,
550 start,
551 end,
552 avg_confidence,
553 self.name(),
554 ));
555 }
556 }
557 }
558 }
559
560 if let Some((entity_type, start, end, probs)) = current_entity {
562 let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
563 if avg_confidence >= self.config.min_confidence {
564 results.push(RecognizerResult::new(
565 entity_type,
566 start,
567 end,
568 avg_confidence,
569 self.name(),
570 ));
571 }
572 }
573
574 results
575 }
576}
577
578impl Recognizer for NerRecognizer {
579 fn name(&self) -> &str {
580 "NerRecognizer"
581 }
582
583 fn supported_entities(&self) -> &[EntityType] {
584 &[
585 EntityType::Person,
586 EntityType::Organization,
587 EntityType::Location,
588 EntityType::DateTime,
589 ]
590 }
591
592 fn analyze(&self, text: &str, _language: &str) -> Result<Vec<RecognizerResult>> {
593 if !self.is_available() {
595 return Ok(vec![]);
596 }
597
598 let tokenizer = self.tokenizer.as_ref().unwrap();
599
600 let mut encoding = tokenizer.encode(text, true)?;
602
603 let pad_id = tokenizer.get_padding_id().unwrap_or(0);
605
606 encoding.pad_to_length(self.config.max_seq_length, pad_id);
608
609 let logits = self.infer(&encoding.ids, &encoding.attention_mask)?;
611
612 let mut predictions = Vec::new();
614 let mut probabilities = Vec::new();
615
616 for token_logits in &logits {
617 let probs = Self::softmax(token_logits);
618 let (pred_id, &max_prob) = probs
619 .iter()
620 .enumerate()
621 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
622 .unwrap();
623 predictions.push(pred_id);
624 probabilities.push(max_prob);
625 }
626
627 let results = self.parse_bio_tags(text, &predictions, &probabilities, &encoding.offsets);
629
630 Ok(results)
631 }
632
633 fn supports_language(&self, language: &str) -> bool {
634 matches!(
636 language,
637 "en" | "es" | "fr" | "de" | "it" | "pt" | "nl" | "pl" | "ru" | "zh" | "ja" | "ko"
638 )
639 }
640}
641
642#[cfg(test)]
643mod tests {
644 use super::*;
645 use std::io::Write;
646
647 #[test]
648 fn test_default_config() {
649 let config = NerConfig::default();
650 assert_eq!(config.min_confidence, 0.7);
651 assert_eq!(config.max_seq_length, 512);
652 assert!(!config.label_mappings.is_empty());
653 }
654
655 #[test]
656 fn test_label_mapping() {
657 let config = NerConfig::default();
658 let recognizer = NerRecognizer::from_config(config).unwrap();
659
660 assert_eq!(
661 recognizer.map_label_to_entity("B-PER"),
662 Some(EntityType::Person)
663 );
664 assert_eq!(
665 recognizer.map_label_to_entity("B-ORG"),
666 Some(EntityType::Organization)
667 );
668 assert_eq!(recognizer.map_label_to_entity("O"), None);
669 }
670
671 #[test]
672 fn test_recognizer_without_model() {
673 let config = NerConfig::default();
674 let recognizer = NerRecognizer::from_config(config).unwrap();
675
676 assert!(!recognizer.is_available());
678
679 let results = recognizer.analyze("John Doe", "en").unwrap();
681 assert_eq!(results.len(), 0);
682 }
683
684 #[test]
685 fn test_recognizer_without_model_has_no_token_type_ids() {
686 let config = NerConfig::default();
687 let recognizer = NerRecognizer::from_config(config).unwrap();
688
689 assert!(!recognizer.needs_token_type_ids);
691 }
692
693 fn write_temp_config(contents: &str) -> tempfile::NamedTempFile {
697 let mut f = tempfile::NamedTempFile::new().unwrap();
698 f.write_all(contents.as_bytes()).unwrap();
699 f.flush().unwrap();
700 f
701 }
702
703 #[test]
704 fn test_load_config_valid_with_both_id2label_and_label_mappings() {
705 let json = r#"{
706 "id2label": {
707 "0": "O",
708 "1": "B-MISC",
709 "2": "I-MISC",
710 "3": "B-PER",
711 "4": "I-PER",
712 "5": "B-ORG",
713 "6": "I-ORG",
714 "7": "B-LOC",
715 "8": "I-LOC"
716 },
717 "label_mappings": {
718 "B-PER": "Person",
719 "I-PER": "Person",
720 "B-ORG": "Organization",
721 "I-ORG": "Organization",
722 "B-LOC": "Location",
723 "I-LOC": "Location"
724 },
725 "min_confidence": 0.8,
726 "max_seq_length": 256,
727 "tokenizer_path": "/build/time/tokenizer.json"
728 }"#;
729
730 let f = write_temp_config(json);
731 let cfg = NerRecognizer::load_config_from_file(f.path(), "/runtime/model.onnx").unwrap();
732
733 assert_eq!(cfg.id2label.len(), 9);
735 assert_eq!(cfg.id2label[&3], "B-PER");
736 assert_eq!(cfg.id2label[&5], "B-ORG");
737
738 assert_eq!(cfg.label_mappings.len(), 6);
740 assert_eq!(cfg.label_mappings["B-PER"], EntityType::Person);
741 assert_eq!(cfg.label_mappings["B-ORG"], EntityType::Organization);
742 assert_eq!(cfg.label_mappings["B-LOC"], EntityType::Location);
743
744 assert_eq!(cfg.min_confidence, 0.8);
746 assert_eq!(cfg.max_seq_length, 256);
747
748 assert_eq!(cfg.model_path, "/runtime/model.onnx");
750
751 assert!(cfg.tokenizer_path.is_none());
753 }
754
755 #[test]
756 fn test_load_config_fallback_derives_label_mappings_from_id2label() {
757 let json = r#"{
759 "id2label": {
760 "0": "O",
761 "1": "B-MISC",
762 "2": "I-MISC",
763 "3": "B-PER",
764 "4": "I-PER",
765 "5": "B-ORG",
766 "6": "I-ORG",
767 "7": "B-LOC",
768 "8": "I-LOC"
769 }
770 }"#;
771
772 let f = write_temp_config(json);
773 let cfg = NerRecognizer::load_config_from_file(f.path(), "/m.onnx").unwrap();
774
775 assert_eq!(cfg.label_mappings.get("B-PER"), Some(&EntityType::Person));
777 assert_eq!(cfg.label_mappings.get("I-PER"), Some(&EntityType::Person));
778 assert_eq!(
779 cfg.label_mappings.get("B-ORG"),
780 Some(&EntityType::Organization)
781 );
782 assert_eq!(cfg.label_mappings.get("B-LOC"), Some(&EntityType::Location));
783
784 assert!(cfg.label_mappings.get("B-MISC").is_none());
786 assert!(cfg.label_mappings.get("I-MISC").is_none());
787
788 assert!(cfg.label_mappings.get("B-DATE").is_none());
791 assert!(cfg.label_mappings.get("I-DATE").is_none());
792 }
793
794 #[test]
795 fn test_load_config_tokenizer_path_always_none() {
796 let json = r#"{
799 "tokenizer_path": "/out/models/tokenizer.json",
800 "id2label": { "0": "O", "1": "B-PER" }
801 }"#;
802
803 let f = write_temp_config(json);
804 let cfg = NerRecognizer::load_config_from_file(f.path(), "/m.onnx").unwrap();
805 assert!(cfg.tokenizer_path.is_none());
806 }
807
808 #[test]
809 fn test_load_config_malformed_json_returns_err() {
810 let f = write_temp_config("{ this is not valid json }}}");
811 let result = NerRecognizer::load_config_from_file(f.path(), "/m.onnx");
812 assert!(result.is_err());
813 }
814
815 #[test]
816 fn test_load_config_empty_json_uses_defaults() {
817 let f = write_temp_config("{}");
819 let cfg = NerRecognizer::load_config_from_file(f.path(), "/m.onnx").unwrap();
820
821 let defaults = NerConfig::default();
822 assert_eq!(cfg.min_confidence, defaults.min_confidence);
823 assert_eq!(cfg.max_seq_length, defaults.max_seq_length);
824 assert_eq!(cfg.id2label.len(), defaults.id2label.len());
826 }
827}