Skip to main content

anno/backends/nuner/
mod.rs

1//! NuNER - Token-based zero-shot NER from NuMind.
2//!
3//! NuNER is a family of zero-shot NER models built on the GLiNER architecture
4//! with a token classifier design (vs span classifier). Key advantages:
5//!
6//! - **Arbitrary-length entities**: No hard limit on entity span length
7//! - **Efficient training**: Trained on NuNER v2.0 dataset (Pile + C4)
8//! - **MIT Licensed**: Open weights from NuMind
9//!
10//! # Architecture
11//!
12//! NuNER uses the same bi-encoder architecture as GLiNER but with token classification:
13//!
14//! ```text
15//! Input: "James Bond works at MI6"
16//!        Labels: ["person", "organization"]
17//!
18//!        ┌──────────────────────┐
19//!        │   Shared Encoder     │
20//!        │  (DeBERTa/BERT)      │
21//!        └──────────────────────┘
22//!               │         │
23//!        ┌──────┴──┐   ┌──┴─────┐
24//!        │  Token  │   │ Label  │
25//!        │  Embeds │   │ Embeds │
26//!        └─────────┘   └────────┘
27//!               │         │
28//!        ┌──────┴─────────┴──────┐
29//!        │   Token Classification │  (BIO tags per token)
30//!        └───────────────────────┘
31//!               │
32//!               ▼
33//!        B-PER I-PER  O    O   B-ORG
34//!        James Bond works at  MI6
35//! ```
36//!
37//! # Differences from GLiNER (Span Mode)
38//!
39//! | Aspect | GLiNER (Span) | NuNER (Token) |
40//! |--------|---------------|---------------|
41//! | Output | Span classification | Token classification (BIO) |
42//! | Entity length | Limited by span window (12) | Arbitrary |
43//! | ONNX inputs | 6 tensors (incl span_idx) | 4 tensors (no span tensors) |
44//! | Decoding | Span scores → entities | BIO tags → entities |
45//!
46//! # Model Variants
47//!
48//! | Model | Context | Notes |
49//! |-------|---------|-------|
50//! | `numind/NuNER_Zero` | 512 | General zero-shot |
51//! | `numind/NuNER_Zero_4k` | 4096 | Long context variant |
52//! | `deepanwa/NuNerZero_onnx` | 512 | Pre-converted ONNX |
53//!
54//! # Usage
55//!
56//! ```rust,ignore
57//! use anno::NuNER;
58//!
59//! // Load NuNER model (requires `onnx` feature)
60//! let ner = NuNER::from_pretrained("deepanwa/NuNerZero_onnx")?;
61//!
62//! // Zero-shot extraction with custom labels
63//! let entities = ner.extract("Apple CEO Tim Cook announced...",
64//!                            &["person", "organization", "product"], 0.5)?;
65//! ```
66//!
67//! # References
68//!
69//! - [NuNER Zero on HuggingFace](https://huggingface.co/numind/NuNER_Zero)
70//! - [NuNER ONNX](https://huggingface.co/deepanwa/NuNerZero_onnx)
71//! - GLiNER paper (for span-based prompting inspiration)
72
73use crate::{Entity, EntityType, Model, Result};
74
75use crate::Error;
76
77/// Encoded prompt result: (input_ids, attention_mask, word_mask, num_entity_types)
78#[cfg(feature = "onnx")]
79type EncodedPrompt = (Vec<i64>, Vec<i64>, Vec<i64>, i64);
80
81/// Special token IDs for GLiNER/NuNER models (shared architecture)
82#[cfg(feature = "onnx")]
83const TOKEN_START: u32 = 1;
84#[cfg(feature = "onnx")]
85const TOKEN_END: u32 = 2;
86#[cfg(feature = "onnx")]
87const TOKEN_ENT: u32 = 128002;
88#[cfg(feature = "onnx")]
89const TOKEN_SEP: u32 = 128003;
90
91/// Maximum span width for span-based inference.
92/// NuNER uses max_width=1 (single-word spans only) per its gliner_config.json.
93/// This matches the Python GLiNER implementation's prepare_span_idx function.
94#[cfg(feature = "onnx")]
95const MAX_SPAN_WIDTH: usize = 1;
96
97/// NuNER Zero-shot NER model.
98///
99/// Token-based variant of GLiNER that uses BIO tagging instead of span classification.
100/// This enables arbitrary-length entity extraction without the span window limitation.
101///
102/// # Feature Requirements
103///
104/// Requires the `onnx` feature for actual inference. Without it, configuration
105/// methods work but extraction returns empty results.
106///
107/// # Example
108///
109/// ```rust,ignore
110/// use anno::NuNER;
111///
112/// let ner = NuNER::from_pretrained("deepanwa/NuNerZero_onnx")?;
113/// let entities = ner.extract(
114///     "The CRISPR-Cas9 system was developed by Jennifer Doudna",
115///     &["technology", "scientist"],
116///     0.5
117/// )?;
118/// ```
119pub struct NuNER {
120    /// Model path or identifier
121    model_id: String,
122    /// Confidence threshold (0.0-1.0)
123    threshold: f64,
124    /// Whether model requires span tensors (detected on load)
125    #[cfg(feature = "onnx")]
126    requires_span_tensors: std::sync::atomic::AtomicBool,
127    /// Default entity labels for Model trait
128    default_labels: Vec<String>,
129    /// ONNX session (when feature enabled)
130    #[cfg(feature = "onnx")]
131    session: Option<crate::sync::Mutex<ort::session::Session>>,
132    /// Tokenizer (when feature enabled)
133    #[cfg(feature = "onnx")]
134    tokenizer: Option<tokenizers::Tokenizer>,
135}
136
137mod inference;
138// NuNER ONNX inference: see inference.rs
139impl Default for NuNER {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145impl Model for NuNER {
146    fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
147        if text.trim().is_empty() {
148            return Ok(vec![]);
149        }
150
151        #[cfg(feature = "onnx")]
152        {
153            if self.session.is_some() {
154                let labels: Vec<&str> = self.default_labels.iter().map(|s| s.as_str()).collect();
155                return self.extract(text, &labels, self.threshold as f32);
156            }
157
158            Err(Error::ModelInit(
159                "NuNER model not loaded. Call `NuNER::from_pretrained(...)` (requires `onnx` feature) before calling `extract_entities`.".to_string(),
160            ))
161        }
162
163        #[cfg(not(feature = "onnx"))]
164        {
165            Err(Error::FeatureNotAvailable(
166                "NuNER requires the 'onnx' feature. Build with: cargo build --features onnx"
167                    .to_string(),
168            ))
169        }
170    }
171
172    fn supported_types(&self) -> Vec<EntityType> {
173        self.default_labels
174            .iter()
175            .map(|l| Self::map_label_to_entity_type(l))
176            .collect()
177    }
178
179    fn is_available(&self) -> bool {
180        #[cfg(feature = "onnx")]
181        {
182            self.session.is_some()
183        }
184        #[cfg(not(feature = "onnx"))]
185        {
186            false
187        }
188    }
189
190    fn name(&self) -> &'static str {
191        "nuner"
192    }
193
194    fn description(&self) -> &'static str {
195        "NuNER Zero: Token-based zero-shot NER from NuMind (MIT licensed)"
196    }
197
198    fn version(&self) -> String {
199        format!("nuner-zero-{}", self.model_id)
200    }
201
202    fn capabilities(&self) -> crate::ModelCapabilities {
203        crate::ModelCapabilities {
204            batch_capable: true,
205            streaming_capable: true,
206            dynamic_labels: true,
207            ..Default::default()
208        }
209    }
210}
211
212impl crate::NamedEntityCapable for NuNER {}
213
214#[cfg(feature = "onnx")]
215impl crate::DynamicLabels for NuNER {
216    fn extract_with_labels(
217        &self,
218        text: &str,
219        labels: &[&str],
220        _language: Option<&str>,
221    ) -> crate::Result<Vec<crate::Entity>> {
222        self.extract(text, labels, self.threshold as f32)
223    }
224}
225
226// =============================================================================
227// BatchCapable Trait Implementation
228// =============================================================================
229
230impl crate::BatchCapable for NuNER {
231    fn optimal_batch_size(&self) -> Option<usize> {
232        Some(8)
233    }
234}
235
236// =============================================================================
237// StreamingCapable Trait Implementation
238// =============================================================================
239
240impl crate::StreamingCapable for NuNER {
241    fn recommended_chunk_size(&self) -> usize {
242        4096 // Characters
243    }
244}
245
246#[cfg(test)]
247mod tests;