Skip to main content

anno/backends/
nuner.rs

1//! NuNER - Token-based zero-shot NER from NuMind.
2//!
3//! NuNER is a family of zero-shot NER models built on the GLiNER architecture
4//! with a token classifier design (vs span classifier). Key advantages:
5//!
6//! - **Arbitrary-length entities**: No hard limit on entity span length
7//! - **Efficient training**: Trained on NuNER v2.0 dataset (Pile + C4)
8//! - **MIT Licensed**: Open weights from NuMind
9//!
10//! # Architecture
11//!
12//! NuNER uses the same bi-encoder architecture as GLiNER but with token classification:
13//!
14//! ```text
15//! Input: "James Bond works at MI6"
16//!        Labels: ["person", "organization"]
17//!
18//!        ┌──────────────────────┐
19//!        │   Shared Encoder     │
20//!        │  (DeBERTa/BERT)      │
21//!        └──────────────────────┘
22//!               │         │
23//!        ┌──────┴──┐   ┌──┴─────┐
24//!        │  Token  │   │ Label  │
25//!        │  Embeds │   │ Embeds │
26//!        └─────────┘   └────────┘
27//!               │         │
28//!        ┌──────┴─────────┴──────┐
29//!        │   Token Classification │  (BIO tags per token)
30//!        └───────────────────────┘
31//!               │
32//!               ▼
33//!        B-PER I-PER  O    O   B-ORG
34//!        James Bond works at  MI6
35//! ```
36//!
37//! # Differences from GLiNER (Span Mode)
38//!
39//! | Aspect | GLiNER (Span) | NuNER (Token) |
40//! |--------|---------------|---------------|
41//! | Output | Span classification | Token classification (BIO) |
42//! | Entity length | Limited by span window (12) | Arbitrary |
43//! | ONNX inputs | 6 tensors (incl span_idx) | 4 tensors (no span tensors) |
44//! | Decoding | Span scores → entities | BIO tags → entities |
45//!
46//! # Model Variants
47//!
48//! | Model | Context | Notes |
49//! |-------|---------|-------|
50//! | `numind/NuNER_Zero` | 512 | General zero-shot |
51//! | `numind/NuNER_Zero_4k` | 4096 | Long context variant |
52//! | `deepanwa/NuNerZero_onnx` | 512 | Pre-converted ONNX |
53//!
54//! # Usage
55//!
56//! ```rust,ignore
57//! use anno::NuNER;
58//!
59//! // Load NuNER model (requires `onnx` feature)
60//! let ner = NuNER::from_pretrained("deepanwa/NuNerZero_onnx")?;
61//!
62//! // Zero-shot extraction with custom labels
63//! let entities = ner.extract("Apple CEO Tim Cook announced...",
64//!                            &["person", "organization", "product"], 0.5)?;
65//! ```
66//!
67//! # References
68//!
69//! - [NuNER Zero on HuggingFace](https://huggingface.co/numind/NuNER_Zero)
70//! - [NuNER ONNX](https://huggingface.co/deepanwa/NuNerZero_onnx)
71//! - GLiNER paper (for span-based prompting inspiration)
72
73use crate::{Entity, EntityType, Model, Result};
74
75use crate::Error;
76
77/// Encoded prompt result: (input_ids, attention_mask, word_mask, num_entity_types)
78#[cfg(feature = "onnx")]
79type EncodedPrompt = (Vec<i64>, Vec<i64>, Vec<i64>, i64);
80
81/// Special token IDs for GLiNER/NuNER models (shared architecture)
82#[cfg(feature = "onnx")]
83const TOKEN_START: u32 = 1;
84#[cfg(feature = "onnx")]
85const TOKEN_END: u32 = 2;
86#[cfg(feature = "onnx")]
87const TOKEN_ENT: u32 = 128002;
88#[cfg(feature = "onnx")]
89const TOKEN_SEP: u32 = 128003;
90
91/// Maximum span width for span-based inference.
92/// NuNER uses max_width=1 (single-word spans only) per its gliner_config.json.
93/// This matches the Python GLiNER implementation's prepare_span_idx function.
94#[cfg(feature = "onnx")]
95const MAX_SPAN_WIDTH: usize = 1;
96
97/// NuNER Zero-shot NER model.
98///
99/// Token-based variant of GLiNER that uses BIO tagging instead of span classification.
100/// This enables arbitrary-length entity extraction without the span window limitation.
101///
102/// # Feature Requirements
103///
104/// Requires the `onnx` feature for actual inference. Without it, configuration
105/// methods work but extraction returns empty results.
106///
107/// # Example
108///
109/// ```rust,ignore
110/// use anno::NuNER;
111///
112/// let ner = NuNER::from_pretrained("deepanwa/NuNerZero_onnx")?;
113/// let entities = ner.extract(
114///     "The CRISPR-Cas9 system was developed by Jennifer Doudna",
115///     &["technology", "scientist"],
116///     0.5
117/// )?;
118/// ```
119pub struct NuNER {
120    /// Model path or identifier
121    model_id: String,
122    /// Confidence threshold (0.0-1.0)
123    threshold: f64,
124    /// Whether model requires span tensors (detected on load)
125    #[cfg(feature = "onnx")]
126    requires_span_tensors: std::sync::atomic::AtomicBool,
127    /// Default entity labels for Model trait
128    default_labels: Vec<String>,
129    /// ONNX session (when feature enabled)
130    #[cfg(feature = "onnx")]
131    session: Option<crate::sync::Mutex<ort::session::Session>>,
132    /// Tokenizer (when feature enabled)
133    #[cfg(feature = "onnx")]
134    tokenizer: Option<tokenizers::Tokenizer>,
135}
136
137impl NuNER {
138    /// Create NuNER with default configuration.
139    ///
140    /// Uses standard NER labels. Call `from_pretrained` (requires `onnx` feature)
141    /// to load actual model weights.
142    #[must_use]
143    pub fn new() -> Self {
144        Self {
145            model_id: "numind/NuNER_Zero".to_string(),
146            threshold: 0.5,
147            #[cfg(feature = "onnx")]
148            requires_span_tensors: std::sync::atomic::AtomicBool::new(false),
149            default_labels: vec![
150                "person".to_string(),
151                "organization".to_string(),
152                "location".to_string(),
153                "date".to_string(),
154                "product".to_string(),
155                "event".to_string(),
156            ],
157            #[cfg(feature = "onnx")]
158            session: None,
159            #[cfg(feature = "onnx")]
160            tokenizer: None,
161        }
162    }
163
164    /// Load NuNER model from HuggingFace.
165    ///
166    /// Automatically loads `.env` for HF_TOKEN if present.
167    ///
168    /// # Arguments
169    /// * `model_id` - HuggingFace model ID (e.g., "deepanwa/NuNerZero_onnx")
170    ///
171    /// # Example
172    /// ```rust,ignore
173    /// let ner = NuNER::from_pretrained("deepanwa/NuNerZero_onnx")?;
174    /// ```
175    #[cfg(feature = "onnx")]
176    pub fn from_pretrained(model_id: &str) -> Result<Self> {
177        use hf_hub::api::sync::{Api, ApiBuilder};
178        use ort::execution_providers::CPUExecutionProvider;
179        use ort::session::Session;
180
181        // Load .env if present (for HF_TOKEN)
182        crate::env::load_dotenv();
183
184        let api = if let Some(token) = crate::env::hf_token() {
185            ApiBuilder::new()
186                .with_token(Some(token))
187                .build()
188                .map_err(|e| Error::Retrieval(format!("HuggingFace API with token: {}", e)))?
189        } else {
190            Api::new().map_err(|e| {
191                Error::Retrieval(format!("Failed to initialize HuggingFace API: {}", e))
192            })?
193        };
194
195        let repo = api.model(model_id.to_string());
196
197        // Download model and tokenizer
198        let model_path = repo
199            .get("onnx/model.onnx")
200            .or_else(|_| repo.get("model.onnx"))
201            .map_err(|e| Error::Retrieval(format!("Failed to download model.onnx: {}", e)))?;
202
203        let tokenizer_path = repo
204            .get("tokenizer.json")
205            .map_err(|e| Error::Retrieval(format!("Failed to download tokenizer.json: {}", e)))?;
206
207        let session = Session::builder()
208            .map_err(|e| Error::Retrieval(format!("Failed to create ONNX session: {}", e)))?
209            .with_execution_providers([CPUExecutionProvider::default().build()])
210            .map_err(|e| Error::Retrieval(format!("Failed to set execution providers: {}", e)))?
211            .commit_from_file(&model_path)
212            .map_err(|e| Error::Retrieval(format!("Failed to load ONNX model: {}", e)))?;
213
214        let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
215            .map_err(|e| Error::Retrieval(format!("Failed to load tokenizer: {}", e)))?;
216
217        Ok(Self {
218            model_id: model_id.to_string(),
219            threshold: 0.5,
220            requires_span_tensors: std::sync::atomic::AtomicBool::new(false),
221            default_labels: vec![
222                "person".to_string(),
223                "organization".to_string(),
224                "location".to_string(),
225            ],
226            session: Some(crate::sync::Mutex::new(session)),
227            tokenizer: Some(tokenizer),
228        })
229    }
230
231    /// Create with custom model identifier (for configuration only).
232    #[must_use]
233    pub fn with_model(model_id: impl Into<String>) -> Self {
234        let mut new = Self::new();
235        new.model_id = model_id.into();
236        new
237    }
238
239    /// Set confidence threshold.
240    #[must_use]
241    pub fn with_threshold(mut self, threshold: f64) -> Self {
242        self.threshold = threshold.clamp(0.0, 1.0);
243        self
244    }
245
246    /// Set default entity labels for Model trait.
247    #[must_use]
248    pub fn with_labels(mut self, labels: Vec<String>) -> Self {
249        self.default_labels = labels;
250        self
251    }
252
253    /// Get the model identifier.
254    #[must_use]
255    pub fn model_id(&self) -> &str {
256        &self.model_id
257    }
258
259    /// Get the confidence threshold.
260    #[must_use]
261    pub fn threshold(&self) -> f64 {
262        self.threshold
263    }
264
265    /// Extract entities with custom labels.
266    ///
267    /// Unlike the `Model` trait which uses default labels, this method
268    /// allows specifying arbitrary entity types at runtime.
269    ///
270    /// # Arguments
271    /// * `text` - Text to extract from
272    /// * `entity_types` - Entity type labels (e.g., ["person", "company"])
273    /// * `threshold` - Confidence threshold (0.0-1.0)
274    #[cfg(feature = "onnx")]
275    pub fn extract(
276        &self,
277        text: &str,
278        entity_types: &[&str],
279        threshold: f32,
280    ) -> Result<Vec<Entity>> {
281        if text.is_empty() || entity_types.is_empty() {
282            return Ok(vec![]);
283        }
284
285        // Debug tracing
286        if std::env::var("ANNO_DEBUG_NUNER_EXTRACT").is_ok() {
287            eprintln!(
288                "DEBUG nuner extract: text.len={} entity_types={:?}",
289                text.len(),
290                entity_types
291            );
292        }
293
294        let session = self.session.as_ref().ok_or_else(|| {
295            Error::Retrieval("Model not loaded. Call from_pretrained() first.".to_string())
296        })?;
297
298        let tokenizer = self
299            .tokenizer
300            .as_ref()
301            .ok_or_else(|| Error::Retrieval("Tokenizer not loaded.".to_string()))?;
302
303        // Split text into words
304        let text_words: Vec<&str> = text.split_whitespace().collect();
305        if text_words.is_empty() {
306            return Ok(vec![]);
307        }
308
309        // Encode input (token mode - no span tensors)
310        let (input_ids, attention_mask, words_mask, text_lengths) =
311            self.encode_prompt(tokenizer, &text_words, entity_types)?;
312
313        let batch_size = 1;
314        let seq_len = input_ids.len();
315
316        // Note: ort tensors are consumed by `into_dyn()`, so we rebuild them for retries.
317        let make_token_tensors = || -> Result<(_, _, _, _)> {
318            use ndarray::Array2;
319
320            let input_ids_array = Array2::from_shape_vec((batch_size, seq_len), input_ids.clone())
321                .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
322            let attention_mask_array =
323                Array2::from_shape_vec((batch_size, seq_len), attention_mask.clone())
324                    .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
325            let words_mask_array =
326                Array2::from_shape_vec((batch_size, seq_len), words_mask.clone())
327                    .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
328            let text_lengths_array = Array2::from_shape_vec((batch_size, 1), vec![text_lengths])
329                .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
330
331            let input_ids_t = super::ort_compat::tensor_from_ndarray(input_ids_array)
332                .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
333            let attention_mask_t = super::ort_compat::tensor_from_ndarray(attention_mask_array)
334                .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
335            let words_mask_t = super::ort_compat::tensor_from_ndarray(words_mask_array)
336                .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
337            let text_lengths_t = super::ort_compat::tensor_from_ndarray(text_lengths_array)
338                .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
339
340            Ok((input_ids_t, attention_mask_t, words_mask_t, text_lengths_t))
341        };
342
343        // Some NuNER ONNX exports require span tensors (span_idx/span_mask), others are token-only.
344        // We default to token-only, and flip to span tensors on the first "missing input" error.
345        use std::sync::atomic::Ordering;
346        let mut needs_span_tensors = self.requires_span_tensors.load(Ordering::Relaxed);
347
348        // Use blocking lock for thread-safe parallel access
349        let mut session_guard = crate::sync::lock(session);
350
351        let outputs = loop {
352            if needs_span_tensors {
353                let (input_ids_t, attention_mask_t, words_mask_t, text_lengths_t) =
354                    make_token_tensors()?;
355                // Generate span tensors similar to GLiNER
356                // Use checked_mul to prevent overflow (same as gliner2.rs:2388)
357                let num_spans = match text_words.len().checked_mul(MAX_SPAN_WIDTH) {
358                    Some(v) => v,
359                    None => {
360                        return Err(Error::InvalidInput(format!(
361                            "Span count overflow: {} words * {} MAX_SPAN_WIDTH",
362                            text_words.len(),
363                            MAX_SPAN_WIDTH
364                        )));
365                    }
366                };
367                let (span_idx, span_mask) = NuNER::make_span_tensors(text_words.len());
368
369                use ndarray::Array2;
370                use ndarray::Array3;
371                let span_idx_array = Array3::from_shape_vec((1, num_spans, 2), span_idx)
372                    .map_err(|e| Error::Parse(format!("Span idx array error: {}", e)))?;
373                let span_mask_array = Array2::from_shape_vec((1, num_spans), span_mask)
374                    .map_err(|e| Error::Parse(format!("Span mask array error: {}", e)))?;
375
376                let span_idx_t = super::ort_compat::tensor_from_ndarray(span_idx_array)
377                    .map_err(|e| Error::Parse(format!("Span idx tensor error: {}", e)))?;
378                let span_mask_t = super::ort_compat::tensor_from_ndarray(span_mask_array)
379                    .map_err(|e| Error::Parse(format!("Span mask tensor error: {}", e)))?;
380
381                break session_guard
382                    .run(ort::inputs![
383                        "input_ids" => input_ids_t.into_dyn(),
384                        "attention_mask" => attention_mask_t.into_dyn(),
385                        "words_mask" => words_mask_t.into_dyn(),
386                        "text_lengths" => text_lengths_t.into_dyn(),
387                        "span_idx" => span_idx_t.into_dyn(),
388                        "span_mask" => span_mask_t.into_dyn(),
389                    ])
390                    .map_err(|e| {
391                        Error::Parse(format!(
392                            "ONNX inference failed: {}\n\n\
393                             NuNER model: {}\n\
394                             requires_span_tensors={}\n\
395                             input_ids=(1,{seq_len}) attention_mask=(1,{seq_len}) words_mask=(1,{seq_len}) text_lengths=(1,1)\n\
396                             span_idx=(1,{num_spans},2) span_mask=(1,{num_spans})\n\n\
397                             Hint: If this looks like a shape mismatch, the ONNX export may have fixed span dimensions.\n\
398                             Try a different NuNER export (e.g., deepanwa/NuNerZero_onnx) or re-export with dynamic axes.",
399                            e,
400                            self.model_id,
401                            self.requires_span_tensors.load(Ordering::Relaxed)
402                        ))
403                    })?;
404            } else {
405                let (input_ids_t, attention_mask_t, words_mask_t, text_lengths_t) =
406                    make_token_tensors()?;
407                // Token mode - only 4 inputs
408                let res = session_guard.run(ort::inputs![
409                    "input_ids" => input_ids_t.into_dyn(),
410                    "attention_mask" => attention_mask_t.into_dyn(),
411                    "words_mask" => words_mask_t.into_dyn(),
412                    "text_lengths" => text_lengths_t.into_dyn(),
413                ]);
414
415                match res {
416                    Ok(o) => break o,
417                    Err(e) => {
418                        let msg = format!("{e}");
419                        let looks_like_missing_span = msg.contains("Missing Input: span_mask")
420                            || msg.contains("Missing Input: span_idx")
421                            || msg.contains("span_mask")
422                            || msg.contains("span_idx");
423                        if looks_like_missing_span {
424                            // Memoize and retry once in span mode (same request).
425                            self.requires_span_tensors.store(true, Ordering::Relaxed);
426                            needs_span_tensors = true;
427                            continue;
428                        }
429                        return Err(Error::Parse(format!(
430                            "ONNX inference failed: {}\n\n\
431                             NuNER model: {}\n\
432                             requires_span_tensors={}\n\
433                             input_ids=(1,{seq_len}) attention_mask=(1,{seq_len}) words_mask=(1,{seq_len}) text_lengths=(1,1)\n\n\
434                             Hint: If this looks like an input-name mismatch, your ONNX export may expect span tensors or different input names.",
435                            e,
436                            self.model_id,
437                            self.requires_span_tensors.load(Ordering::Relaxed),
438                        )));
439                    }
440                }
441            }
442        };
443
444        // Decode span-level output to entities
445        // NuNER with span_mode=marker and max_width=1 outputs: [batch, num_words, max_width, num_classes]
446        let entities =
447            self.decode_span_output(&outputs, text, &text_words, entity_types, threshold)?;
448
449        Ok(entities)
450    }
451
452    /// Generate span tensors for span-based inference (if model requires it).
453    ///
454    /// Matches Python GLiNER's prepare_span_idx function:
455    /// `span_idx = [(i, i + j) for i in range(num_tokens) for j in range(max_width)]`
456    ///
457    /// With MAX_SPAN_WIDTH=1, generates single-word spans only: (0,0), (1,1), etc.
458    /// Span indices use INCLUSIVE end positions (matching Python GLiNER).
459    ///
460    /// Returns: (span_idx, span_mask)
461    /// - span_idx: [num_spans, 2] - (start, end) word indices (both 0-indexed, inclusive)
462    /// - span_mask: [num_spans] - boolean mask indicating valid spans
463    #[cfg(feature = "onnx")]
464    pub(crate) fn make_span_tensors(num_words: usize) -> (Vec<i64>, Vec<bool>) {
465        // Use checked_mul to prevent overflow (same as gliner2.rs:2388)
466        let num_spans = match num_words.checked_mul(MAX_SPAN_WIDTH) {
467            Some(v) => v,
468            None => {
469                // Overflow - return empty tensors (shouldn't happen in practice)
470                log::warn!(
471                    "Span count overflow: {} words * {} MAX_SPAN_WIDTH, returning empty tensors",
472                    num_words,
473                    MAX_SPAN_WIDTH
474                );
475                return (Vec::new(), Vec::new());
476            }
477        };
478        // Check for overflow in num_spans * 2
479        let span_idx_len = match num_spans.checked_mul(2) {
480            Some(v) => v,
481            None => {
482                log::warn!(
483                    "Span idx length overflow: {} spans * 2, returning empty tensors",
484                    num_spans
485                );
486                return (Vec::new(), Vec::new());
487            }
488        };
489        let mut span_idx: Vec<i64> = vec![0; span_idx_len];
490        let mut span_mask: Vec<bool> = vec![false; num_spans];
491
492        for start in 0..num_words {
493            let remaining_width = num_words - start;
494            let actual_max_width = MAX_SPAN_WIDTH.min(remaining_width);
495
496            for width in 0..actual_max_width {
497                // Check for overflow in dim calculation
498                let dim = match start.checked_mul(MAX_SPAN_WIDTH) {
499                    Some(v) => match v.checked_add(width) {
500                        Some(d) => d,
501                        None => {
502                            log::warn!(
503                                "Dim calculation overflow: {} * {} + {}, skipping span",
504                                start,
505                                MAX_SPAN_WIDTH,
506                                width
507                            );
508                            continue;
509                        }
510                    },
511                    None => {
512                        log::warn!(
513                            "Dim calculation overflow: {} * {}, skipping span",
514                            start,
515                            MAX_SPAN_WIDTH
516                        );
517                        continue;
518                    }
519                };
520                // Check bounds before array access (dim * 2 could overflow or exceed span_idx_len)
521                if let Some(dim2) = dim.checked_mul(2) {
522                    if dim2 + 1 < span_idx_len && dim < num_spans {
523                        span_idx[dim2] = start as i64; // start offset (0-indexed, inclusive)
524                        span_idx[dim2 + 1] = (start + width) as i64; // end offset (0-indexed, INCLUSIVE per Python GLiNER)
525                        span_mask[dim] = true;
526                    } else {
527                        log::warn!(
528                            "Span idx access out of bounds: dim={}, dim*2={}, span_idx_len={}, num_spans={}, skipping",
529                            dim, dim2, span_idx_len, num_spans
530                        );
531                    }
532                } else {
533                    log::warn!("Dim * 2 overflow: dim={}, skipping span", dim);
534                }
535            }
536        }
537
538        (span_idx, span_mask)
539    }
540
541    /// Encode prompt for token mode (no span tensors).
542    #[cfg(feature = "onnx")]
543    fn encode_prompt(
544        &self,
545        tokenizer: &tokenizers::Tokenizer,
546        text_words: &[&str],
547        entity_types: &[&str],
548    ) -> Result<EncodedPrompt> {
549        // Performance: Pre-allocate vectors with estimated capacity
550        // Most prompts have 50-200 tokens
551        let mut input_ids: Vec<i64> = Vec::with_capacity(128);
552        let mut word_mask: Vec<i64> = Vec::with_capacity(128);
553
554        // [START]
555        input_ids.push(TOKEN_START as i64);
556        word_mask.push(0);
557
558        // <<ENT>> type1 <<ENT>> type2 ...
559        for entity_type in entity_types {
560            input_ids.push(TOKEN_ENT as i64);
561            word_mask.push(0);
562
563            let encoding = tokenizer
564                .encode(entity_type.to_string(), false)
565                .map_err(|e| Error::Parse(format!("Tokenizer error: {}", e)))?;
566            for token_id in encoding.get_ids() {
567                input_ids.push(*token_id as i64);
568                word_mask.push(0);
569            }
570        }
571
572        // <<SEP>>
573        input_ids.push(TOKEN_SEP as i64);
574        word_mask.push(0);
575
576        // Text words (word_mask starts from 1)
577        let mut word_id: i64 = 0;
578        for word in text_words {
579            let encoding = tokenizer
580                .encode(word.to_string(), false)
581                .map_err(|e| Error::Parse(format!("Tokenizer error: {}", e)))?;
582
583            word_id += 1;
584            for (token_idx, token_id) in encoding.get_ids().iter().enumerate() {
585                input_ids.push(*token_id as i64);
586                word_mask.push(if token_idx == 0 { word_id } else { 0 });
587            }
588        }
589
590        // [END]
591        input_ids.push(TOKEN_END as i64);
592        word_mask.push(0);
593
594        let seq_len = input_ids.len();
595        let attention_mask: Vec<i64> = vec![1; seq_len];
596
597        Ok((input_ids, attention_mask, word_mask, word_id))
598    }
599
600    /// Decode token classification output to entities.
601    ///
602    /// Token mode output shape: [batch, seq_len, num_entity_types]
603    /// Each position has scores for each entity type (BIO-style).
604    #[cfg(feature = "onnx")]
605    fn decode_token_output(
606        &self,
607        outputs: &ort::session::SessionOutputs,
608        text: &str,
609        text_words: &[&str],
610        entity_types: &[&str],
611        threshold: f32,
612    ) -> Result<Vec<Entity>> {
613        let output = outputs
614            .iter()
615            .next()
616            .map(|(_, v)| v)
617            .ok_or_else(|| Error::Parse("No output from NuNER model".to_string()))?;
618
619        let (_, data_slice) = output
620            .try_extract_tensor::<f32>()
621            .map_err(|e| Error::Parse(format!("Failed to extract output tensor: {}", e)))?;
622        let output_data: Vec<f32> = data_slice.to_vec();
623
624        // Get shape: [batch, num_words, num_classes]
625        let shape: Vec<i64> = match output.dtype() {
626            ort::value::ValueType::Tensor { shape, .. } => shape.iter().copied().collect(),
627            _ => return Err(Error::Parse("Expected tensor output".to_string())),
628        };
629
630        // Debug output shape
631        if std::env::var("ANNO_DEBUG_NUNER_DECODE").is_ok() {
632            eprintln!(
633                "DEBUG nuner decode: shape={:?} text_words.len={} data.len={}",
634                shape,
635                text_words.len(),
636                output_data.len()
637            );
638            // Sample first few values
639            let sample: Vec<f32> = output_data.iter().take(10).copied().collect();
640            eprintln!("DEBUG nuner decode: sample data={:?}", sample);
641        }
642
643        if shape.len() < 3 {
644            return Err(Error::Parse(format!(
645                "Unexpected output shape: {:?}",
646                shape
647            )));
648        }
649
650        let num_words = shape[1] as usize;
651        let num_classes = shape[2] as usize;
652
653        if std::env::var("ANNO_DEBUG_NUNER_DECODE").is_ok() {
654            eprintln!(
655                "DEBUG nuner decode: num_words={} num_classes={} entity_types.len={}",
656                num_words,
657                num_classes,
658                entity_types.len()
659            );
660        }
661
662        // Calculate word positions in original text
663        // Validate that all words are found to prevent silent failures
664        let word_positions: Vec<(usize, usize)> = {
665            // Performance: Pre-allocate positions vec with known size
666            let mut positions = Vec::with_capacity(text_words.len());
667            let mut pos = 0;
668            for (idx, word) in text_words.iter().enumerate() {
669                if let Some(start) = text[pos..].find(word) {
670                    let abs_start = pos + start;
671                    let abs_end = abs_start + word.len();
672                    // Validate position is after previous word (words should be in order)
673                    if !positions.is_empty() {
674                        let (_prev_start, prev_end) = positions[positions.len() - 1];
675                        if abs_start < prev_end {
676                            log::warn!(
677                                "Word '{}' at position {} overlaps with previous word ending at {}",
678                                word,
679                                abs_start,
680                                prev_end
681                            );
682                        }
683                    }
684                    positions.push((abs_start, abs_end));
685                    pos = abs_end;
686                } else {
687                    // Word not found - return error to prevent silent entity skipping
688                    return Err(Error::Parse(format!(
689                        "Word '{}' (index {}) not found in text starting at position {}",
690                        word, idx, pos
691                    )));
692                }
693            }
694            positions
695        };
696
697        // Validate that we found positions for all words
698        if word_positions.len() != text_words.len() {
699            return Err(Error::Parse(format!(
700                "Word position mismatch: found {} positions for {} words",
701                word_positions.len(),
702                text_words.len()
703            )));
704        }
705
706        // Word positions are byte offsets; `Entity` requires character offsets.
707        let span_converter = crate::offset::SpanConverter::new(text);
708
709        // Performance: Pre-allocate entities vec with estimated capacity
710        let mut entities = Vec::with_capacity(16);
711        let mut current_entity: Option<(usize, usize, usize, f32)> = None; // (start_word, end_word, type_idx, score)
712
713        // Process each word position
714        for word_idx in 0..num_words.min(text_words.len()) {
715            let base_idx = word_idx * num_classes;
716
717            // Find best class for this word
718            let mut best_class = 0;
719            let mut best_score = 0.0f32;
720
721            for class_idx in 0..num_classes {
722                let score = output_data
723                    .get(base_idx + class_idx)
724                    .copied()
725                    .unwrap_or(0.0);
726                if score > best_score {
727                    best_score = score;
728                    best_class = class_idx;
729                }
730            }
731
732            // BIO decoding: class 0 = O, odd = B-type, even = I-type
733            let is_begin = best_class > 0 && best_class % 2 == 1;
734            let is_inside = best_class > 0 && best_class % 2 == 0;
735            let type_idx = if best_class > 0 {
736                (best_class - 1) / 2
737            } else {
738                0
739            };
740
741            if best_score >= threshold {
742                if is_begin {
743                    // Flush previous entity
744                    if let Some((start, end, etype, score)) = current_entity.take() {
745                        if let Some(e) = self.create_entity(
746                            text,
747                            &span_converter,
748                            &word_positions,
749                            start,
750                            end,
751                            etype,
752                            score,
753                            entity_types,
754                        ) {
755                            entities.push(e);
756                        }
757                    }
758                    // Start new entity
759                    current_entity = Some((word_idx, word_idx + 1, type_idx, best_score));
760                } else if is_inside {
761                    // Extend current entity if same type
762                    if let Some((_start, end, etype, score)) = current_entity.as_mut() {
763                        if *etype == type_idx {
764                            *end = word_idx + 1;
765                            *score = (*score + best_score) / 2.0; // Average confidence
766                        }
767                    }
768                }
769            } else {
770                // Low confidence or O tag - flush current entity
771                if let Some((start, end, etype, score)) = current_entity.take() {
772                    if let Some(e) = self.create_entity(
773                        text,
774                        &span_converter,
775                        &word_positions,
776                        start,
777                        end,
778                        etype,
779                        score,
780                        entity_types,
781                    ) {
782                        entities.push(e);
783                    }
784                }
785            }
786        }
787
788        // Flush final entity
789        if let Some((start, end, etype, score)) = current_entity.take() {
790            if let Some(e) = self.create_entity(
791                text,
792                &span_converter,
793                &word_positions,
794                start,
795                end,
796                etype,
797                score,
798                entity_types,
799            ) {
800                entities.push(e);
801            }
802        }
803
804        Ok(entities)
805    }
806
807    /// Decode span classification output to entities.
808    ///
809    /// Span mode output shape: [batch, num_words, max_width, num_classes]
810    /// With max_width=1, each word has logits for each entity type.
811    /// We apply sigmoid and compare to threshold.
812    #[cfg(feature = "onnx")]
813    fn decode_span_output(
814        &self,
815        outputs: &ort::session::SessionOutputs,
816        text: &str,
817        text_words: &[&str],
818        entity_types: &[&str],
819        threshold: f32,
820    ) -> Result<Vec<Entity>> {
821        // Find the logits output
822        let logits_output = outputs
823            .iter()
824            .find(|(name, _)| name.contains("logits"))
825            .map(|(_, v)| v)
826            .or_else(|| outputs.iter().next().map(|(_, v)| v))
827            .ok_or_else(|| Error::Parse("No logits output from NuNER model".to_string()))?;
828
829        let (_, data_slice) = logits_output
830            .try_extract_tensor::<f32>()
831            .map_err(|e| Error::Parse(format!("Failed to extract output tensor: {}", e)))?;
832        let output_data: Vec<f32> = data_slice.to_vec();
833
834        // Get shape: [batch, num_words, max_width, num_classes]
835        let shape: Vec<i64> = match logits_output.dtype() {
836            ort::value::ValueType::Tensor { shape, .. } => shape.iter().copied().collect(),
837            _ => return Err(Error::Parse("Expected tensor output".to_string())),
838        };
839
840        if shape.len() != 4 {
841            // Fall back to token decoding if shape doesn't match span format
842            return self.decode_token_output(outputs, text, text_words, entity_types, threshold);
843        }
844
845        let num_words = shape[1] as usize;
846        let max_width = shape[2] as usize; // Should be 1 for NuNER
847        let num_classes = shape[3] as usize;
848
849        // Debug
850        if std::env::var("ANNO_DEBUG_NUNER_DECODE").is_ok() {
851            eprintln!(
852                "DEBUG nuner decode_span: shape={:?} num_words={} max_width={} num_classes={} entity_types.len={}",
853                shape, num_words, max_width, num_classes, entity_types.len()
854            );
855        }
856
857        // Calculate word positions in original text
858        let word_positions: Vec<(usize, usize)> = {
859            let mut positions = Vec::with_capacity(text_words.len());
860            let mut pos = 0;
861            for word in text_words.iter() {
862                if let Some(start) = text[pos..].find(word) {
863                    let abs_start = pos + start;
864                    let abs_end = abs_start + word.len();
865                    positions.push((abs_start, abs_end));
866                    pos = abs_end;
867                } else {
868                    // Word not found - this shouldn't happen with whitespace split
869                    return Err(Error::Parse(format!(
870                        "Word '{}' not found in text starting at position {}",
871                        word, pos
872                    )));
873                }
874            }
875            positions
876        };
877
878        // Word positions are byte offsets; `Entity` requires character offsets.
879        let span_converter = crate::offset::SpanConverter::new(text);
880
881        let mut entities = Vec::with_capacity(16);
882        let mut current_entity: Option<(usize, usize, usize, f32)> = None; // (start_word, end_word, type_idx, score)
883
884        // Process each word
885        for word_idx in 0..num_words.min(text_words.len()) {
886            // For span mode with max_width=1, each word has one set of class logits
887            // Index: [batch=0, word_idx, width=0, class_idx]
888            let base_idx = word_idx * max_width * num_classes;
889
890            // Find best class above threshold
891            let mut best_class: Option<usize> = None;
892            let mut best_prob = 0.0f32;
893
894            for class_idx in 0..num_classes {
895                let logit = output_data
896                    .get(base_idx + class_idx)
897                    .copied()
898                    .unwrap_or(f32::NEG_INFINITY);
899                // Apply sigmoid: prob = 1 / (1 + exp(-logit))
900                let prob = 1.0 / (1.0 + (-logit).exp());
901
902                if prob >= threshold && prob > best_prob {
903                    best_prob = prob;
904                    best_class = Some(class_idx);
905                }
906            }
907
908            if let Some(class_idx) = best_class {
909                // We found an entity at this word
910                if let Some((start, end, etype, score)) = current_entity.as_mut() {
911                    if *etype == class_idx {
912                        // Extend current entity (same type)
913                        *end = word_idx + 1;
914                        *score = (*score + best_prob) / 2.0;
915                    } else {
916                        // Different type - flush and start new
917                        if let Some(e) = self.create_entity(
918                            text,
919                            &span_converter,
920                            &word_positions,
921                            *start,
922                            *end,
923                            *etype,
924                            *score,
925                            entity_types,
926                        ) {
927                            entities.push(e);
928                        }
929                        current_entity = Some((word_idx, word_idx + 1, class_idx, best_prob));
930                    }
931                } else {
932                    // Start new entity
933                    current_entity = Some((word_idx, word_idx + 1, class_idx, best_prob));
934                }
935            } else {
936                // No entity at this word - flush current
937                if let Some((start, end, etype, score)) = current_entity.take() {
938                    if let Some(e) = self.create_entity(
939                        text,
940                        &span_converter,
941                        &word_positions,
942                        start,
943                        end,
944                        etype,
945                        score,
946                        entity_types,
947                    ) {
948                        entities.push(e);
949                    }
950                }
951            }
952        }
953
954        // Flush final entity
955        if let Some((start, end, etype, score)) = current_entity.take() {
956            if let Some(e) = self.create_entity(
957                text,
958                &span_converter,
959                &word_positions,
960                start,
961                end,
962                etype,
963                score,
964                entity_types,
965            ) {
966                entities.push(e);
967            }
968        }
969
970        if std::env::var("ANNO_DEBUG_NUNER_DECODE").is_ok() {
971            eprintln!("DEBUG nuner decode_span: found {} entities", entities.len());
972        }
973
974        Ok(entities)
975    }
976
977    #[cfg(feature = "onnx")]
978    #[allow(clippy::too_many_arguments)]
979    fn create_entity(
980        &self,
981        text: &str,
982        span_converter: &crate::offset::SpanConverter,
983        word_positions: &[(usize, usize)],
984        start_word: usize,
985        end_word: usize,
986        type_idx: usize,
987        score: f32,
988        entity_types: &[&str],
989    ) -> Option<Entity> {
990        // Validate indices to prevent underflow
991        if end_word == 0 || end_word > word_positions.len() || start_word >= word_positions.len() {
992            return None;
993        }
994        let start_pos = word_positions.get(start_word)?.0;
995        let end_pos = word_positions.get(end_word.saturating_sub(1))?.1;
996
997        let entity_text = text.get(start_pos..end_pos)?;
998        let label = entity_types.get(type_idx)?;
999        let entity_type = Self::map_label_to_entity_type(label);
1000
1001        let char_start = span_converter.byte_to_char(start_pos);
1002        let char_end = span_converter.byte_to_char(end_pos);
1003
1004        Some(Entity::new(
1005            entity_text,
1006            entity_type,
1007            char_start,
1008            char_end,
1009            score as f64,
1010        ))
1011    }
1012
1013    /// Map label string to EntityType.
1014    fn map_label_to_entity_type(label: &str) -> EntityType {
1015        match label.to_lowercase().as_str() {
1016            "person" | "per" => EntityType::Person,
1017            "organization" | "org" | "company" => EntityType::Organization,
1018            "location" | "loc" | "place" | "gpe" => EntityType::Location,
1019            "date" => EntityType::Date,
1020            "time" => EntityType::Time,
1021            "money" | "currency" => EntityType::Money,
1022            "percent" | "percentage" => EntityType::Percent,
1023            _ => EntityType::Other(label.to_string()),
1024        }
1025    }
1026}
1027
1028impl Default for NuNER {
1029    fn default() -> Self {
1030        Self::new()
1031    }
1032}
1033
1034impl Model for NuNER {
1035    fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
1036        if text.trim().is_empty() {
1037            return Ok(vec![]);
1038        }
1039
1040        #[cfg(feature = "onnx")]
1041        {
1042            if self.session.is_some() {
1043                let labels: Vec<&str> = self.default_labels.iter().map(|s| s.as_str()).collect();
1044                return self.extract(text, &labels, self.threshold as f32);
1045            }
1046
1047            Err(Error::ModelInit(
1048                "NuNER model not loaded. Call `NuNER::from_pretrained(...)` (requires `onnx` feature) before calling `extract_entities`.".to_string(),
1049            ))
1050        }
1051
1052        #[cfg(not(feature = "onnx"))]
1053        {
1054            Err(Error::FeatureNotAvailable(
1055                "NuNER requires the 'onnx' feature. Build with: cargo build --features onnx"
1056                    .to_string(),
1057            ))
1058        }
1059    }
1060
1061    fn supported_types(&self) -> Vec<EntityType> {
1062        self.default_labels
1063            .iter()
1064            .map(|l| Self::map_label_to_entity_type(l))
1065            .collect()
1066    }
1067
1068    fn is_available(&self) -> bool {
1069        #[cfg(feature = "onnx")]
1070        {
1071            self.session.is_some()
1072        }
1073        #[cfg(not(feature = "onnx"))]
1074        {
1075            false
1076        }
1077    }
1078
1079    fn name(&self) -> &'static str {
1080        "nuner"
1081    }
1082
1083    fn description(&self) -> &'static str {
1084        "NuNER Zero: Token-based zero-shot NER from NuMind (MIT licensed)"
1085    }
1086
1087    fn version(&self) -> String {
1088        format!("nuner-zero-{}", self.model_id)
1089    }
1090}
1091
1092// =============================================================================
1093// BatchCapable Trait Implementation
1094// =============================================================================
1095
1096impl crate::BatchCapable for NuNER {
1097    fn optimal_batch_size(&self) -> Option<usize> {
1098        Some(8)
1099    }
1100}
1101
1102// =============================================================================
1103// StreamingCapable Trait Implementation
1104// =============================================================================
1105
1106impl crate::StreamingCapable for NuNER {
1107    fn recommended_chunk_size(&self) -> usize {
1108        4096 // Characters
1109    }
1110}
1111
1112#[cfg(test)]
1113mod tests {
1114    use super::*;
1115
1116    #[test]
1117    fn test_nuner_creation() {
1118        let ner = NuNER::new();
1119        assert_eq!(ner.model_id(), "numind/NuNER_Zero");
1120        assert!((ner.threshold() - 0.5).abs() < f64::EPSILON);
1121    }
1122
1123    #[test]
1124    fn test_nuner_with_custom_model() {
1125        let ner = NuNER::with_model("custom/model")
1126            .with_threshold(0.7)
1127            .with_labels(vec!["technology".to_string()]);
1128
1129        assert_eq!(ner.model_id(), "custom/model");
1130        assert!((ner.threshold() - 0.7).abs() < f64::EPSILON);
1131        assert_eq!(ner.default_labels.len(), 1);
1132    }
1133
1134    #[test]
1135    fn test_label_mapping() {
1136        assert_eq!(
1137            NuNER::map_label_to_entity_type("person"),
1138            EntityType::Person
1139        );
1140        assert_eq!(NuNER::map_label_to_entity_type("PER"), EntityType::Person);
1141        assert_eq!(
1142            NuNER::map_label_to_entity_type("organization"),
1143            EntityType::Organization
1144        );
1145        assert_eq!(
1146            NuNER::map_label_to_entity_type("custom"),
1147            EntityType::Other("custom".to_string())
1148        );
1149    }
1150
1151    #[test]
1152    fn test_supported_types() {
1153        let ner = NuNER::new();
1154        let types = ner.supported_types();
1155        assert!(types.contains(&EntityType::Person));
1156        assert!(types.contains(&EntityType::Organization));
1157        assert!(types.contains(&EntityType::Location));
1158    }
1159
1160    #[test]
1161    fn test_empty_input() {
1162        let ner = NuNER::new();
1163        let entities = ner.extract_entities("", None).unwrap();
1164        assert!(entities.is_empty());
1165    }
1166
1167    #[test]
1168    fn test_not_available_without_model() {
1169        let ner = NuNER::new();
1170        assert!(!ner.is_available());
1171    }
1172
1173    #[test]
1174    #[cfg(feature = "onnx")]
1175    fn test_create_entity_converts_byte_offsets_to_char_offsets() {
1176        let ner = NuNER::new();
1177        let text = "北京 Beijing";
1178        let word_positions = vec![(0usize, 6usize), (7usize, 14usize)]; // byte offsets
1179        let entity_types = ["loc"];
1180        let span_converter = crate::offset::SpanConverter::new(text);
1181
1182        // Select the second word ("Beijing"): start_word=1, end_word=2 (exclusive)
1183        let e = ner
1184            .create_entity(
1185                text,
1186                &span_converter,
1187                &word_positions,
1188                1,
1189                2,
1190                0,
1191                0.9,
1192                &entity_types,
1193            )
1194            .expect("expected entity");
1195
1196        assert_eq!(e.text, "Beijing");
1197        assert_eq!(
1198            (e.start, e.end),
1199            (3, 10),
1200            "expected char offsets for Beijing"
1201        );
1202    }
1203}