1use anyhow::{anyhow, Result};
7use ort::session::builder::GraphOptimizationLevel;
8use ort::session::Session;
9use ort::value::Value;
10use redact_core::{EntityType, Recognizer, RecognizerResult};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::path::Path;
14use std::sync::Mutex;
15use tracing::{debug, info, warn};
16
17use crate::tokenizer_wrapper::TokenizerWrapper;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct NerConfig {
22 pub model_path: String,
24
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub tokenizer_path: Option<String>,
28
29 #[serde(default = "default_confidence")]
31 pub min_confidence: f32,
32
33 #[serde(default = "default_max_length")]
35 pub max_seq_length: usize,
36
37 #[serde(default)]
39 pub label_mappings: HashMap<String, EntityType>,
40
41 #[serde(default)]
43 pub id2label: HashMap<usize, String>,
44}
45
46fn default_confidence() -> f32 {
47 0.7
48}
49
50fn default_max_length() -> usize {
51 512
52}
53
54impl Default for NerConfig {
55 fn default() -> Self {
56 let mut label_mappings = HashMap::new();
57 let mut id2label = HashMap::new();
58
59 label_mappings.insert("B-PER".to_string(), EntityType::Person);
61 label_mappings.insert("I-PER".to_string(), EntityType::Person);
62 label_mappings.insert("B-ORG".to_string(), EntityType::Organization);
63 label_mappings.insert("I-ORG".to_string(), EntityType::Organization);
64 label_mappings.insert("B-LOC".to_string(), EntityType::Location);
65 label_mappings.insert("I-LOC".to_string(), EntityType::Location);
66 label_mappings.insert("B-DATE".to_string(), EntityType::DateTime);
67 label_mappings.insert("I-DATE".to_string(), EntityType::DateTime);
68 label_mappings.insert("B-TIME".to_string(), EntityType::DateTime);
69 label_mappings.insert("I-TIME".to_string(), EntityType::DateTime);
70
71 id2label.insert(0, "O".to_string());
73 id2label.insert(1, "B-PER".to_string());
74 id2label.insert(2, "I-PER".to_string());
75 id2label.insert(3, "B-ORG".to_string());
76 id2label.insert(4, "I-ORG".to_string());
77 id2label.insert(5, "B-LOC".to_string());
78 id2label.insert(6, "I-LOC".to_string());
79 id2label.insert(7, "B-MISC".to_string());
80 id2label.insert(8, "I-MISC".to_string());
81
82 Self {
83 model_path: String::new(),
84 tokenizer_path: None,
85 min_confidence: default_confidence(),
86 max_seq_length: default_max_length(),
87 label_mappings,
88 id2label,
89 }
90 }
91}
92
93pub struct NerRecognizer {
112 config: NerConfig,
113 tokenizer: Option<TokenizerWrapper>,
114 session: Option<Mutex<Session>>,
115}
116
117impl std::fmt::Debug for NerRecognizer {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 f.debug_struct("NerRecognizer")
120 .field("config", &self.config)
121 .field("tokenizer", &self.tokenizer)
122 .field("session", &self.session.as_ref().map(|_| "Session"))
123 .finish()
124 }
125}
126
127impl NerRecognizer {
128 pub fn from_file<P: AsRef<Path>>(model_path: P) -> Result<Self> {
130 let config = NerConfig {
131 model_path: model_path.as_ref().to_string_lossy().to_string(),
132 ..Default::default()
133 };
134 Self::from_config(config)
135 }
136
137 pub fn from_config(config: NerConfig) -> Result<Self> {
139 let tokenizer = if let Some(ref tokenizer_path) = config.tokenizer_path {
141 debug!("Loading tokenizer from: {}", tokenizer_path);
142 match TokenizerWrapper::from_file(tokenizer_path) {
143 Ok(t) => {
144 info!("✓ Tokenizer loaded successfully from: {}", tokenizer_path);
145 Some(t)
146 }
147 Err(e) => {
148 warn!(
149 "Failed to load tokenizer: {}. NER will not be available.",
150 e
151 );
152 None
153 }
154 }
155 } else if !config.model_path.is_empty() {
156 let model_dir = Path::new(&config.model_path).parent();
158 if let Some(dir) = model_dir {
159 let tokenizer_json = dir.join("tokenizer.json");
160 if tokenizer_json.exists() {
161 debug!("Loading tokenizer from: {}", tokenizer_json.display());
162 match TokenizerWrapper::from_file(&tokenizer_json) {
163 Ok(t) => {
164 info!("✓ Tokenizer loaded successfully from model directory");
165 Some(t)
166 }
167 Err(e) => {
168 warn!("Failed to load tokenizer from model directory: {}", e);
169 None
170 }
171 }
172 } else {
173 debug!("No tokenizer.json found in model directory");
174 None
175 }
176 } else {
177 None
178 }
179 } else {
180 None
181 };
182
183 let session = if !config.model_path.is_empty() {
185 let model_path = Path::new(&config.model_path);
186 if model_path.exists() {
187 debug!("Loading ONNX model from: {}", config.model_path);
188 match Session::builder()?
189 .with_optimization_level(GraphOptimizationLevel::Level3)?
190 .with_intra_threads(4)?
191 .commit_from_file(&config.model_path)
192 {
193 Ok(s) => {
194 info!("✓ ONNX model loaded successfully: {}", config.model_path);
195 Some(Mutex::new(s))
196 }
197 Err(e) => {
198 warn!(
199 "Failed to load ONNX model: {}. NER will not be available.",
200 e
201 );
202 None
203 }
204 }
205 } else {
206 debug!(
207 "Model path provided but file does not exist: {}",
208 config.model_path
209 );
210 None
211 }
212 } else {
213 debug!("No model path provided, NER will not be available");
214 None
215 };
216
217 let is_available = tokenizer.is_some() && session.is_some();
218 if is_available {
219 info!("✓ NER is fully operational with ONNX Runtime");
220 } else {
221 info!("⚠ NER not available - using pattern-based detection (36+ entity types)");
222 if tokenizer.is_none() {
223 debug!(" Missing: tokenizer");
224 }
225 if session.is_none() {
226 debug!(" Missing: ONNX model");
227 }
228 }
229
230 Ok(Self {
231 config,
232 tokenizer,
233 session,
234 })
235 }
236
237 pub fn config(&self) -> &NerConfig {
239 &self.config
240 }
241
242 pub fn is_available(&self) -> bool {
244 self.tokenizer.is_some() && self.session.is_some()
245 }
246
247 fn map_label_to_entity(&self, label: &str) -> Option<EntityType> {
249 self.config.label_mappings.get(label).cloned()
250 }
251
252 fn infer(&self, input_ids: &[u32], attention_mask: &[u32]) -> Result<Vec<Vec<f32>>> {
254 let session_mutex = self
255 .session
256 .as_ref()
257 .ok_or_else(|| anyhow!("ONNX session not loaded"))?;
258
259 let mut session = session_mutex
260 .lock()
261 .map_err(|e| anyhow!("Failed to lock session: {}", e))?;
262
263 let seq_len = input_ids.len();
265 let input_ids_i64: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
266 let attention_mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
267
268 let input_ids_value = Value::from_array(([1, seq_len], input_ids_i64))?;
270 let attention_mask_value = Value::from_array(([1, seq_len], attention_mask_i64))?;
271
272 let outputs = session.run(ort::inputs![
274 "input_ids" => input_ids_value,
275 "attention_mask" => attention_mask_value,
276 ])?;
277
278 let (shape, logits_data) = outputs["logits"].try_extract_tensor::<f32>()?;
280 let shape_dims = shape.as_ref();
281
282 if shape_dims.len() != 3 || shape_dims[0] != 1 {
283 return Err(anyhow!("Unexpected logits shape: {:?}", shape_dims));
284 }
285
286 let seq_len_out = shape_dims[1] as usize;
287 let num_labels = shape_dims[2] as usize;
288
289 let mut result = Vec::new();
291 for i in 0..seq_len_out {
292 let mut token_logits = Vec::new();
293 for j in 0..num_labels {
294 let idx = i * num_labels + j;
295 token_logits.push(logits_data[idx]);
296 }
297 result.push(token_logits);
298 }
299
300 Ok(result)
301 }
302
303 fn softmax(logits: &[f32]) -> Vec<f32> {
305 let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
306 let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
307 logits
308 .iter()
309 .map(|&x| (x - max_logit).exp() / exp_sum)
310 .collect()
311 }
312
313 fn parse_bio_tags(
315 &self,
316 _text: &str,
317 predictions: &[usize],
318 probabilities: &[f32],
319 offsets: &[(usize, usize)],
320 ) -> Vec<RecognizerResult> {
321 let mut results = Vec::new();
322 let mut current_entity: Option<(EntityType, usize, usize, Vec<f32>)> = None;
323
324 for (idx, (&pred_id, &prob)) in predictions.iter().zip(probabilities.iter()).enumerate() {
325 if offsets[idx] == (0, 0) {
327 continue;
328 }
329
330 let label = self
331 .config
332 .id2label
333 .get(&pred_id)
334 .map(|s| s.as_str())
335 .unwrap_or("O");
336
337 if label.starts_with("B-") {
338 if let Some((entity_type, start, end, probs)) = current_entity.take() {
340 let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
341 if avg_confidence >= self.config.min_confidence {
342 results.push(RecognizerResult::new(
343 entity_type,
344 start,
345 end,
346 avg_confidence,
347 self.name(),
348 ));
349 }
350 }
351
352 if let Some(entity_type) = self.map_label_to_entity(label) {
354 let start = offsets[idx].0;
355 let end = offsets[idx].1;
356 current_entity = Some((entity_type, start, end, vec![prob]));
357 }
358 } else if label.starts_with("I-") {
359 if let Some((ref entity_type, start, ref mut end, ref mut probs)) = current_entity {
361 if let Some(label_entity) = self.map_label_to_entity(label) {
363 if label_entity == *entity_type {
364 *end = offsets[idx].1;
365 probs.push(prob);
366 } else {
367 let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
369 if avg_confidence >= self.config.min_confidence {
370 results.push(RecognizerResult::new(
371 entity_type.clone(),
372 start,
373 *end,
374 avg_confidence,
375 self.name(),
376 ));
377 }
378 current_entity = None;
379 }
380 }
381 }
382 } else {
383 if let Some((entity_type, start, end, probs)) = current_entity.take() {
385 let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
386 if avg_confidence >= self.config.min_confidence {
387 results.push(RecognizerResult::new(
388 entity_type,
389 start,
390 end,
391 avg_confidence,
392 self.name(),
393 ));
394 }
395 }
396 }
397 }
398
399 if let Some((entity_type, start, end, probs)) = current_entity {
401 let avg_confidence = probs.iter().sum::<f32>() / probs.len() as f32;
402 if avg_confidence >= self.config.min_confidence {
403 results.push(RecognizerResult::new(
404 entity_type,
405 start,
406 end,
407 avg_confidence,
408 self.name(),
409 ));
410 }
411 }
412
413 results
414 }
415}
416
417impl Recognizer for NerRecognizer {
418 fn name(&self) -> &str {
419 "NerRecognizer"
420 }
421
422 fn supported_entities(&self) -> &[EntityType] {
423 &[
424 EntityType::Person,
425 EntityType::Organization,
426 EntityType::Location,
427 EntityType::DateTime,
428 ]
429 }
430
431 fn analyze(&self, text: &str, _language: &str) -> Result<Vec<RecognizerResult>> {
432 if !self.is_available() {
434 return Ok(vec![]);
435 }
436
437 let tokenizer = self.tokenizer.as_ref().unwrap();
438
439 let mut encoding = tokenizer.encode(text, true)?;
441
442 let pad_id = tokenizer.get_padding_id().unwrap_or(0);
444
445 encoding.pad_to_length(self.config.max_seq_length, pad_id);
447
448 let logits = self.infer(&encoding.ids, &encoding.attention_mask)?;
450
451 let mut predictions = Vec::new();
453 let mut probabilities = Vec::new();
454
455 for token_logits in &logits {
456 let probs = Self::softmax(token_logits);
457 let (pred_id, &max_prob) = probs
458 .iter()
459 .enumerate()
460 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
461 .unwrap();
462 predictions.push(pred_id);
463 probabilities.push(max_prob);
464 }
465
466 let results = self.parse_bio_tags(text, &predictions, &probabilities, &encoding.offsets);
468
469 Ok(results)
470 }
471
472 fn supports_language(&self, language: &str) -> bool {
473 matches!(
475 language,
476 "en" | "es" | "fr" | "de" | "it" | "pt" | "nl" | "pl" | "ru" | "zh" | "ja" | "ko"
477 )
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484
485 #[test]
486 fn test_default_config() {
487 let config = NerConfig::default();
488 assert_eq!(config.min_confidence, 0.7);
489 assert_eq!(config.max_seq_length, 512);
490 assert!(!config.label_mappings.is_empty());
491 }
492
493 #[test]
494 fn test_label_mapping() {
495 let config = NerConfig::default();
496 let recognizer = NerRecognizer::from_config(config).unwrap();
497
498 assert_eq!(
499 recognizer.map_label_to_entity("B-PER"),
500 Some(EntityType::Person)
501 );
502 assert_eq!(
503 recognizer.map_label_to_entity("B-ORG"),
504 Some(EntityType::Organization)
505 );
506 assert_eq!(recognizer.map_label_to_entity("O"), None);
507 }
508
509 #[test]
510 fn test_recognizer_without_model() {
511 let config = NerConfig::default();
512 let recognizer = NerRecognizer::from_config(config).unwrap();
513
514 assert!(!recognizer.is_available());
516
517 let results = recognizer.analyze("John Doe", "en").unwrap();
519 assert_eq!(results.len(), 0);
520 }
521}