#![allow(unused_imports)]
#[cfg(feature = "onnx")]
mod inference;
use crate::backends::inference::ZeroShotNER;
use crate::EntityCategory;
use crate::{Entity, EntityType, Error, Language, Result};
const DEFAULT_POLY_LABELS: &[&str] = &[
"person",
"organization",
"location",
"date",
"time",
"money",
"percent",
"product",
"event",
"facility",
];
#[cfg(feature = "onnx")]
fn local_model_cache_candidates() -> [std::path::PathBuf; 2] {
[
crate::env::cache_dir().join("models/gliner-poly"),
dirs::home_dir()
.unwrap_or_default()
.join(".cache/anno/models/gliner-poly"),
]
}
#[cfg(feature = "onnx")]
use std::sync::Mutex;
#[cfg(feature = "onnx")]
pub struct GLiNERPoly {
session: Mutex<ort::session::Session>,
tokenizer: std::sync::Arc<tokenizers::Tokenizer>,
label_tokenizer: std::sync::Arc<tokenizers::Tokenizer>,
model_name: String,
is_quantized: bool,
}
#[cfg(feature = "onnx")]
impl std::fmt::Debug for GLiNERPoly {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GLiNERPoly")
.field("model_name", &self.model_name)
.field("is_quantized", &self.is_quantized)
.finish_non_exhaustive()
}
}
#[cfg(feature = "onnx")]
impl crate::Model for GLiNERPoly {
fn extract_entities(&self, text: &str, _language: Option<Language>) -> Result<Vec<Entity>> {
self.extract(text, DEFAULT_POLY_LABELS, 0.5)
}
fn supported_types(&self) -> Vec<EntityType> {
DEFAULT_POLY_LABELS
.iter()
.map(|label| EntityType::Custom {
name: (*label).to_string(),
category: EntityCategory::Misc,
})
.collect()
}
fn is_available(&self) -> bool {
true
}
fn name(&self) -> &'static str {
"gliner_poly"
}
fn description(&self) -> &'static str {
"Poly-Encoder GLiNER for zero-shot NER with inter-label interactions (ONNX)"
}
fn capabilities(&self) -> crate::ModelCapabilities {
crate::ModelCapabilities {
zero_shot: true,
..Default::default()
}
}
fn version(&self) -> String {
format!(
"gliner-poly-{}-{}",
self.model_name,
if self.is_quantized { "q" } else { "fp32" }
)
}
}
#[cfg(feature = "onnx")]
impl ZeroShotNER for GLiNERPoly {
fn extract_with_types(
&self,
text: &str,
entity_types: &[&str],
threshold: f32,
) -> Result<Vec<Entity>> {
self.extract(text, entity_types, threshold)
}
fn extract_with_descriptions(
&self,
text: &str,
descriptions: &[&str],
threshold: f32,
) -> Result<Vec<Entity>> {
self.extract(text, descriptions, threshold)
}
fn default_types(&self) -> &[&'static str] {
DEFAULT_POLY_LABELS
}
}
#[cfg(not(feature = "onnx"))]
#[derive(Debug)]
pub struct GLiNERPoly {
_private: (),
}
#[cfg(not(feature = "onnx"))]
impl GLiNERPoly {
pub fn new(_model_name: &str) -> Result<Self> {
Err(Error::FeatureNotAvailable(
"GLiNERPoly requires the 'onnx' feature. \
Build with: cargo build --features onnx"
.to_string(),
))
}
}
#[cfg(not(feature = "onnx"))]
impl crate::Model for GLiNERPoly {
fn extract_entities(&self, _text: &str, _language: Option<Language>) -> Result<Vec<Entity>> {
Err(Error::FeatureNotAvailable(
"GLiNERPoly requires the 'onnx' feature".to_string(),
))
}
fn supported_types(&self) -> Vec<EntityType> {
vec![]
}
fn is_available(&self) -> bool {
false
}
fn name(&self) -> &'static str {
"gliner_poly"
}
fn description(&self) -> &'static str {
"Poly-Encoder GLiNER (requires 'onnx' feature)"
}
}
#[cfg(not(feature = "onnx"))]
impl ZeroShotNER for GLiNERPoly {
fn extract_with_types(
&self,
_text: &str,
_entity_types: &[&str],
_threshold: f32,
) -> Result<Vec<Entity>> {
Err(Error::FeatureNotAvailable(
"GLiNERPoly requires the 'onnx' feature".to_string(),
))
}
fn extract_with_descriptions(
&self,
_text: &str,
_descriptions: &[&str],
_threshold: f32,
) -> Result<Vec<Entity>> {
Err(Error::FeatureNotAvailable(
"GLiNERPoly requires the 'onnx' feature".to_string(),
))
}
fn default_types(&self) -> &[&'static str] {
DEFAULT_POLY_LABELS
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(not(feature = "onnx"))]
fn test_gliner_poly_creation_no_onnx() {
let err = GLiNERPoly::new("knowledgator/gliner-bi-large-v1.0").unwrap_err();
assert!(
matches!(err, Error::FeatureNotAvailable(_)),
"expected FeatureNotAvailable, got: {err:?}"
);
}
#[test]
#[cfg(not(feature = "onnx"))]
fn test_gliner_poly_name_stable() {
use crate::Model;
let model = GLiNERPoly { _private: () };
assert_eq!(model.name(), "gliner_poly");
}
#[test]
#[cfg(not(feature = "onnx"))]
fn test_gliner_poly_error_mentions_onnx() {
let err = GLiNERPoly::new("test-model").unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("onnx"),
"error should mention 'onnx', got: {msg}"
);
}
#[test]
#[cfg(not(feature = "onnx"))]
fn test_gliner_poly_supported_types_empty() {
use crate::Model;
let model = GLiNERPoly { _private: () };
assert!(model.supported_types().is_empty());
}
#[test]
#[cfg(not(feature = "onnx"))]
fn test_gliner_poly_is_not_available() {
use crate::Model;
let model = GLiNERPoly { _private: () };
assert!(!model.is_available());
}
#[test]
#[cfg(not(feature = "onnx"))]
fn test_gliner_poly_zero_shot_error() {
let model = GLiNERPoly { _private: () };
let err = model
.extract_with_types("hello", &["person"], 0.5)
.unwrap_err();
assert!(matches!(err, Error::FeatureNotAvailable(_)));
}
#[test]
#[cfg(feature = "onnx")]
fn test_gliner_poly_name_onnx() {
for cache in &local_model_cache_candidates() {
if cache.join("model.onnx").exists() && cache.join("tokenizer.json").exists() {
eprintln!(
"skipping: local gliner-poly cache exists at {}",
cache.display()
);
return;
}
}
let err = GLiNERPoly::new("nonexistent/model-that-does-not-exist").unwrap_err();
assert!(
matches!(err, Error::Retrieval(_)),
"expected Retrieval error, got: {err:?}"
);
}
#[test]
#[cfg(feature = "onnx")]
fn test_gliner_poly_capabilities() {
assert!(!DEFAULT_POLY_LABELS.is_empty());
assert!(DEFAULT_POLY_LABELS.contains(&"person"));
assert!(DEFAULT_POLY_LABELS.contains(&"organization"));
}
#[test]
fn test_default_poly_labels_complete() {
assert!(DEFAULT_POLY_LABELS.contains(&"person"));
assert!(DEFAULT_POLY_LABELS.contains(&"organization"));
assert!(DEFAULT_POLY_LABELS.contains(&"location"));
assert!(DEFAULT_POLY_LABELS.contains(&"date"));
assert!(
DEFAULT_POLY_LABELS.len() >= 8,
"expected >= 8 labels, got {}",
DEFAULT_POLY_LABELS.len()
);
}
#[cfg(feature = "onnx")]
#[test]
fn test_make_span_tensors_empty() {
let (span_idx, span_mask) = GLiNERPoly::make_span_tensors(0);
assert!(span_idx.is_empty());
assert!(span_mask.is_empty());
}
#[cfg(feature = "onnx")]
#[test]
fn test_make_span_tensors_single_word() {
let (span_idx, span_mask) = GLiNERPoly::make_span_tensors(1);
let max_w = inference::MAX_SPAN_WIDTH;
assert_eq!(span_mask.len(), max_w);
assert!(span_mask[0]); for m in &span_mask[1..] {
assert!(!m, "extra span slots should be masked");
}
assert_eq!(span_idx[0], 0);
assert_eq!(span_idx[1], 0);
}
#[cfg(feature = "onnx")]
#[test]
fn test_make_span_tensors_three_words() {
let (span_idx, span_mask) = GLiNERPoly::make_span_tensors(3);
let max_w = inference::MAX_SPAN_WIDTH;
let num_spans = 3 * max_w;
assert_eq!(span_mask.len(), num_spans);
assert_eq!(span_idx.len(), num_spans * 2);
assert!(span_mask[0]); assert!(span_mask[1]); assert!(span_mask[2]);
assert_eq!((span_idx[0], span_idx[1]), (0, 0));
assert_eq!((span_idx[2], span_idx[3]), (0, 1));
assert_eq!((span_idx[4], span_idx[5]), (0, 2));
let base = max_w;
assert!(span_mask[base]); assert!(span_mask[base + 1]); if max_w > 2 {
assert!(!span_mask[base + 2]); }
let base = 2 * max_w;
assert!(span_mask[base]); if max_w > 1 {
assert!(!span_mask[base + 1]);
}
}
#[cfg(feature = "onnx")]
#[test]
fn test_word_span_to_char_offsets_ascii() {
let text = "New York City is great";
let words: Vec<&str> = text.split_whitespace().collect();
let (start, end) = GLiNERPoly::word_span_to_char_offsets(text, &words, 0, 2);
let extracted: String = text.chars().skip(start).take(end - start).collect();
assert_eq!(extracted, "New York City");
}
#[cfg(feature = "onnx")]
#[test]
fn test_word_span_to_char_offsets_unicode() {
let text = "Visit 北京 for tourism";
let words: Vec<&str> = text.split_whitespace().collect();
let (start, end) = GLiNERPoly::word_span_to_char_offsets(text, &words, 1, 1);
let extracted: String = text.chars().skip(start).take(end - start).collect();
assert_eq!(extracted, "北京");
}
#[cfg(feature = "onnx")]
#[test]
fn test_word_span_to_char_offsets_empty_words() {
let (start, end) = GLiNERPoly::word_span_to_char_offsets("hello", &[], 0, 0);
assert_eq!((start, end), (0, 0));
}
#[cfg(feature = "onnx")]
#[test]
fn test_word_span_to_char_offsets_out_of_bounds() {
let text = "hello world";
let words: Vec<&str> = text.split_whitespace().collect();
let (start, end) = GLiNERPoly::word_span_to_char_offsets(text, &words, 5, 10);
assert_eq!((start, end), (0, 0));
}
#[cfg(feature = "onnx")]
#[test]
fn test_word_span_to_char_offsets_inverted() {
let text = "hello world";
let words: Vec<&str> = text.split_whitespace().collect();
let (start, end) = GLiNERPoly::word_span_to_char_offsets(text, &words, 1, 0);
assert_eq!((start, end), (0, 0));
}
#[cfg(feature = "onnx")]
#[test]
fn test_word_span_to_char_offsets_single_word() {
let text = "Steve Jobs founded Apple in California";
let words: Vec<&str> = text.split_whitespace().collect();
let (start, end) = GLiNERPoly::word_span_to_char_offsets(text, &words, 3, 3);
let extracted: String = text.chars().skip(start).take(end - start).collect();
assert_eq!(extracted, "Apple");
}
#[cfg(feature = "onnx")]
#[test]
fn test_local_model_cache_candidates() {
let paths = local_model_cache_candidates();
assert!(paths.len() >= 2, "should have at least 2 candidate paths");
for p in &paths {
assert!(
p.to_string_lossy().contains("gliner-poly") || p.to_string_lossy().contains("anno"),
"cache path should reference gliner-poly or anno: {:?}",
p
);
}
}
}