use crate::{Entity, EntityCategory, EntityType, Language, Model, Result};
use crate::Error;
#[cfg(feature = "onnx")]
type EncodedPrompt = (Vec<i64>, Vec<i64>, Vec<i64>, i64);
#[cfg(feature = "onnx")]
const TOKEN_START: u32 = 1;
#[cfg(feature = "onnx")]
const TOKEN_END: u32 = 2;
#[cfg(feature = "onnx")]
const TOKEN_ENT: u32 = 128002;
#[cfg(feature = "onnx")]
const TOKEN_SEP: u32 = 128003;
#[cfg(feature = "onnx")]
const MAX_SPAN_WIDTH: usize = 1;
pub struct NuNER {
model_id: String,
threshold: f64,
#[cfg(feature = "onnx")]
requires_span_tensors: std::sync::atomic::AtomicBool,
default_labels: Vec<String>,
max_input_chars: usize,
#[cfg(feature = "onnx")]
session: Option<std::sync::Mutex<ort::session::Session>>,
#[cfg(feature = "onnx")]
tokenizer: Option<tokenizers::Tokenizer>,
}
mod inference;
impl Default for NuNER {
fn default() -> Self {
Self::new()
}
}
const MAX_INPUT_CHARS_512: usize = 2000;
const MAX_INPUT_CHARS_4K: usize = 16000;
impl Model for NuNER {
fn extract_entities(&self, text: &str, _language: Option<Language>) -> Result<Vec<Entity>> {
if text.trim().is_empty() {
return Ok(vec![]);
}
#[cfg(feature = "onnx")]
{
if self.session.is_some() {
let labels: Vec<&str> = self.default_labels.iter().map(|s| s.as_str()).collect();
let threshold = self.threshold as f32;
let max_chars = self.max_input_chars;
if text.chars().count() > max_chars {
use crate::backends::chunking::{extract_chunked_parallel, ChunkConfig};
let config = ChunkConfig {
chunk_size: max_chars,
overlap: 200,
respect_sentences: true,
buffer_size: 1000,
};
return extract_chunked_parallel(text, &config, |chunk_text, char_offset| {
let mut entities = self.extract(chunk_text, &labels, threshold)?;
for e in &mut entities {
e.set_start(e.start() + char_offset);
e.set_end(e.end() + char_offset);
}
Ok(entities)
});
}
return self.extract(text, &labels, threshold);
}
Err(Error::ModelInit(
"NuNER model not loaded. Call `NuNER::from_pretrained(...)` (requires `onnx` feature) before calling `extract_entities`.".to_string(),
))
}
#[cfg(not(feature = "onnx"))]
{
Err(Error::FeatureNotAvailable(
"NuNER requires the 'onnx' feature. Build with: cargo build --features onnx"
.to_string(),
))
}
}
fn supported_types(&self) -> Vec<EntityType> {
self.default_labels
.iter()
.map(|l| Self::map_label_to_entity_type(l))
.collect()
}
fn is_available(&self) -> bool {
#[cfg(feature = "onnx")]
{
self.session.is_some()
}
#[cfg(not(feature = "onnx"))]
{
false
}
}
fn name(&self) -> &'static str {
"nuner"
}
fn description(&self) -> &'static str {
"NuNER Zero: Token-based zero-shot NER from NuMind (MIT licensed)"
}
fn version(&self) -> String {
format!("nuner-zero-{}", self.model_id)
}
fn capabilities(&self) -> crate::ModelCapabilities {
crate::ModelCapabilities {
zero_shot: true,
..Default::default()
}
}
fn as_zero_shot(&self) -> Option<&dyn crate::backends::inference::ZeroShotNER> {
#[cfg(feature = "onnx")]
{
Some(self)
}
#[cfg(not(feature = "onnx"))]
{
None
}
}
}
#[cfg(feature = "onnx")]
impl crate::backends::inference::ZeroShotNER for NuNER {
fn extract_with_types(
&self,
text: &str,
entity_types: &[&str],
threshold: f32,
) -> crate::Result<Vec<crate::Entity>> {
self.extract(text, entity_types, threshold)
}
fn extract_with_descriptions(
&self,
text: &str,
descriptions: &[&str],
threshold: f32,
) -> crate::Result<Vec<crate::Entity>> {
self.extract(text, descriptions, threshold)
}
fn default_types(&self) -> &[&'static str] {
&["person", "organization", "location", "date", "event"]
}
}
#[cfg(test)]
mod tests;