1#![allow(missing_docs)] #![allow(dead_code)] #![allow(clippy::manual_strip)] use crate::{Entity, Error, Result};
22#[cfg(feature = "onnx")]
23use anno_core::EntityType;
24
25#[cfg(feature = "onnx")]
26use {
27 crate::sync::lock,
28 hf_hub::api::sync::Api,
29 ndarray::Array2,
30 ort::{session::builder::GraphOptimizationLevel, session::Session},
31 std::collections::HashMap,
32 tokenizers::Tokenizer,
33};
34
35pub const DEFAULT_BERT_NER_MODEL: &str = "protectai/bert-base-NER-onnx";
37
38#[cfg(feature = "onnx")]
40#[derive(Debug, Clone)]
41pub struct BertNERConfig {
42 pub prefer_quantized: bool,
44 pub optimization_level: u8,
46 pub num_threads: usize,
48}
49
50#[cfg(feature = "onnx")]
51impl Default for BertNERConfig {
52 fn default() -> Self {
53 Self {
54 prefer_quantized: true,
55 optimization_level: 3,
56 num_threads: 4,
57 }
58 }
59}
60
61#[cfg(feature = "onnx")]
66pub struct BertNEROnnx {
67 session: crate::sync::Mutex<Session>,
68 tokenizer: std::sync::Arc<Tokenizer>,
70 id_to_label: HashMap<usize, String>,
71 label_to_entity_type: HashMap<String, EntityType>,
72 model_name: String,
73 is_quantized: bool,
75}
76
77#[cfg(feature = "onnx")]
78impl BertNEROnnx {
79 pub fn new(model_name: &str) -> Result<Self> {
87 Self::with_config(model_name, BertNERConfig::default())
88 }
89
90 pub fn with_config(model_name: &str, config: BertNERConfig) -> Result<Self> {
107 let api = Api::new().map_err(|e| {
108 Error::Retrieval(format!("Failed to initialize HuggingFace API: {}", e))
109 })?;
110
111 let repo = api.model(model_name.to_string());
112
113 let (model_path, is_quantized) = if config.prefer_quantized {
115 if let Ok(path) = repo.get("model_quantized.onnx") {
116 log::info!("[BERT-NER] Using quantized model (INT8)");
117 (path, true)
118 } else if let Ok(path) = repo.get("onnx/model_quantized.onnx") {
119 log::info!("[BERT-NER] Using quantized model (INT8)");
120 (path, true)
121 } else if let Ok(path) = repo.get("model_int8.onnx") {
122 log::info!("[BERT-NER] Using INT8 quantized model");
123 (path, true)
124 } else {
125 let path = repo
127 .get("model.onnx")
128 .or_else(|_| repo.get("onnx/model.onnx"))
129 .map_err(|e| {
130 Error::Retrieval(format!("Failed to download model.onnx: {}", e))
131 })?;
132 log::info!("[BERT-NER] Using FP32 model (quantized not available)");
133 (path, false)
134 }
135 } else {
136 let path = repo
137 .get("model.onnx")
138 .or_else(|_| repo.get("onnx/model.onnx"))
139 .map_err(|e| Error::Retrieval(format!("Failed to download model.onnx: {}", e)))?;
140 (path, false)
141 };
142
143 let tokenizer_path = repo
145 .get("tokenizer.json")
146 .map_err(|e| Error::Retrieval(format!("Failed to download tokenizer.json: {}", e)))?;
147
148 let config_path = repo
150 .get("config.json")
151 .map_err(|e| Error::Retrieval(format!("Failed to download config.json: {}", e)))?;
152
153 let tokenizer = Tokenizer::from_file(&tokenizer_path)
155 .map_err(|e| Error::Retrieval(format!("Failed to load tokenizer: {}", e)))?;
156
157 let config_str = std::fs::read_to_string(&config_path)
159 .map_err(|e| Error::Retrieval(format!("Failed to read config.json: {}", e)))?;
160 let config_json: serde_json::Value = serde_json::from_str(&config_str)
161 .map_err(|e| Error::Parse(format!("Failed to parse config.json: {}", e)))?;
162
163 let id_to_label = Self::build_id_to_label(&config_json);
165 let label_to_entity_type = Self::build_label_to_entity_type();
166
167 let opt_level = match config.optimization_level {
169 1 => GraphOptimizationLevel::Level1,
170 2 => GraphOptimizationLevel::Level2,
171 _ => GraphOptimizationLevel::Level3,
172 };
173
174 let mut builder = Session::builder()
175 .map_err(|e| Error::Retrieval(format!("Failed to create session builder: {}", e)))?
176 .with_optimization_level(opt_level)
177 .map_err(|e| Error::Retrieval(format!("Failed to set optimization level: {}", e)))?;
178
179 if config.num_threads > 0 {
180 builder = builder
181 .with_intra_threads(config.num_threads)
182 .map_err(|e| Error::Retrieval(format!("Failed to set threads: {}", e)))?;
183 }
184
185 let session = builder
186 .commit_from_file(&model_path)
187 .map_err(|e| Error::Retrieval(format!("Failed to load ONNX model: {}", e)))?;
188
189 Ok(Self {
190 session: crate::sync::Mutex::new(session),
191 tokenizer: std::sync::Arc::new(tokenizer),
192 id_to_label,
193 label_to_entity_type,
194 model_name: model_name.to_string(),
195 is_quantized,
196 })
197 }
198
199 #[must_use]
201 pub fn is_quantized(&self) -> bool {
202 self.is_quantized
203 }
204
205 #[must_use]
207 pub fn tokenizer(&self) -> std::sync::Arc<Tokenizer> {
208 std::sync::Arc::clone(&self.tokenizer)
209 }
210
211 fn build_id_to_label(config_json: &serde_json::Value) -> HashMap<usize, String> {
213 let mut map = HashMap::new();
214 if let Some(id2label) = config_json.get("id2label") {
215 if let Some(obj) = id2label.as_object() {
216 for (id_str, label_value) in obj {
217 if let (Ok(id), Some(label)) = (id_str.parse::<usize>(), label_value.as_str()) {
218 map.insert(id, label.to_string());
219 }
220 }
221 }
222 }
223 if map.is_empty() {
225 map.insert(0, "O".to_string());
226 map.insert(1, "B-MISC".to_string());
227 map.insert(2, "I-MISC".to_string());
228 map.insert(3, "B-PER".to_string());
229 map.insert(4, "I-PER".to_string());
230 map.insert(5, "B-ORG".to_string());
231 map.insert(6, "I-ORG".to_string());
232 map.insert(7, "B-LOC".to_string());
233 map.insert(8, "I-LOC".to_string());
234 }
235 map
236 }
237
238 fn build_label_to_entity_type() -> HashMap<String, EntityType> {
240 let mut map = HashMap::new();
241 map.insert("B-PER".to_string(), EntityType::Person);
243 map.insert("I-PER".to_string(), EntityType::Person);
244 map.insert("B-ORG".to_string(), EntityType::Organization);
245 map.insert("I-ORG".to_string(), EntityType::Organization);
246 map.insert("B-LOC".to_string(), EntityType::Location);
247 map.insert("I-LOC".to_string(), EntityType::Location);
248 map.insert("B-MISC".to_string(), EntityType::Other("misc".to_string()));
249 map.insert("I-MISC".to_string(), EntityType::Other("misc".to_string()));
250 map.insert("PER".to_string(), EntityType::Person);
252 map.insert("ORG".to_string(), EntityType::Organization);
253 map.insert("LOC".to_string(), EntityType::Location);
254 map.insert("MISC".to_string(), EntityType::Other("misc".to_string()));
255 map
256 }
257
258 pub fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
267 if text.is_empty() {
268 return Ok(vec![]);
269 }
270
271 let encoding = self
273 .tokenizer
274 .encode(text, true)
275 .map_err(|e| Error::Parse(format!("Failed to tokenize input: {}", e)))?;
276
277 let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
278 let attention_mask: Vec<i64> = encoding
279 .get_attention_mask()
280 .iter()
281 .map(|&mask| mask as i64)
282 .collect();
283 let token_type_ids: Vec<i64> = vec![0i64; input_ids.len()];
286
287 let batch_size = 1;
288 let seq_len = input_ids.len();
289
290 let input_ids_array: Array2<i64> =
292 Array2::from_shape_vec((batch_size, seq_len), input_ids.clone())
293 .map_err(|e| Error::Parse(format!("Failed to create input_ids array: {}", e)))?;
294
295 let attention_mask_array: Array2<i64> =
296 Array2::from_shape_vec((batch_size, seq_len), attention_mask.clone()).map_err(|e| {
297 Error::Parse(format!("Failed to create attention_mask array: {}", e))
298 })?;
299
300 let token_type_ids_array: Array2<i64> =
301 Array2::from_shape_vec((batch_size, seq_len), token_type_ids).map_err(|e| {
302 Error::Parse(format!("Failed to create token_type_ids array: {}", e))
303 })?;
304
305 let input_ids_tensor = super::ort_compat::tensor_from_ndarray(input_ids_array)
306 .map_err(|e| Error::Parse(format!("Failed to create input_ids tensor: {}", e)))?;
307
308 let attention_mask_tensor = super::ort_compat::tensor_from_ndarray(attention_mask_array)
309 .map_err(|e| Error::Parse(format!("Failed to create attention_mask tensor: {}", e)))?;
310
311 let token_type_ids_tensor = super::ort_compat::tensor_from_ndarray(token_type_ids_array)
312 .map_err(|e| Error::Parse(format!("Failed to create token_type_ids tensor: {}", e)))?;
313
314 let mut session = lock(&self.session);
316
317 let outputs = session
318 .run(ort::inputs![
319 "input_ids" => input_ids_tensor.into_dyn(),
320 "attention_mask" => attention_mask_tensor.into_dyn(),
321 "token_type_ids" => token_type_ids_tensor.into_dyn(),
322 ])
323 .map_err(|e| Error::Parse(format!("ONNX inference failed: {}", e)))?;
324
325 let logits = outputs.get("logits").ok_or_else(|| {
327 Error::Parse("ONNX model output does not contain 'logits' key".to_string())
328 })?;
329
330 self.decode_output(logits, text, &encoding)
332 }
333
334 fn decode_output(
336 &self,
337 output: &ort::value::DynValue,
338 text: &str,
339 encoding: &tokenizers::Encoding,
340 ) -> Result<Vec<Entity>> {
341 let (shape, logits_data) = output
343 .try_extract_tensor::<f32>()
344 .map_err(|e| Error::Parse(format!("Failed to extract logits tensor: {}", e)))?;
345
346 if shape.len() != 3 || shape[0] != 1 {
348 return Err(Error::Parse(format!(
349 "Unexpected logits shape: {:?}",
350 shape
351 )));
352 }
353
354 let seq_len = shape[1] as usize;
355 let num_labels = shape[2] as usize;
356
357 let offsets = encoding.get_offsets();
359
360 let span_converter = crate::offset::SpanConverter::new(text);
363
364 let get_logit = |token_idx: usize, label_idx: usize| -> f32 {
366 logits_data[token_idx * num_labels + label_idx]
367 };
368
369 let mut entities = Vec::with_capacity(16);
372 let mut current_entity: Option<(usize, usize, EntityType, f64)> = None; for token_idx in 0..seq_len {
375 if token_idx >= offsets.len() {
377 continue;
378 }
379 let (byte_start, byte_end) = offsets[token_idx];
380 if byte_start == byte_end {
381 if let Some((start, end, entity_type, conf)) = current_entity.take() {
383 if start < end && end <= text.len() {
384 if let Some(entity_text) = text.get(start..end) {
385 let entity_text = entity_text.trim();
386 if !entity_text.is_empty() {
387 entities.push(Entity::new(
388 entity_text.to_string(),
389 entity_type,
390 span_converter.byte_to_char(start),
391 span_converter.byte_to_char(end),
392 conf,
393 ));
394 }
395 }
396 }
397 }
398 continue;
399 }
400
401 let mut max_idx = 0;
403 let mut max_val = f32::NEG_INFINITY;
404 for label_idx in 0..num_labels {
405 let val = get_logit(token_idx, label_idx);
406 if val > max_val {
407 max_val = val;
408 max_idx = label_idx;
409 }
410 }
411
412 let exp_sum: f32 = (0..num_labels)
414 .map(|i| (get_logit(token_idx, i) - max_val).exp())
415 .sum();
416 let confidence = if exp_sum > 0.0 && num_labels > 0 {
418 (1.0_f32 / exp_sum) as f64 } else {
420 0.0 };
422
423 let label = self
424 .id_to_label
425 .get(&max_idx)
426 .cloned()
427 .unwrap_or_else(|| format!("LABEL_{}", max_idx));
428
429 if label == "O" {
431 if let Some((start, end, entity_type, conf)) = current_entity.take() {
432 if start < end && end <= text.len() {
433 if let Some(entity_text) = text.get(start..end) {
434 let entity_text = entity_text.trim();
435 if !entity_text.is_empty() {
436 entities.push(Entity::new(
437 entity_text.to_string(),
438 entity_type,
439 span_converter.byte_to_char(start),
440 span_converter.byte_to_char(end),
441 conf,
442 ));
443 }
444 }
445 }
446 }
447 continue;
448 }
449
450 let (bio, entity_label) = if label.starts_with("B-") {
452 ("B", label[2..].to_string())
453 } else if label.starts_with("I-") {
454 ("I", label[2..].to_string())
455 } else {
456 ("B", label.clone())
457 };
458
459 let entity_type = self
460 .label_to_entity_type
461 .get(&format!("B-{}", entity_label))
462 .or_else(|| self.label_to_entity_type.get(&entity_label))
463 .cloned()
464 .unwrap_or_else(|| EntityType::Other(entity_label.clone()));
465
466 match bio {
467 "B" => {
468 let should_merge = if let Some((_, prev_end, ref prev_type, _)) = current_entity
472 {
473 std::mem::discriminant(prev_type) == std::mem::discriminant(&entity_type)
475 && byte_start <= prev_end + 1 } else {
477 false
478 };
479
480 if should_merge {
481 if let Some((start, _, prev_type, conf)) = current_entity.take() {
483 current_entity = Some((start, byte_end, prev_type, conf));
484 }
485 } else {
486 if let Some((start, end, prev_type, conf)) = current_entity.take() {
488 if start < end && end <= text.len() {
489 if let Some(entity_text) = text.get(start..end) {
490 let entity_text = entity_text.trim();
491 if !entity_text.is_empty() {
492 entities.push(Entity::new(
493 entity_text.to_string(),
494 prev_type,
495 span_converter.byte_to_char(start),
496 span_converter.byte_to_char(end),
497 conf,
498 ));
499 }
500 }
501 }
502 }
503 current_entity = Some((byte_start, byte_end, entity_type, confidence));
505 }
506 }
507 "I" => {
508 if let Some((start, _end, ref prev_type, conf)) = current_entity {
510 if std::mem::discriminant(prev_type) == std::mem::discriminant(&entity_type)
511 {
512 current_entity = Some((start, byte_end, entity_type, conf));
513 } else {
514 if start < _end && _end <= text.len() {
516 if let Some(entity_text) = text.get(start.._end) {
517 let entity_text = entity_text.trim();
518 if !entity_text.is_empty() {
519 entities.push(Entity::new(
520 entity_text.to_string(),
521 prev_type.clone(),
522 span_converter.byte_to_char(start),
523 span_converter.byte_to_char(_end),
524 conf,
525 ));
526 }
527 }
528 }
529 current_entity = Some((byte_start, byte_end, entity_type, confidence));
530 }
531 } else {
532 current_entity = Some((byte_start, byte_end, entity_type, confidence));
534 }
535 }
536 _ => {}
537 }
538 }
539
540 if let Some((start, end, entity_type, conf)) = current_entity {
542 if start < end && end <= text.len() {
543 if let Some(entity_text) = text.get(start..end) {
544 let entity_text = entity_text.trim();
545 if !entity_text.is_empty() {
546 entities.push(Entity::new(
547 entity_text.to_string(),
548 entity_type,
549 span_converter.byte_to_char(start),
550 span_converter.byte_to_char(end),
551 conf,
552 ));
553 }
554 }
555 }
556 }
557
558 Ok(entities)
559 }
560
561 pub fn model_name(&self) -> &str {
563 &self.model_name
564 }
565}
566
567#[cfg(feature = "onnx")]
568impl crate::Model for BertNEROnnx {
569 fn extract_entities(&self, text: &str, language: Option<&str>) -> Result<Vec<Entity>> {
570 self.extract_entities(text, language)
571 }
572
573 fn supported_types(&self) -> Vec<EntityType> {
574 vec![
575 EntityType::Person,
576 EntityType::Organization,
577 EntityType::Location,
578 EntityType::Other("MISC".to_string()),
579 ]
580 }
581
582 fn is_available(&self) -> bool {
583 true
584 }
585
586 fn name(&self) -> &'static str {
587 "bert-onnx"
588 }
589
590 fn description(&self) -> &'static str {
591 "BERT-based NER using ONNX Runtime (PER/ORG/LOC/MISC)"
592 }
593
594 fn version(&self) -> String {
595 format!(
596 "bert-onnx-{}-{}",
597 self.model_name,
598 if self.is_quantized { "q" } else { "fp32" }
599 )
600 }
601
602 fn capabilities(&self) -> crate::ModelCapabilities {
603 crate::ModelCapabilities {
604 batch_capable: true,
605 streaming_capable: true,
606 ..Default::default()
607 }
608 }
609}
610
611impl crate::NamedEntityCapable for BertNEROnnx {}
612
613#[cfg(feature = "onnx")]
618impl crate::BatchCapable for BertNEROnnx {
619 fn optimal_batch_size(&self) -> Option<usize> {
620 Some(8)
621 }
622}
623
624#[cfg(feature = "onnx")]
629impl crate::StreamingCapable for BertNEROnnx {
630 fn recommended_chunk_size(&self) -> usize {
631 512 }
633}
634
635#[cfg(not(feature = "onnx"))]
637pub struct BertNEROnnx;
638
639#[cfg(not(feature = "onnx"))]
640impl BertNEROnnx {
641 pub fn new(_model_name: &str) -> Result<Self> {
642 Err(Error::Parse(
643 "BERT NER ONNX support requires 'onnx' feature".to_string(),
644 ))
645 }
646
647 pub fn extract_entities(&self, _text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
648 Err(Error::Parse(
649 "BERT NER ONNX support requires 'onnx' feature".to_string(),
650 ))
651 }
652
653 pub fn model_name(&self) -> &str {
654 "onnx-not-enabled"
655 }
656}
657
658#[cfg(not(feature = "onnx"))]
659impl crate::Model for BertNEROnnx {
660 fn extract_entities(&self, _text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
661 Err(Error::Parse(
662 "BERT NER ONNX support requires 'onnx' feature".to_string(),
663 ))
664 }
665
666 fn supported_types(&self) -> Vec<anno_core::EntityType> {
667 vec![]
668 }
669
670 fn is_available(&self) -> bool {
671 false
672 }
673}