1use crate::{Entity, EntityType, Model, Result};
74
75use crate::Error;
76
77#[cfg(feature = "onnx")]
79type EncodedPrompt = (Vec<i64>, Vec<i64>, Vec<i64>, i64);
80
81#[cfg(feature = "onnx")]
83const TOKEN_START: u32 = 1;
84#[cfg(feature = "onnx")]
85const TOKEN_END: u32 = 2;
86#[cfg(feature = "onnx")]
87const TOKEN_ENT: u32 = 128002;
88#[cfg(feature = "onnx")]
89const TOKEN_SEP: u32 = 128003;
90
91#[cfg(feature = "onnx")]
95const MAX_SPAN_WIDTH: usize = 1;
96
97pub struct NuNER {
120 model_id: String,
122 threshold: f64,
124 #[cfg(feature = "onnx")]
126 requires_span_tensors: std::sync::atomic::AtomicBool,
127 default_labels: Vec<String>,
129 #[cfg(feature = "onnx")]
131 session: Option<crate::sync::Mutex<ort::session::Session>>,
132 #[cfg(feature = "onnx")]
134 tokenizer: Option<tokenizers::Tokenizer>,
135}
136
137impl NuNER {
138 #[must_use]
143 pub fn new() -> Self {
144 Self {
145 model_id: "numind/NuNER_Zero".to_string(),
146 threshold: 0.5,
147 #[cfg(feature = "onnx")]
148 requires_span_tensors: std::sync::atomic::AtomicBool::new(false),
149 default_labels: vec![
150 "person".to_string(),
151 "organization".to_string(),
152 "location".to_string(),
153 "date".to_string(),
154 "product".to_string(),
155 "event".to_string(),
156 ],
157 #[cfg(feature = "onnx")]
158 session: None,
159 #[cfg(feature = "onnx")]
160 tokenizer: None,
161 }
162 }
163
164 #[cfg(feature = "onnx")]
176 pub fn from_pretrained(model_id: &str) -> Result<Self> {
177 use hf_hub::api::sync::{Api, ApiBuilder};
178 use ort::execution_providers::CPUExecutionProvider;
179 use ort::session::Session;
180
181 crate::env::load_dotenv();
183
184 let api = if let Some(token) = crate::env::hf_token() {
185 ApiBuilder::new()
186 .with_token(Some(token))
187 .build()
188 .map_err(|e| Error::Retrieval(format!("HuggingFace API with token: {}", e)))?
189 } else {
190 Api::new().map_err(|e| {
191 Error::Retrieval(format!("Failed to initialize HuggingFace API: {}", e))
192 })?
193 };
194
195 let repo = api.model(model_id.to_string());
196
197 let model_path = repo
199 .get("onnx/model.onnx")
200 .or_else(|_| repo.get("model.onnx"))
201 .map_err(|e| Error::Retrieval(format!("Failed to download model.onnx: {}", e)))?;
202
203 let tokenizer_path = repo
204 .get("tokenizer.json")
205 .map_err(|e| Error::Retrieval(format!("Failed to download tokenizer.json: {}", e)))?;
206
207 let session = Session::builder()
208 .map_err(|e| Error::Retrieval(format!("Failed to create ONNX session: {}", e)))?
209 .with_execution_providers([CPUExecutionProvider::default().build()])
210 .map_err(|e| Error::Retrieval(format!("Failed to set execution providers: {}", e)))?
211 .commit_from_file(&model_path)
212 .map_err(|e| Error::Retrieval(format!("Failed to load ONNX model: {}", e)))?;
213
214 let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
215 .map_err(|e| Error::Retrieval(format!("Failed to load tokenizer: {}", e)))?;
216
217 Ok(Self {
218 model_id: model_id.to_string(),
219 threshold: 0.5,
220 requires_span_tensors: std::sync::atomic::AtomicBool::new(false),
221 default_labels: vec![
222 "person".to_string(),
223 "organization".to_string(),
224 "location".to_string(),
225 ],
226 session: Some(crate::sync::Mutex::new(session)),
227 tokenizer: Some(tokenizer),
228 })
229 }
230
231 #[must_use]
233 pub fn with_model(model_id: impl Into<String>) -> Self {
234 let mut new = Self::new();
235 new.model_id = model_id.into();
236 new
237 }
238
239 #[must_use]
241 pub fn with_threshold(mut self, threshold: f64) -> Self {
242 self.threshold = threshold.clamp(0.0, 1.0);
243 self
244 }
245
246 #[must_use]
248 pub fn with_labels(mut self, labels: Vec<String>) -> Self {
249 self.default_labels = labels;
250 self
251 }
252
253 #[must_use]
255 pub fn model_id(&self) -> &str {
256 &self.model_id
257 }
258
259 #[must_use]
261 pub fn threshold(&self) -> f64 {
262 self.threshold
263 }
264
265 #[cfg(feature = "onnx")]
275 pub fn extract(
276 &self,
277 text: &str,
278 entity_types: &[&str],
279 threshold: f32,
280 ) -> Result<Vec<Entity>> {
281 if text.is_empty() || entity_types.is_empty() {
282 return Ok(vec![]);
283 }
284
285 if std::env::var("ANNO_DEBUG_NUNER_EXTRACT").is_ok() {
287 eprintln!(
288 "DEBUG nuner extract: text.len={} entity_types={:?}",
289 text.len(),
290 entity_types
291 );
292 }
293
294 let session = self.session.as_ref().ok_or_else(|| {
295 Error::Retrieval("Model not loaded. Call from_pretrained() first.".to_string())
296 })?;
297
298 let tokenizer = self
299 .tokenizer
300 .as_ref()
301 .ok_or_else(|| Error::Retrieval("Tokenizer not loaded.".to_string()))?;
302
303 let text_words: Vec<&str> = text.split_whitespace().collect();
305 if text_words.is_empty() {
306 return Ok(vec![]);
307 }
308
309 let (input_ids, attention_mask, words_mask, text_lengths) =
311 self.encode_prompt(tokenizer, &text_words, entity_types)?;
312
313 let batch_size = 1;
314 let seq_len = input_ids.len();
315
316 let make_token_tensors = || -> Result<(_, _, _, _)> {
318 use ndarray::Array2;
319
320 let input_ids_array = Array2::from_shape_vec((batch_size, seq_len), input_ids.clone())
321 .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
322 let attention_mask_array =
323 Array2::from_shape_vec((batch_size, seq_len), attention_mask.clone())
324 .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
325 let words_mask_array =
326 Array2::from_shape_vec((batch_size, seq_len), words_mask.clone())
327 .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
328 let text_lengths_array = Array2::from_shape_vec((batch_size, 1), vec![text_lengths])
329 .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
330
331 let input_ids_t = super::ort_compat::tensor_from_ndarray(input_ids_array)
332 .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
333 let attention_mask_t = super::ort_compat::tensor_from_ndarray(attention_mask_array)
334 .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
335 let words_mask_t = super::ort_compat::tensor_from_ndarray(words_mask_array)
336 .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
337 let text_lengths_t = super::ort_compat::tensor_from_ndarray(text_lengths_array)
338 .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
339
340 Ok((input_ids_t, attention_mask_t, words_mask_t, text_lengths_t))
341 };
342
343 use std::sync::atomic::Ordering;
346 let mut needs_span_tensors = self.requires_span_tensors.load(Ordering::Relaxed);
347
348 let mut session_guard = crate::sync::lock(session);
350
351 let outputs = loop {
352 if needs_span_tensors {
353 let (input_ids_t, attention_mask_t, words_mask_t, text_lengths_t) =
354 make_token_tensors()?;
355 let num_spans = match text_words.len().checked_mul(MAX_SPAN_WIDTH) {
358 Some(v) => v,
359 None => {
360 return Err(Error::InvalidInput(format!(
361 "Span count overflow: {} words * {} MAX_SPAN_WIDTH",
362 text_words.len(),
363 MAX_SPAN_WIDTH
364 )));
365 }
366 };
367 let (span_idx, span_mask) = NuNER::make_span_tensors(text_words.len());
368
369 use ndarray::Array2;
370 use ndarray::Array3;
371 let span_idx_array = Array3::from_shape_vec((1, num_spans, 2), span_idx)
372 .map_err(|e| Error::Parse(format!("Span idx array error: {}", e)))?;
373 let span_mask_array = Array2::from_shape_vec((1, num_spans), span_mask)
374 .map_err(|e| Error::Parse(format!("Span mask array error: {}", e)))?;
375
376 let span_idx_t = super::ort_compat::tensor_from_ndarray(span_idx_array)
377 .map_err(|e| Error::Parse(format!("Span idx tensor error: {}", e)))?;
378 let span_mask_t = super::ort_compat::tensor_from_ndarray(span_mask_array)
379 .map_err(|e| Error::Parse(format!("Span mask tensor error: {}", e)))?;
380
381 break session_guard
382 .run(ort::inputs![
383 "input_ids" => input_ids_t.into_dyn(),
384 "attention_mask" => attention_mask_t.into_dyn(),
385 "words_mask" => words_mask_t.into_dyn(),
386 "text_lengths" => text_lengths_t.into_dyn(),
387 "span_idx" => span_idx_t.into_dyn(),
388 "span_mask" => span_mask_t.into_dyn(),
389 ])
390 .map_err(|e| {
391 Error::Parse(format!(
392 "ONNX inference failed: {}\n\n\
393 NuNER model: {}\n\
394 requires_span_tensors={}\n\
395 input_ids=(1,{seq_len}) attention_mask=(1,{seq_len}) words_mask=(1,{seq_len}) text_lengths=(1,1)\n\
396 span_idx=(1,{num_spans},2) span_mask=(1,{num_spans})\n\n\
397 Hint: If this looks like a shape mismatch, the ONNX export may have fixed span dimensions.\n\
398 Try a different NuNER export (e.g., deepanwa/NuNerZero_onnx) or re-export with dynamic axes.",
399 e,
400 self.model_id,
401 self.requires_span_tensors.load(Ordering::Relaxed)
402 ))
403 })?;
404 } else {
405 let (input_ids_t, attention_mask_t, words_mask_t, text_lengths_t) =
406 make_token_tensors()?;
407 let res = session_guard.run(ort::inputs![
409 "input_ids" => input_ids_t.into_dyn(),
410 "attention_mask" => attention_mask_t.into_dyn(),
411 "words_mask" => words_mask_t.into_dyn(),
412 "text_lengths" => text_lengths_t.into_dyn(),
413 ]);
414
415 match res {
416 Ok(o) => break o,
417 Err(e) => {
418 let msg = format!("{e}");
419 let looks_like_missing_span = msg.contains("Missing Input: span_mask")
420 || msg.contains("Missing Input: span_idx")
421 || msg.contains("span_mask")
422 || msg.contains("span_idx");
423 if looks_like_missing_span {
424 self.requires_span_tensors.store(true, Ordering::Relaxed);
426 needs_span_tensors = true;
427 continue;
428 }
429 return Err(Error::Parse(format!(
430 "ONNX inference failed: {}\n\n\
431 NuNER model: {}\n\
432 requires_span_tensors={}\n\
433 input_ids=(1,{seq_len}) attention_mask=(1,{seq_len}) words_mask=(1,{seq_len}) text_lengths=(1,1)\n\n\
434 Hint: If this looks like an input-name mismatch, your ONNX export may expect span tensors or different input names.",
435 e,
436 self.model_id,
437 self.requires_span_tensors.load(Ordering::Relaxed),
438 )));
439 }
440 }
441 }
442 };
443
444 let entities =
447 self.decode_span_output(&outputs, text, &text_words, entity_types, threshold)?;
448
449 Ok(entities)
450 }
451
452 #[cfg(feature = "onnx")]
464 pub(crate) fn make_span_tensors(num_words: usize) -> (Vec<i64>, Vec<bool>) {
465 let num_spans = match num_words.checked_mul(MAX_SPAN_WIDTH) {
467 Some(v) => v,
468 None => {
469 log::warn!(
471 "Span count overflow: {} words * {} MAX_SPAN_WIDTH, returning empty tensors",
472 num_words,
473 MAX_SPAN_WIDTH
474 );
475 return (Vec::new(), Vec::new());
476 }
477 };
478 let span_idx_len = match num_spans.checked_mul(2) {
480 Some(v) => v,
481 None => {
482 log::warn!(
483 "Span idx length overflow: {} spans * 2, returning empty tensors",
484 num_spans
485 );
486 return (Vec::new(), Vec::new());
487 }
488 };
489 let mut span_idx: Vec<i64> = vec![0; span_idx_len];
490 let mut span_mask: Vec<bool> = vec![false; num_spans];
491
492 for start in 0..num_words {
493 let remaining_width = num_words - start;
494 let actual_max_width = MAX_SPAN_WIDTH.min(remaining_width);
495
496 for width in 0..actual_max_width {
497 let dim = match start.checked_mul(MAX_SPAN_WIDTH) {
499 Some(v) => match v.checked_add(width) {
500 Some(d) => d,
501 None => {
502 log::warn!(
503 "Dim calculation overflow: {} * {} + {}, skipping span",
504 start,
505 MAX_SPAN_WIDTH,
506 width
507 );
508 continue;
509 }
510 },
511 None => {
512 log::warn!(
513 "Dim calculation overflow: {} * {}, skipping span",
514 start,
515 MAX_SPAN_WIDTH
516 );
517 continue;
518 }
519 };
520 if let Some(dim2) = dim.checked_mul(2) {
522 if dim2 + 1 < span_idx_len && dim < num_spans {
523 span_idx[dim2] = start as i64; span_idx[dim2 + 1] = (start + width) as i64; span_mask[dim] = true;
526 } else {
527 log::warn!(
528 "Span idx access out of bounds: dim={}, dim*2={}, span_idx_len={}, num_spans={}, skipping",
529 dim, dim2, span_idx_len, num_spans
530 );
531 }
532 } else {
533 log::warn!("Dim * 2 overflow: dim={}, skipping span", dim);
534 }
535 }
536 }
537
538 (span_idx, span_mask)
539 }
540
541 #[cfg(feature = "onnx")]
543 fn encode_prompt(
544 &self,
545 tokenizer: &tokenizers::Tokenizer,
546 text_words: &[&str],
547 entity_types: &[&str],
548 ) -> Result<EncodedPrompt> {
549 let mut input_ids: Vec<i64> = Vec::with_capacity(128);
552 let mut word_mask: Vec<i64> = Vec::with_capacity(128);
553
554 input_ids.push(TOKEN_START as i64);
556 word_mask.push(0);
557
558 for entity_type in entity_types {
560 input_ids.push(TOKEN_ENT as i64);
561 word_mask.push(0);
562
563 let encoding = tokenizer
564 .encode(entity_type.to_string(), false)
565 .map_err(|e| Error::Parse(format!("Tokenizer error: {}", e)))?;
566 for token_id in encoding.get_ids() {
567 input_ids.push(*token_id as i64);
568 word_mask.push(0);
569 }
570 }
571
572 input_ids.push(TOKEN_SEP as i64);
574 word_mask.push(0);
575
576 let mut word_id: i64 = 0;
578 for word in text_words {
579 let encoding = tokenizer
580 .encode(word.to_string(), false)
581 .map_err(|e| Error::Parse(format!("Tokenizer error: {}", e)))?;
582
583 word_id += 1;
584 for (token_idx, token_id) in encoding.get_ids().iter().enumerate() {
585 input_ids.push(*token_id as i64);
586 word_mask.push(if token_idx == 0 { word_id } else { 0 });
587 }
588 }
589
590 input_ids.push(TOKEN_END as i64);
592 word_mask.push(0);
593
594 let seq_len = input_ids.len();
595 let attention_mask: Vec<i64> = vec![1; seq_len];
596
597 Ok((input_ids, attention_mask, word_mask, word_id))
598 }
599
600 #[cfg(feature = "onnx")]
605 fn decode_token_output(
606 &self,
607 outputs: &ort::session::SessionOutputs,
608 text: &str,
609 text_words: &[&str],
610 entity_types: &[&str],
611 threshold: f32,
612 ) -> Result<Vec<Entity>> {
613 let output = outputs
614 .iter()
615 .next()
616 .map(|(_, v)| v)
617 .ok_or_else(|| Error::Parse("No output from NuNER model".to_string()))?;
618
619 let (_, data_slice) = output
620 .try_extract_tensor::<f32>()
621 .map_err(|e| Error::Parse(format!("Failed to extract output tensor: {}", e)))?;
622 let output_data: Vec<f32> = data_slice.to_vec();
623
624 let shape: Vec<i64> = match output.dtype() {
626 ort::value::ValueType::Tensor { shape, .. } => shape.iter().copied().collect(),
627 _ => return Err(Error::Parse("Expected tensor output".to_string())),
628 };
629
630 if std::env::var("ANNO_DEBUG_NUNER_DECODE").is_ok() {
632 eprintln!(
633 "DEBUG nuner decode: shape={:?} text_words.len={} data.len={}",
634 shape,
635 text_words.len(),
636 output_data.len()
637 );
638 let sample: Vec<f32> = output_data.iter().take(10).copied().collect();
640 eprintln!("DEBUG nuner decode: sample data={:?}", sample);
641 }
642
643 if shape.len() < 3 {
644 return Err(Error::Parse(format!(
645 "Unexpected output shape: {:?}",
646 shape
647 )));
648 }
649
650 let num_words = shape[1] as usize;
651 let num_classes = shape[2] as usize;
652
653 if std::env::var("ANNO_DEBUG_NUNER_DECODE").is_ok() {
654 eprintln!(
655 "DEBUG nuner decode: num_words={} num_classes={} entity_types.len={}",
656 num_words,
657 num_classes,
658 entity_types.len()
659 );
660 }
661
662 let word_positions: Vec<(usize, usize)> = {
665 let mut positions = Vec::with_capacity(text_words.len());
667 let mut pos = 0;
668 for (idx, word) in text_words.iter().enumerate() {
669 if let Some(start) = text[pos..].find(word) {
670 let abs_start = pos + start;
671 let abs_end = abs_start + word.len();
672 if !positions.is_empty() {
674 let (_prev_start, prev_end) = positions[positions.len() - 1];
675 if abs_start < prev_end {
676 log::warn!(
677 "Word '{}' at position {} overlaps with previous word ending at {}",
678 word,
679 abs_start,
680 prev_end
681 );
682 }
683 }
684 positions.push((abs_start, abs_end));
685 pos = abs_end;
686 } else {
687 return Err(Error::Parse(format!(
689 "Word '{}' (index {}) not found in text starting at position {}",
690 word, idx, pos
691 )));
692 }
693 }
694 positions
695 };
696
697 if word_positions.len() != text_words.len() {
699 return Err(Error::Parse(format!(
700 "Word position mismatch: found {} positions for {} words",
701 word_positions.len(),
702 text_words.len()
703 )));
704 }
705
706 let span_converter = crate::offset::SpanConverter::new(text);
708
709 let mut entities = Vec::with_capacity(16);
711 let mut current_entity: Option<(usize, usize, usize, f32)> = None; for word_idx in 0..num_words.min(text_words.len()) {
715 let base_idx = word_idx * num_classes;
716
717 let mut best_class = 0;
719 let mut best_score = 0.0f32;
720
721 for class_idx in 0..num_classes {
722 let score = output_data
723 .get(base_idx + class_idx)
724 .copied()
725 .unwrap_or(0.0);
726 if score > best_score {
727 best_score = score;
728 best_class = class_idx;
729 }
730 }
731
732 let is_begin = best_class > 0 && best_class % 2 == 1;
734 let is_inside = best_class > 0 && best_class % 2 == 0;
735 let type_idx = if best_class > 0 {
736 (best_class - 1) / 2
737 } else {
738 0
739 };
740
741 if best_score >= threshold {
742 if is_begin {
743 if let Some((start, end, etype, score)) = current_entity.take() {
745 if let Some(e) = self.create_entity(
746 text,
747 &span_converter,
748 &word_positions,
749 start,
750 end,
751 etype,
752 score,
753 entity_types,
754 ) {
755 entities.push(e);
756 }
757 }
758 current_entity = Some((word_idx, word_idx + 1, type_idx, best_score));
760 } else if is_inside {
761 if let Some((_start, end, etype, score)) = current_entity.as_mut() {
763 if *etype == type_idx {
764 *end = word_idx + 1;
765 *score = (*score + best_score) / 2.0; }
767 }
768 }
769 } else {
770 if let Some((start, end, etype, score)) = current_entity.take() {
772 if let Some(e) = self.create_entity(
773 text,
774 &span_converter,
775 &word_positions,
776 start,
777 end,
778 etype,
779 score,
780 entity_types,
781 ) {
782 entities.push(e);
783 }
784 }
785 }
786 }
787
788 if let Some((start, end, etype, score)) = current_entity.take() {
790 if let Some(e) = self.create_entity(
791 text,
792 &span_converter,
793 &word_positions,
794 start,
795 end,
796 etype,
797 score,
798 entity_types,
799 ) {
800 entities.push(e);
801 }
802 }
803
804 Ok(entities)
805 }
806
807 #[cfg(feature = "onnx")]
813 fn decode_span_output(
814 &self,
815 outputs: &ort::session::SessionOutputs,
816 text: &str,
817 text_words: &[&str],
818 entity_types: &[&str],
819 threshold: f32,
820 ) -> Result<Vec<Entity>> {
821 let logits_output = outputs
823 .iter()
824 .find(|(name, _)| name.contains("logits"))
825 .map(|(_, v)| v)
826 .or_else(|| outputs.iter().next().map(|(_, v)| v))
827 .ok_or_else(|| Error::Parse("No logits output from NuNER model".to_string()))?;
828
829 let (_, data_slice) = logits_output
830 .try_extract_tensor::<f32>()
831 .map_err(|e| Error::Parse(format!("Failed to extract output tensor: {}", e)))?;
832 let output_data: Vec<f32> = data_slice.to_vec();
833
834 let shape: Vec<i64> = match logits_output.dtype() {
836 ort::value::ValueType::Tensor { shape, .. } => shape.iter().copied().collect(),
837 _ => return Err(Error::Parse("Expected tensor output".to_string())),
838 };
839
840 if shape.len() != 4 {
841 return self.decode_token_output(outputs, text, text_words, entity_types, threshold);
843 }
844
845 let num_words = shape[1] as usize;
846 let max_width = shape[2] as usize; let num_classes = shape[3] as usize;
848
849 if std::env::var("ANNO_DEBUG_NUNER_DECODE").is_ok() {
851 eprintln!(
852 "DEBUG nuner decode_span: shape={:?} num_words={} max_width={} num_classes={} entity_types.len={}",
853 shape, num_words, max_width, num_classes, entity_types.len()
854 );
855 }
856
857 let word_positions: Vec<(usize, usize)> = {
859 let mut positions = Vec::with_capacity(text_words.len());
860 let mut pos = 0;
861 for word in text_words.iter() {
862 if let Some(start) = text[pos..].find(word) {
863 let abs_start = pos + start;
864 let abs_end = abs_start + word.len();
865 positions.push((abs_start, abs_end));
866 pos = abs_end;
867 } else {
868 return Err(Error::Parse(format!(
870 "Word '{}' not found in text starting at position {}",
871 word, pos
872 )));
873 }
874 }
875 positions
876 };
877
878 let span_converter = crate::offset::SpanConverter::new(text);
880
881 let mut entities = Vec::with_capacity(16);
882 let mut current_entity: Option<(usize, usize, usize, f32)> = None; for word_idx in 0..num_words.min(text_words.len()) {
886 let base_idx = word_idx * max_width * num_classes;
889
890 let mut best_class: Option<usize> = None;
892 let mut best_prob = 0.0f32;
893
894 for class_idx in 0..num_classes {
895 let logit = output_data
896 .get(base_idx + class_idx)
897 .copied()
898 .unwrap_or(f32::NEG_INFINITY);
899 let prob = 1.0 / (1.0 + (-logit).exp());
901
902 if prob >= threshold && prob > best_prob {
903 best_prob = prob;
904 best_class = Some(class_idx);
905 }
906 }
907
908 if let Some(class_idx) = best_class {
909 if let Some((start, end, etype, score)) = current_entity.as_mut() {
911 if *etype == class_idx {
912 *end = word_idx + 1;
914 *score = (*score + best_prob) / 2.0;
915 } else {
916 if let Some(e) = self.create_entity(
918 text,
919 &span_converter,
920 &word_positions,
921 *start,
922 *end,
923 *etype,
924 *score,
925 entity_types,
926 ) {
927 entities.push(e);
928 }
929 current_entity = Some((word_idx, word_idx + 1, class_idx, best_prob));
930 }
931 } else {
932 current_entity = Some((word_idx, word_idx + 1, class_idx, best_prob));
934 }
935 } else {
936 if let Some((start, end, etype, score)) = current_entity.take() {
938 if let Some(e) = self.create_entity(
939 text,
940 &span_converter,
941 &word_positions,
942 start,
943 end,
944 etype,
945 score,
946 entity_types,
947 ) {
948 entities.push(e);
949 }
950 }
951 }
952 }
953
954 if let Some((start, end, etype, score)) = current_entity.take() {
956 if let Some(e) = self.create_entity(
957 text,
958 &span_converter,
959 &word_positions,
960 start,
961 end,
962 etype,
963 score,
964 entity_types,
965 ) {
966 entities.push(e);
967 }
968 }
969
970 if std::env::var("ANNO_DEBUG_NUNER_DECODE").is_ok() {
971 eprintln!("DEBUG nuner decode_span: found {} entities", entities.len());
972 }
973
974 Ok(entities)
975 }
976
977 #[cfg(feature = "onnx")]
978 #[allow(clippy::too_many_arguments)]
979 fn create_entity(
980 &self,
981 text: &str,
982 span_converter: &crate::offset::SpanConverter,
983 word_positions: &[(usize, usize)],
984 start_word: usize,
985 end_word: usize,
986 type_idx: usize,
987 score: f32,
988 entity_types: &[&str],
989 ) -> Option<Entity> {
990 if end_word == 0 || end_word > word_positions.len() || start_word >= word_positions.len() {
992 return None;
993 }
994 let start_pos = word_positions.get(start_word)?.0;
995 let end_pos = word_positions.get(end_word.saturating_sub(1))?.1;
996
997 let entity_text = text.get(start_pos..end_pos)?;
998 let label = entity_types.get(type_idx)?;
999 let entity_type = Self::map_label_to_entity_type(label);
1000
1001 let char_start = span_converter.byte_to_char(start_pos);
1002 let char_end = span_converter.byte_to_char(end_pos);
1003
1004 Some(Entity::new(
1005 entity_text,
1006 entity_type,
1007 char_start,
1008 char_end,
1009 score as f64,
1010 ))
1011 }
1012
1013 fn map_label_to_entity_type(label: &str) -> EntityType {
1015 match label.to_lowercase().as_str() {
1016 "person" | "per" => EntityType::Person,
1017 "organization" | "org" | "company" => EntityType::Organization,
1018 "location" | "loc" | "place" | "gpe" => EntityType::Location,
1019 "date" => EntityType::Date,
1020 "time" => EntityType::Time,
1021 "money" | "currency" => EntityType::Money,
1022 "percent" | "percentage" => EntityType::Percent,
1023 _ => EntityType::Other(label.to_string()),
1024 }
1025 }
1026}
1027
1028impl Default for NuNER {
1029 fn default() -> Self {
1030 Self::new()
1031 }
1032}
1033
1034impl Model for NuNER {
1035 fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
1036 if text.trim().is_empty() {
1037 return Ok(vec![]);
1038 }
1039
1040 #[cfg(feature = "onnx")]
1041 {
1042 if self.session.is_some() {
1043 let labels: Vec<&str> = self.default_labels.iter().map(|s| s.as_str()).collect();
1044 return self.extract(text, &labels, self.threshold as f32);
1045 }
1046
1047 Err(Error::ModelInit(
1048 "NuNER model not loaded. Call `NuNER::from_pretrained(...)` (requires `onnx` feature) before calling `extract_entities`.".to_string(),
1049 ))
1050 }
1051
1052 #[cfg(not(feature = "onnx"))]
1053 {
1054 Err(Error::FeatureNotAvailable(
1055 "NuNER requires the 'onnx' feature. Build with: cargo build --features onnx"
1056 .to_string(),
1057 ))
1058 }
1059 }
1060
1061 fn supported_types(&self) -> Vec<EntityType> {
1062 self.default_labels
1063 .iter()
1064 .map(|l| Self::map_label_to_entity_type(l))
1065 .collect()
1066 }
1067
1068 fn is_available(&self) -> bool {
1069 #[cfg(feature = "onnx")]
1070 {
1071 self.session.is_some()
1072 }
1073 #[cfg(not(feature = "onnx"))]
1074 {
1075 false
1076 }
1077 }
1078
1079 fn name(&self) -> &'static str {
1080 "nuner"
1081 }
1082
1083 fn description(&self) -> &'static str {
1084 "NuNER Zero: Token-based zero-shot NER from NuMind (MIT licensed)"
1085 }
1086
1087 fn version(&self) -> String {
1088 format!("nuner-zero-{}", self.model_id)
1089 }
1090}
1091
1092impl crate::BatchCapable for NuNER {
1097 fn optimal_batch_size(&self) -> Option<usize> {
1098 Some(8)
1099 }
1100}
1101
1102impl crate::StreamingCapable for NuNER {
1107 fn recommended_chunk_size(&self) -> usize {
1108 4096 }
1110}
1111
1112#[cfg(test)]
1113mod tests {
1114 use super::*;
1115
1116 #[test]
1117 fn test_nuner_creation() {
1118 let ner = NuNER::new();
1119 assert_eq!(ner.model_id(), "numind/NuNER_Zero");
1120 assert!((ner.threshold() - 0.5).abs() < f64::EPSILON);
1121 }
1122
1123 #[test]
1124 fn test_nuner_with_custom_model() {
1125 let ner = NuNER::with_model("custom/model")
1126 .with_threshold(0.7)
1127 .with_labels(vec!["technology".to_string()]);
1128
1129 assert_eq!(ner.model_id(), "custom/model");
1130 assert!((ner.threshold() - 0.7).abs() < f64::EPSILON);
1131 assert_eq!(ner.default_labels.len(), 1);
1132 }
1133
1134 #[test]
1135 fn test_label_mapping() {
1136 assert_eq!(
1137 NuNER::map_label_to_entity_type("person"),
1138 EntityType::Person
1139 );
1140 assert_eq!(NuNER::map_label_to_entity_type("PER"), EntityType::Person);
1141 assert_eq!(
1142 NuNER::map_label_to_entity_type("organization"),
1143 EntityType::Organization
1144 );
1145 assert_eq!(
1146 NuNER::map_label_to_entity_type("custom"),
1147 EntityType::Other("custom".to_string())
1148 );
1149 }
1150
1151 #[test]
1152 fn test_supported_types() {
1153 let ner = NuNER::new();
1154 let types = ner.supported_types();
1155 assert!(types.contains(&EntityType::Person));
1156 assert!(types.contains(&EntityType::Organization));
1157 assert!(types.contains(&EntityType::Location));
1158 }
1159
1160 #[test]
1161 fn test_empty_input() {
1162 let ner = NuNER::new();
1163 let entities = ner.extract_entities("", None).unwrap();
1164 assert!(entities.is_empty());
1165 }
1166
1167 #[test]
1168 fn test_not_available_without_model() {
1169 let ner = NuNER::new();
1170 assert!(!ner.is_available());
1171 }
1172
1173 #[test]
1174 #[cfg(feature = "onnx")]
1175 fn test_create_entity_converts_byte_offsets_to_char_offsets() {
1176 let ner = NuNER::new();
1177 let text = "北京 Beijing";
1178 let word_positions = vec![(0usize, 6usize), (7usize, 14usize)]; let entity_types = ["loc"];
1180 let span_converter = crate::offset::SpanConverter::new(text);
1181
1182 let e = ner
1184 .create_entity(
1185 text,
1186 &span_converter,
1187 &word_positions,
1188 1,
1189 2,
1190 0,
1191 0.9,
1192 &entity_types,
1193 )
1194 .expect("expected entity");
1195
1196 assert_eq!(e.text, "Beijing");
1197 assert_eq!(
1198 (e.start, e.end),
1199 (3, 10),
1200 "expected char offsets for Beijing"
1201 );
1202 }
1203}