#![allow(missing_docs)]
#![allow(clippy::type_complexity)] #![allow(clippy::manual_contains)] #![allow(unused_variables)] #![allow(clippy::items_after_test_module)] #![allow(unused_imports)]
use crate::{Entity, Error, Language, Result};
use crate::{EntityCategory, EntityType};
#[cfg(feature = "onnx")]
use std::collections::HashMap;
#[cfg(feature = "onnx")]
use std::sync::Mutex;
const TOKEN_START: u32 = 1;
const TOKEN_END: u32 = 2;
const DEFAULT_TOKEN_ENT: u32 = 128002;
const DEFAULT_TOKEN_SEP: u32 = 128003;
const MAX_SPAN_WIDTH: usize = 12;
#[cfg(feature = "onnx")]
pub mod config;
pub use config::*;
pub struct GLiNEROnnx {
session: Mutex<ort::session::Session>,
tokenizer: std::sync::Arc<tokenizers::Tokenizer>,
model_name: String,
is_quantized: bool,
prompt_cache: Option<Mutex<lru::LruCache<PromptCacheKey, PromptCacheValue>>>,
encoder_mode: config::EncoderMode,
label_cache: Mutex<HashMap<String, config::LabelEmbedding>>,
token_ent: u32,
token_sep: u32,
has_span_inputs: bool,
label_encoder_session: Option<Mutex<ort::session::Session>>,
label_tokenizer: Option<std::sync::Arc<tokenizers::Tokenizer>>,
}
#[cfg(feature = "onnx")]
mod inference;
pub(crate) use inference::expand_ner_label;
#[cfg(feature = "onnx")]
pub(crate) use inference::looks_like_company_name;
use inference::DEFAULT_GLINER_LABELS;
#[cfg(feature = "onnx")]
const MAX_INPUT_CHARS: usize = 2000;
impl crate::Model for GLiNEROnnx {
fn extract_entities(
&self,
text: &str,
_language: Option<Language>,
) -> crate::Result<Vec<Entity>> {
#[cfg(feature = "onnx")]
{
if text.chars().count() > MAX_INPUT_CHARS {
use crate::backends::chunking::{extract_chunked_parallel, ChunkConfig};
let config = ChunkConfig {
chunk_size: MAX_INPUT_CHARS,
overlap: 200,
respect_sentences: true,
buffer_size: 1000,
};
return extract_chunked_parallel(text, &config, |chunk_text, char_offset| {
let mut entities = self.extract(chunk_text, DEFAULT_GLINER_LABELS, 0.5)?;
for e in &mut entities {
e.set_start(e.start() + char_offset);
e.set_end(e.end() + char_offset);
}
Ok(entities)
});
}
}
self.extract(text, DEFAULT_GLINER_LABELS, 0.5)
}
fn supported_types(&self) -> Vec<crate::EntityType> {
DEFAULT_GLINER_LABELS
.iter()
.map(|label| crate::EntityType::Custom {
name: (*label).to_string(),
category: EntityCategory::Misc,
})
.collect()
}
fn is_available(&self) -> bool {
true }
fn name(&self) -> &'static str {
"GLiNER-ONNX"
}
fn description(&self) -> &'static str {
"Zero-shot NER using GLiNER with ONNX Runtime backend"
}
fn version(&self) -> String {
let quant = if self.is_quantized { "q" } else { "fp32" };
let enc = match self.encoder_mode {
config::EncoderMode::Uni => "uni",
config::EncoderMode::Bi => "bi",
};
format!("gliner-onnx-{}-{}-{}", self.model_name, quant, enc)
}
fn as_zero_shot(&self) -> Option<&dyn crate::backends::inference::ZeroShotNER> {
Some(self)
}
}
#[cfg(feature = "onnx")]
impl crate::backends::inference::ZeroShotNER for GLiNEROnnx {
fn extract_with_types(
&self,
text: &str,
entity_types: &[&str],
threshold: f32,
) -> crate::Result<Vec<Entity>> {
self.extract(text, entity_types, threshold)
}
fn extract_with_descriptions(
&self,
text: &str,
descriptions: &[&str],
threshold: f32,
) -> crate::Result<Vec<Entity>> {
self.extract(text, descriptions, threshold)
}
fn default_types(&self) -> &[&'static str] {
DEFAULT_GLINER_LABELS
}
}
crate::backends::macros::define_feature_stub! {
struct GLiNEROnnx;
feature = "onnx";
name = "GLiNER-ONNX (unavailable)";
description = "GLiNER with ONNX Runtime backend - requires 'onnx' feature";
error_msg = "GLiNER-ONNX requires the 'onnx' feature";
methods {
pub fn model_name(&self) -> &str {
"gliner-not-enabled"
}
pub fn extract(
&self,
_text: &str,
_entity_types: &[&str],
_threshold: f32,
) -> crate::Result<Vec<crate::Entity>> {
Err(crate::Error::FeatureNotAvailable(
"GLiNER-ONNX requires the 'onnx' feature".to_string(),
))
}
}
impls {
ZeroShotNER,
}
}
fn remove_overlapping_spans(mut entities: Vec<Entity>) -> Vec<Entity> {
super::chunking::deduplicate_overlapping(
&mut entities,
super::chunking::OverlapStrategy::KeepShortest,
);
entities
}
#[cfg(test)]
mod postprocess_tests;