1use anyhow::{anyhow, Result};
6use ort::session::builder::GraphOptimizationLevel;
7use ort::session::Session;
8use ort::value::Value;
9use redact_core::{EntityType, Recognizer, RecognizerResult};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::path::Path;
13use std::sync::Mutex;
14use tracing::{debug, info, warn};
15
16use crate::tokenizer_wrapper::TokenizerWrapper;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct NerConfig {
21 pub model_path: String,
23
24 #[serde(skip_serializing_if = "Option::is_none")]
26 pub tokenizer_path: Option<String>,
27
28 #[serde(default = "default_confidence")]
30 pub min_confidence: f32,
31
32 #[serde(default = "default_max_length")]
34 pub max_seq_length: usize,
35
36 #[serde(default)]
38 pub label_mappings: HashMap<String, EntityType>,
39
40 #[serde(default)]
42 pub id2label: HashMap<usize, String>,
43}
44
45fn default_confidence() -> f32 {
46 0.7
47}
48
49fn default_max_length() -> usize {
50 512
51}
52
53impl Default for NerConfig {
54 fn default() -> Self {
55 let mut label_mappings = HashMap::new();
56 let mut id2label = HashMap::new();
57
58 label_mappings.insert("B-PER".to_string(), EntityType::Person);
60 label_mappings.insert("I-PER".to_string(), EntityType::Person);
61 label_mappings.insert("B-ORG".to_string(), EntityType::Organization);
62 label_mappings.insert("I-ORG".to_string(), EntityType::Organization);
63 label_mappings.insert("B-LOC".to_string(), EntityType::Location);
64 label_mappings.insert("I-LOC".to_string(), EntityType::Location);
65 label_mappings.insert("B-DATE".to_string(), EntityType::DateTime);
66 label_mappings.insert("I-DATE".to_string(), EntityType::DateTime);
67 label_mappings.insert("B-TIME".to_string(), EntityType::DateTime);
68 label_mappings.insert("I-TIME".to_string(), EntityType::DateTime);
69
70 id2label.insert(0, "O".to_string());
72 id2label.insert(1, "B-PER".to_string());
73 id2label.insert(2, "I-PER".to_string());
74 id2label.insert(3, "B-ORG".to_string());
75 id2label.insert(4, "I-ORG".to_string());
76 id2label.insert(5, "B-LOC".to_string());
77 id2label.insert(6, "I-LOC".to_string());
78 id2label.insert(7, "B-MISC".to_string());
79 id2label.insert(8, "I-MISC".to_string());
80
81 Self {
82 model_path: String::new(),
83 tokenizer_path: None,
84 min_confidence: default_confidence(),
85 max_seq_length: default_max_length(),
86 label_mappings,
87 id2label,
88 }
89 }
90}
91
92pub struct NerRecognizer {
111 config: NerConfig,
112 tokenizer: Option<TokenizerWrapper>,
113 session: Option<Mutex<Session>>,
114 needs_token_type_ids: bool,
118}
119
120impl std::fmt::Debug for NerRecognizer {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 f.debug_struct("NerRecognizer")
123 .field("config", &self.config)
124 .field("tokenizer", &self.tokenizer)
125 .field("session", &self.session.as_ref().map(|_| "Session"))
126 .field("needs_token_type_ids", &self.needs_token_type_ids)
127 .finish()
128 }
129}
130
131impl NerRecognizer {
132 pub fn from_file<P: AsRef<Path>>(model_path: P) -> Result<Self> {
138 let model_path_ref = model_path.as_ref();
139 let model_path_str = model_path_ref.to_string_lossy().to_string();
140
141 let config = if let Some(model_dir) = model_path_ref.parent() {
143 let config_path = model_dir.join("config.json");
144 if config_path.exists() {
145 debug!("Loading NER config from: {}", config_path.display());
146 match Self::load_config_from_file(&config_path, &model_path_str) {
147 Ok(cfg) => cfg,
148 Err(e) => {
149 warn!("Failed to load NER config.json: {}. Using defaults.", e);
150 NerConfig {
151 model_path: model_path_str,
152 ..Default::default()
153 }
154 }
155 }
156 } else {
157 debug!("No config.json in model directory, using default label mappings");
158 NerConfig {
159 model_path: model_path_str,
160 ..Default::default()
161 }
162 }
163 } else {
164 NerConfig {
165 model_path: model_path_str,
166 ..Default::default()
167 }
168 };
169
170 Self::from_config(config)
171 }
172
173 fn load_config_from_file(config_path: &Path, model_path: &str) -> Result<NerConfig> {
178 let json_str = std::fs::read_to_string(config_path)?;
179 let raw: serde_json::Value = serde_json::from_str(&json_str)?;
180
181 let defaults = NerConfig::default();
182
183 let id2label = if let Some(obj) = raw.get("id2label").and_then(|v| v.as_object()) {
185 let mut map = HashMap::new();
186 for (k, v) in obj {
187 if let (Ok(id), Some(label)) = (k.parse::<usize>(), v.as_str()) {
188 map.insert(id, label.to_string());
189 }
190 }
191 map
192 } else {
193 defaults.id2label.clone()
194 };
195
196 let label_mappings =
199 if let Some(obj) = raw.get("label_mappings").and_then(|v| v.as_object()) {
200 let mut map = HashMap::new();
201 for (k, v) in obj {
202 if let Some(entity_str) = v.as_str() {
203 map.insert(k.clone(), EntityType::from(entity_str.to_string()));
204 }
205 }
206 map
207 } else {
208 let mut map = HashMap::new();
210 for label in id2label.values() {
211 if label == "O" {
212 continue;
213 }
214 let entity_type = label.split('-').next_back().unwrap_or(label);
215 match entity_type {
216 "PER" | "PERSON" => {
217 map.insert(label.clone(), EntityType::Person);
218 }
219 "ORG" | "ORGANIZATION" => {
220 map.insert(label.clone(), EntityType::Organization);
221 }
222 "LOC" | "LOCATION" | "GPE" => {
223 map.insert(label.clone(), EntityType::Location);
224 }
225 "DATE" | "TIME" | "DATETIME" => {
226 map.insert(label.clone(), EntityType::DateTime);
227 }
228 _ => {
229 debug!("Unmapped NER label: {} — no EntityType match", label);
230 }
231 }
232 }
233 map
234 };
235
236 let min_confidence = raw
237 .get("min_confidence")
238 .and_then(|v| v.as_f64())
239 .map(|v| v as f32)
240 .unwrap_or(defaults.min_confidence);
241
242 let max_seq_length = raw
243 .get("max_seq_length")
244 .and_then(|v| v.as_u64())
245 .map(|v| v as usize)
246 .unwrap_or(defaults.max_seq_length);
247
248 let tokenizer_path = None;
252
253 info!(
254 "Loaded NER config from {} ({} label mappings, {} id2label entries)",
255 config_path.display(),
256 label_mappings.len(),
257 id2label.len()
258 );
259
260 Ok(NerConfig {
261 model_path: model_path.to_string(),
262 tokenizer_path,
263 min_confidence,
264 max_seq_length,
265 label_mappings,
266 id2label,
267 })
268 }
269
270 pub fn from_config(config: NerConfig) -> Result<Self> {
272 let tokenizer = if let Some(ref tokenizer_path) = config.tokenizer_path {
274 debug!("Loading tokenizer from: {}", tokenizer_path);
275 match TokenizerWrapper::from_file(tokenizer_path) {
276 Ok(t) => {
277 info!("✓ Tokenizer loaded successfully from: {}", tokenizer_path);
278 Some(t)
279 }
280 Err(e) => {
281 warn!(
282 "Failed to load tokenizer: {}. NER will not be available.",
283 e
284 );
285 None
286 }
287 }
288 } else if !config.model_path.is_empty() {
289 let model_dir = Path::new(&config.model_path).parent();
291 if let Some(dir) = model_dir {
292 let tokenizer_json = dir.join("tokenizer.json");
293 if tokenizer_json.exists() {
294 debug!("Loading tokenizer from: {}", tokenizer_json.display());
295 match TokenizerWrapper::from_file(&tokenizer_json) {
296 Ok(t) => {
297 info!("✓ Tokenizer loaded successfully from model directory");
298 Some(t)
299 }
300 Err(e) => {
301 warn!("Failed to load tokenizer from model directory: {}", e);
302 None
303 }
304 }
305 } else {
306 debug!("No tokenizer.json found in model directory");
307 None
308 }
309 } else {
310 None
311 }
312 } else {
313 None
314 };
315
316 let session = if !config.model_path.is_empty() {
318 let model_path = Path::new(&config.model_path);
319 if model_path.exists() {
320 debug!("Loading ONNX model from: {}", config.model_path);
321 match Session::builder()?
322 .with_optimization_level(GraphOptimizationLevel::Level3)
323 .map_err(|e| anyhow::anyhow!("{e}"))?
324 .with_intra_threads(4)
325 .map_err(|e| anyhow::anyhow!("{e}"))?
326 .commit_from_file(&config.model_path)
327 {
328 Ok(s) => {
329 info!("✓ ONNX model loaded successfully: {}", config.model_path);
330 Some(Mutex::new(s))
331 }
332 Err(e) => {
333 warn!(
334 "Failed to load ONNX model: {}. NER will not be available.",
335 e
336 );
337 None
338 }
339 }
340 } else {
341 debug!(
342 "Model path provided but file does not exist: {}",
343 config.model_path
344 );
345 None
346 }
347 } else {
348 debug!("No model path provided, NER will not be available");
349 None
350 };
351
352 let needs_token_type_ids = session.as_ref().is_some_and(|s| {
355 let guard = s.lock().expect("session lock poisoned during init");
356 let has_it = guard
357 .inputs()
358 .iter()
359 .any(|input| input.name() == "token_type_ids");
360 if has_it {
361 debug!("Model declares token_type_ids input — will include in inference");
362 } else {
363 debug!("Model does not declare token_type_ids — omitting from inference");
364 }
365 has_it
366 });
367
368 let is_available = tokenizer.is_some() && session.is_some();
369 if is_available {
370 info!("✓ NER is fully operational with ONNX Runtime");
371 } else {
372 info!("⚠ NER not available - using pattern-based detection (36+ entity types)");
373 if tokenizer.is_none() {
374 debug!(" Missing: tokenizer");
375 }
376 if session.is_none() {
377 debug!(" Missing: ONNX model");
378 }
379 }
380
381 Ok(Self {
382 config,
383 tokenizer,
384 session,
385 needs_token_type_ids,
386 })
387 }
388
389 pub fn config(&self) -> &NerConfig {
391 &self.config
392 }
393
394 pub fn is_available(&self) -> bool {
396 self.tokenizer.is_some() && self.session.is_some()
397 }
398
399 fn map_label_to_entity(&self, label: &str) -> Option<EntityType> {
401 self.config.label_mappings.get(label).cloned()
402 }
403
404 fn infer(&self, input_ids: &[u32], attention_mask: &[u32]) -> Result<Vec<Vec<f32>>> {
406 let session_mutex = self
407 .session
408 .as_ref()
409 .ok_or_else(|| anyhow!("ONNX session not loaded"))?;
410
411 let mut session = session_mutex
412 .lock()
413 .map_err(|e| anyhow!("Failed to lock session: {}", e))?;
414
415 let seq_len = input_ids.len();
417 let input_ids_i64: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
418 let attention_mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
419
420 let input_ids_value = Value::from_array(([1, seq_len], input_ids_i64))?;
421 let attention_mask_value = Value::from_array(([1, seq_len], attention_mask_i64))?;
422
423 let mut inputs: Vec<(std::borrow::Cow<'_, str>, Value)> = vec![
426 ("input_ids".into(), input_ids_value.into()),
427 ("attention_mask".into(), attention_mask_value.into()),
428 ];
429
430 if self.needs_token_type_ids {
431 let token_type_ids_i64: Vec<i64> = vec![0i64; seq_len];
432 let token_type_ids_value = Value::from_array(([1, seq_len], token_type_ids_i64))?;
433 inputs.push(("token_type_ids".into(), token_type_ids_value.into()));
434 }
435
436 let outputs = session.run(inputs)?;
437
438 let (shape, logits_data) = outputs["logits"].try_extract_tensor::<f32>()?;
440 let shape_dims: &[i64] = shape.as_ref();
441
442 if shape_dims.len() != 3 || shape_dims[0] != 1 {
443 return Err(anyhow!("Unexpected logits shape: {:?}", shape_dims));
444 }
445
446 let seq_len_out = shape_dims[1] as usize;
447 let num_labels = shape_dims[2] as usize;
448
449 let mut result = Vec::new();
451 for i in 0..seq_len_out {
452 let mut token_logits = Vec::new();
453 for j in 0..num_labels {
454 let idx = i * num_labels + j;
455 token_logits.push(logits_data[idx]);
456 }
457 result.push(token_logits);
458 }
459
460 Ok(result)
461 }
462
463 fn softmax(logits: &[f32]) -> Vec<f32> {
465 let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
466 let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
467 logits
468 .iter()
469 .map(|&x| (x - max_logit).exp() / exp_sum)
470 .collect()
471 }
472
473 fn parse_bio_tags(
475 &self,
476 _text: &str,
477 predictions: &[usize],
478 probabilities: &[f32],
479 offsets: &[(usize, usize)],
480 ) -> Vec<RecognizerResult> {
481 let mut results = Vec::new();
482 let mut current_entity: Option<(EntityType, usize, usize, Vec<f32>)> = None;
483
484 for (idx, (&pred_id, &prob)) in predictions.iter().zip(probabilities.iter()).enumerate() {
485 if offsets[idx] == (0, 0) {
487 continue;
488 }
489
490 let label = self
491 .config
492 .id2label
493 .get(&pred_id)
494 .map(|s| s.as_str())
495 .unwrap_or("O");
496
497 if label.starts_with("B-") {
498 if let Some((entity_type, start, end, probs)) = current_entity.take() {
500 let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
501 if avg_confidence >= self.config.min_confidence {
502 results.push(RecognizerResult::new(
503 entity_type,
504 start,
505 end,
506 avg_confidence,
507 self.name(),
508 ));
509 }
510 }
511
512 if let Some(entity_type) = self.map_label_to_entity(label) {
514 let start = offsets[idx].0;
515 let end = offsets[idx].1;
516 current_entity = Some((entity_type, start, end, vec![prob]));
517 }
518 } else if label.starts_with("I-") {
519 if let Some((ref entity_type, start, ref mut end, ref mut probs)) = current_entity {
521 if let Some(label_entity) = self.map_label_to_entity(label) {
523 if label_entity == *entity_type {
524 *end = offsets[idx].1;
525 probs.push(prob);
526 } else {
527 let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
529 if avg_confidence >= self.config.min_confidence {
530 results.push(RecognizerResult::new(
531 entity_type.clone(),
532 start,
533 *end,
534 avg_confidence,
535 self.name(),
536 ));
537 }
538 current_entity = None;
539 }
540 }
541 }
542 } else {
543 if let Some((entity_type, start, end, probs)) = current_entity.take() {
545 let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
546 if avg_confidence >= self.config.min_confidence {
547 results.push(RecognizerResult::new(
548 entity_type,
549 start,
550 end,
551 avg_confidence,
552 self.name(),
553 ));
554 }
555 }
556 }
557 }
558
559 if let Some((entity_type, start, end, probs)) = current_entity {
561 let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
562 if avg_confidence >= self.config.min_confidence {
563 results.push(RecognizerResult::new(
564 entity_type,
565 start,
566 end,
567 avg_confidence,
568 self.name(),
569 ));
570 }
571 }
572
573 results
574 }
575}
576
577impl Recognizer for NerRecognizer {
578 fn name(&self) -> &str {
579 "NerRecognizer"
580 }
581
582 fn supported_entities(&self) -> &[EntityType] {
583 &[
584 EntityType::Person,
585 EntityType::Organization,
586 EntityType::Location,
587 EntityType::DateTime,
588 ]
589 }
590
591 fn analyze(&self, text: &str, _language: &str) -> Result<Vec<RecognizerResult>> {
592 if !self.is_available() {
594 return Ok(vec![]);
595 }
596
597 let tokenizer = self.tokenizer.as_ref().unwrap();
598
599 let mut encoding = tokenizer.encode(text, true)?;
601
602 let pad_id = tokenizer.get_padding_id().unwrap_or(0);
604
605 encoding.pad_to_length(self.config.max_seq_length, pad_id);
607
608 let logits = self.infer(&encoding.ids, &encoding.attention_mask)?;
610
611 let mut predictions = Vec::new();
613 let mut probabilities = Vec::new();
614
615 for token_logits in &logits {
616 let probs = Self::softmax(token_logits);
617 let (pred_id, &max_prob) = probs
618 .iter()
619 .enumerate()
620 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
621 .unwrap();
622 predictions.push(pred_id);
623 probabilities.push(max_prob);
624 }
625
626 let results = self.parse_bio_tags(text, &predictions, &probabilities, &encoding.offsets);
628
629 Ok(results)
630 }
631
632 fn supports_language(&self, language: &str) -> bool {
633 matches!(
635 language,
636 "en" | "es" | "fr" | "de" | "it" | "pt" | "nl" | "pl" | "ru" | "zh" | "ja" | "ko"
637 )
638 }
639}
640
641#[cfg(test)]
642mod tests {
643 use super::*;
644 use std::io::Write;
645
646 #[test]
647 fn test_default_config() {
648 let config = NerConfig::default();
649 assert_eq!(config.min_confidence, 0.7);
650 assert_eq!(config.max_seq_length, 512);
651 assert!(!config.label_mappings.is_empty());
652 }
653
654 #[test]
655 fn test_label_mapping() {
656 let config = NerConfig::default();
657 let recognizer = NerRecognizer::from_config(config).unwrap();
658
659 assert_eq!(
660 recognizer.map_label_to_entity("B-PER"),
661 Some(EntityType::Person)
662 );
663 assert_eq!(
664 recognizer.map_label_to_entity("B-ORG"),
665 Some(EntityType::Organization)
666 );
667 assert_eq!(recognizer.map_label_to_entity("O"), None);
668 }
669
670 #[test]
671 fn test_recognizer_without_model() {
672 let config = NerConfig::default();
673 let recognizer = NerRecognizer::from_config(config).unwrap();
674
675 assert!(!recognizer.is_available());
677
678 let results = recognizer.analyze("John Doe", "en").unwrap();
680 assert_eq!(results.len(), 0);
681 }
682
683 #[test]
684 fn test_recognizer_without_model_has_no_token_type_ids() {
685 let config = NerConfig::default();
686 let recognizer = NerRecognizer::from_config(config).unwrap();
687
688 assert!(!recognizer.needs_token_type_ids);
690 }
691
692 fn write_temp_config(contents: &str) -> tempfile::NamedTempFile {
696 let mut f = tempfile::NamedTempFile::new().unwrap();
697 f.write_all(contents.as_bytes()).unwrap();
698 f.flush().unwrap();
699 f
700 }
701
702 #[test]
703 fn test_load_config_valid_with_both_id2label_and_label_mappings() {
704 let json = r#"{
705 "id2label": {
706 "0": "O",
707 "1": "B-MISC",
708 "2": "I-MISC",
709 "3": "B-PER",
710 "4": "I-PER",
711 "5": "B-ORG",
712 "6": "I-ORG",
713 "7": "B-LOC",
714 "8": "I-LOC"
715 },
716 "label_mappings": {
717 "B-PER": "Person",
718 "I-PER": "Person",
719 "B-ORG": "Organization",
720 "I-ORG": "Organization",
721 "B-LOC": "Location",
722 "I-LOC": "Location"
723 },
724 "min_confidence": 0.8,
725 "max_seq_length": 256,
726 "tokenizer_path": "/build/time/tokenizer.json"
727 }"#;
728
729 let f = write_temp_config(json);
730 let cfg = NerRecognizer::load_config_from_file(f.path(), "/runtime/model.onnx").unwrap();
731
732 assert_eq!(cfg.id2label.len(), 9);
734 assert_eq!(cfg.id2label[&3], "B-PER");
735 assert_eq!(cfg.id2label[&5], "B-ORG");
736
737 assert_eq!(cfg.label_mappings.len(), 6);
739 assert_eq!(cfg.label_mappings["B-PER"], EntityType::Person);
740 assert_eq!(cfg.label_mappings["B-ORG"], EntityType::Organization);
741 assert_eq!(cfg.label_mappings["B-LOC"], EntityType::Location);
742
743 assert_eq!(cfg.min_confidence, 0.8);
745 assert_eq!(cfg.max_seq_length, 256);
746
747 assert_eq!(cfg.model_path, "/runtime/model.onnx");
749
750 assert!(cfg.tokenizer_path.is_none());
752 }
753
754 #[test]
755 fn test_load_config_fallback_derives_label_mappings_from_id2label() {
756 let json = r#"{
758 "id2label": {
759 "0": "O",
760 "1": "B-MISC",
761 "2": "I-MISC",
762 "3": "B-PER",
763 "4": "I-PER",
764 "5": "B-ORG",
765 "6": "I-ORG",
766 "7": "B-LOC",
767 "8": "I-LOC"
768 }
769 }"#;
770
771 let f = write_temp_config(json);
772 let cfg = NerRecognizer::load_config_from_file(f.path(), "/m.onnx").unwrap();
773
774 assert_eq!(cfg.label_mappings.get("B-PER"), Some(&EntityType::Person));
776 assert_eq!(cfg.label_mappings.get("I-PER"), Some(&EntityType::Person));
777 assert_eq!(
778 cfg.label_mappings.get("B-ORG"),
779 Some(&EntityType::Organization)
780 );
781 assert_eq!(cfg.label_mappings.get("B-LOC"), Some(&EntityType::Location));
782
783 assert!(cfg.label_mappings.get("B-MISC").is_none());
785 assert!(cfg.label_mappings.get("I-MISC").is_none());
786
787 assert!(cfg.label_mappings.get("B-DATE").is_none());
790 assert!(cfg.label_mappings.get("I-DATE").is_none());
791 }
792
793 #[test]
794 fn test_load_config_tokenizer_path_always_none() {
795 let json = r#"{
798 "tokenizer_path": "/out/models/tokenizer.json",
799 "id2label": { "0": "O", "1": "B-PER" }
800 }"#;
801
802 let f = write_temp_config(json);
803 let cfg = NerRecognizer::load_config_from_file(f.path(), "/m.onnx").unwrap();
804 assert!(cfg.tokenizer_path.is_none());
805 }
806
807 #[test]
808 fn test_load_config_malformed_json_returns_err() {
809 let f = write_temp_config("{ this is not valid json }}}");
810 let result = NerRecognizer::load_config_from_file(f.path(), "/m.onnx");
811 assert!(result.is_err());
812 }
813
814 #[test]
815 fn test_load_config_empty_json_uses_defaults() {
816 let f = write_temp_config("{}");
818 let cfg = NerRecognizer::load_config_from_file(f.path(), "/m.onnx").unwrap();
819
820 let defaults = NerConfig::default();
821 assert_eq!(cfg.min_confidence, defaults.min_confidence);
822 assert_eq!(cfg.max_seq_length, defaults.max_seq_length);
823 assert_eq!(cfg.id2label.len(), defaults.id2label.len());
825 }
826}