use crate::backends::hf_loader;
use crate::backends::inference::RelationTriple;
use crate::{Confidence, Entity, Error, Result};
use ndarray::Array2;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
fn default_model_dir() -> PathBuf {
dirs::cache_dir()
.unwrap_or_else(|| PathBuf::from(".cache"))
.join("anno")
.join("models")
.join("glirel")
}
#[derive(Debug)]
pub struct GLiREL {
session: Mutex<ort::session::Session>,
tokenizer: tokenizers::Tokenizer,
config: GLiRELConfig,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct GLiRELConfig {
#[serde(default)]
pub model_name: String,
#[serde(default = "default_hidden_size")]
pub hidden_size: usize,
#[serde(default = "default_max_width")]
pub max_width: usize,
}
fn default_hidden_size() -> usize {
1024
}
fn default_max_width() -> usize {
12
}
impl Default for GLiRELConfig {
fn default() -> Self {
Self {
model_name: "jackboyla/glirel-large-v0".to_string(),
hidden_size: 1024,
max_width: 12,
}
}
}
#[derive(Debug, Clone)]
pub struct ScoredRelation {
pub head_idx: usize,
pub tail_idx: usize,
pub relation_type: String,
pub confidence: Confidence,
}
impl GLiREL {
pub fn from_pretrained(model_id: &str) -> Result<Self> {
let api = hf_loader::hf_api()?;
let repo = api.model(model_id.to_string());
let model_path = hf_loader::download_model_file(&repo, &["onnx/model.onnx", "model.onnx"])?;
let tokenizer_path = hf_loader::download_model_file(&repo, &["tokenizer.json"])?;
let config = match repo.get("glirel_config.json") {
Ok(config_path) => {
let data = std::fs::read_to_string(&config_path)
.map_err(|e| Error::Retrieval(format!("glirel config read: {e}")))?;
serde_json::from_str(&data)
.map_err(|e| Error::Parse(format!("glirel config parse: {e}")))?
}
Err(_) => GLiRELConfig {
model_name: model_id.to_string(),
..GLiRELConfig::default()
},
};
let tokenizer = hf_loader::load_tokenizer(&tokenizer_path)?;
let session =
hf_loader::create_onnx_session(&model_path, hf_loader::OnnxSessionConfig::default())?;
log::info!(
"[GLiREL] Loaded {} (config.model_name={}, hidden={}, max_width={})",
model_id,
config.model_name,
config.hidden_size,
config.max_width
);
Ok(Self {
session: Mutex::new(session),
tokenizer,
config,
})
}
pub fn from_local(dir: &Path) -> Result<Self> {
let model_path = dir.join("model.onnx");
if !model_path.exists() {
let default_dir = default_model_dir();
let alt_path = default_dir.join("model.onnx");
if alt_path.exists() {
return Self::from_local(&default_dir);
}
return Err(Error::Retrieval(format!(
"GLiREL model not found at {}. Export it with: uv run scripts/export_glirel_onnx.py",
model_path.display()
)));
}
let tokenizer_path = dir.join("tokenizer.json");
if !tokenizer_path.exists() {
return Err(Error::Retrieval(format!(
"Tokenizer not found at {}",
tokenizer_path.display()
)));
}
let config = {
let config_path = dir.join("glirel_config.json");
if config_path.exists() {
let data = std::fs::read_to_string(&config_path)
.map_err(|e| Error::Retrieval(format!("glirel config read: {e}")))?;
serde_json::from_str(&data)
.map_err(|e| Error::Parse(format!("glirel config parse: {e}")))?
} else {
GLiRELConfig::default()
}
};
let tokenizer = hf_loader::load_tokenizer(&tokenizer_path)?;
let session =
hf_loader::create_onnx_session(&model_path, hf_loader::OnnxSessionConfig::default())?;
log::info!("[GLiREL] Loaded from {}", dir.display());
Ok(Self {
session: Mutex::new(session),
tokenizer,
config,
})
}
pub fn extract_relations(
&self,
text: &str,
entities: &[Entity],
relation_types: &[&str],
threshold: f32,
) -> Result<Vec<RelationTriple>> {
if entities.len() < 2 || relation_types.is_empty() || text.is_empty() {
return Ok(Vec::new());
}
let words: Vec<&str> = text.split_whitespace().collect();
if words.is_empty() {
return Ok(Vec::new());
}
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| Error::Inference(format!("GLiREL tokenize: {e}")))?;
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
let attention_mask: Vec<i64> = encoding
.get_attention_mask()
.iter()
.map(|&m| m as i64)
.collect();
let seq_len = input_ids.len();
let words_mask = self.build_words_mask(&encoding, &words);
let text_lengths = vec![words.len() as i64];
let span_idx = self.entities_to_word_spans(text, &words, entities);
let num_spans = span_idx.len();
let span_mask: Vec<bool> = vec![true; num_spans];
let (rel_input_ids, rel_attention_mask, rel_seq_len) =
self.encode_relation_labels(relation_types)?;
let num_relations = relation_types.len();
let input_ids_arr = Array2::from_shape_vec((1, seq_len), input_ids)
.map_err(|e| Error::Parse(format!("input_ids array: {e}")))?;
let attention_mask_arr = Array2::from_shape_vec((1, seq_len), attention_mask)
.map_err(|e| Error::Parse(format!("attention_mask array: {e}")))?;
let words_mask_arr = Array2::from_shape_vec((1, seq_len), words_mask)
.map_err(|e| Error::Parse(format!("words_mask array: {e}")))?;
let text_lengths_arr = Array2::from_shape_vec((1, 1), text_lengths)
.map_err(|e| Error::Parse(format!("text_lengths array: {e}")))?;
let span_flat: Vec<i64> = span_idx.iter().flat_map(|&(s, e)| [s, e]).collect();
let span_idx_arr = ndarray::Array3::from_shape_vec((1, num_spans, 2), span_flat)
.map_err(|e| Error::Parse(format!("span_idx array: {e}")))?;
let span_mask_i64: Vec<i64> = span_mask.iter().map(|&b| if b { 1 } else { 0 }).collect();
let span_mask_arr = Array2::from_shape_vec((1, num_spans), span_mask_i64)
.map_err(|e| Error::Parse(format!("span_mask array: {e}")))?;
let rel_ids_arr = Array2::from_shape_vec((num_relations, rel_seq_len), rel_input_ids)
.map_err(|e| Error::Parse(format!("rel_input_ids array: {e}")))?;
let rel_mask_arr = Array2::from_shape_vec((num_relations, rel_seq_len), rel_attention_mask)
.map_err(|e| Error::Parse(format!("rel_attention_mask array: {e}")))?;
use super::super::ort_compat::tensor_from_ndarray;
let t_input_ids = tensor_from_ndarray(input_ids_arr)
.map_err(|e| Error::Inference(format!("tensor input_ids: {e}")))?;
let t_attention_mask = tensor_from_ndarray(attention_mask_arr)
.map_err(|e| Error::Inference(format!("tensor attention_mask: {e}")))?;
let t_words_mask = tensor_from_ndarray(words_mask_arr)
.map_err(|e| Error::Inference(format!("tensor words_mask: {e}")))?;
let t_text_lengths = tensor_from_ndarray(text_lengths_arr)
.map_err(|e| Error::Inference(format!("tensor text_lengths: {e}")))?;
let t_span_idx = tensor_from_ndarray(span_idx_arr)
.map_err(|e| Error::Inference(format!("tensor span_idx: {e}")))?;
let t_span_mask = tensor_from_ndarray(span_mask_arr)
.map_err(|e| Error::Inference(format!("tensor span_mask: {e}")))?;
let t_rel_ids = tensor_from_ndarray(rel_ids_arr)
.map_err(|e| Error::Inference(format!("tensor rel_input_ids: {e}")))?;
let t_rel_mask = tensor_from_ndarray(rel_mask_arr)
.map_err(|e| Error::Inference(format!("tensor rel_attention_mask: {e}")))?;
let mut session = self.session.lock().unwrap_or_else(|e| e.into_inner());
let outputs = session
.run(ort::inputs![
"input_ids" => t_input_ids.into_dyn(),
"attention_mask" => t_attention_mask.into_dyn(),
"words_mask" => t_words_mask.into_dyn(),
"text_lengths" => t_text_lengths.into_dyn(),
"span_idx" => t_span_idx.into_dyn(),
"span_mask" => t_span_mask.into_dyn(),
"rel_label_input_ids" => t_rel_ids.into_dyn(),
"rel_label_attention_mask" => t_rel_mask.into_dyn(),
])
.map_err(|e| Error::Inference(format!("GLiREL ONNX run: {e}")))?;
let scores_output = outputs
.get("relation_scores")
.ok_or_else(|| Error::Inference("Missing relation_scores output".to_string()))?;
let (shape, scores_data) = scores_output
.try_extract_tensor::<f32>()
.map_err(|e| Error::Inference(format!("extract relation_scores: {e}")))?;
if shape.len() != 4 {
return Err(Error::Inference(format!(
"Unexpected relation_scores shape: {:?}",
shape
)));
}
let stride_head = num_spans * num_relations;
let mut relations = Vec::new();
for head_idx in 0..num_spans {
for tail_idx in 0..num_spans {
if head_idx == tail_idx {
continue;
}
for (rel_idx, rel_type) in relation_types.iter().enumerate() {
let flat_idx = head_idx * stride_head + tail_idx * num_relations + rel_idx;
let raw_score = scores_data[flat_idx];
let conf_f32 = sigmoid(raw_score);
if conf_f32 >= threshold {
relations.push(RelationTriple {
head_idx,
tail_idx,
relation_type: rel_type.to_string(),
confidence: Confidence::new(conf_f32 as f64),
});
}
}
}
}
relations.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut seen = std::collections::HashSet::new();
relations.retain(|r| seen.insert((r.head_idx, r.tail_idx)));
Ok(relations)
}
fn build_words_mask(&self, encoding: &tokenizers::Encoding, words: &[&str]) -> Vec<i64> {
let seq_len = encoding.get_ids().len();
let mut mask = vec![0i64; seq_len];
for (token_idx, word_id) in encoding.get_word_ids().iter().enumerate() {
if let Some(wid) = word_id {
if (*wid as usize) < words.len() {
mask[token_idx] = (*wid as i64) + 1;
}
}
}
mask
}
fn entities_to_word_spans(
&self,
text: &str,
words: &[&str],
entities: &[Entity],
) -> Vec<(i64, i64)> {
let mut word_starts = Vec::with_capacity(words.len());
let mut byte_pos = 0;
let chars: Vec<char> = text.chars().collect();
for word in words {
if let Some(pos) = text[byte_pos..].find(word) {
let abs_byte = byte_pos + pos;
let char_offset = text[..abs_byte].chars().count();
let char_end = char_offset + word.chars().count();
word_starts.push((char_offset, char_end));
byte_pos = abs_byte + word.len();
} else {
let char_offset = if word_starts.is_empty() {
0
} else {
word_starts.last().map(|&(_, e)| e).unwrap_or(0)
};
word_starts.push((char_offset, char_offset + word.chars().count()));
}
}
let _ = chars;
let max_width = self.config.max_width.max(1);
entities
.iter()
.map(|ent| {
let mut best_start = 0i64;
let mut best_end = 0i64;
let mut found = false;
for (word_idx, &(ws, we)) in word_starts.iter().enumerate() {
if we > ent.start() && ws < ent.end() {
if !found {
best_start = word_idx as i64;
found = true;
}
best_end = word_idx as i64;
}
}
if found {
let width = (best_end - best_start + 1) as usize;
if width > max_width {
best_end = best_start + (max_width as i64) - 1;
}
}
(best_start, best_end)
})
.collect()
}
fn encode_relation_labels(&self, labels: &[&str]) -> Result<(Vec<i64>, Vec<i64>, usize)> {
let encodings: Vec<_> = labels
.iter()
.map(|label| {
self.tokenizer
.encode(*label, true)
.map_err(|e| Error::Inference(format!("GLiREL encode label '{label}': {e}")))
})
.collect::<Result<Vec<_>>>()?;
let max_len = encodings
.iter()
.map(|e| e.get_ids().len())
.max()
.unwrap_or(1);
let mut all_ids = Vec::with_capacity(labels.len() * max_len);
let mut all_masks = Vec::with_capacity(labels.len() * max_len);
for enc in &encodings {
let ids = enc.get_ids();
let masks = enc.get_attention_mask();
for &id in ids {
all_ids.push(id as i64);
}
for &m in masks {
all_masks.push(m as i64);
}
let pad = max_len - ids.len();
all_ids.extend(std::iter::repeat_n(0i64, pad));
all_masks.extend(std::iter::repeat_n(0i64, pad));
}
Ok((all_ids, all_masks, max_len))
}
pub fn scored_to_triples(scored: Vec<ScoredRelation>) -> Vec<RelationTriple> {
scored
.into_iter()
.map(|sr| RelationTriple {
head_idx: sr.head_idx,
tail_idx: sr.tail_idx,
relation_type: sr.relation_type,
confidence: sr.confidence,
})
.collect()
}
}
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sigmoid() {
assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
assert!(sigmoid(10.0) > 0.99);
assert!(sigmoid(-10.0) < 0.01);
}
#[test]
fn test_config_defaults() {
let config = GLiRELConfig::default();
assert_eq!(config.hidden_size, 1024);
assert_eq!(config.max_width, 12);
assert_eq!(config.model_name, "jackboyla/glirel-large-v0");
}
#[test]
fn test_config_deserialization_with_defaults() {
let json = r#"{"model_name": "test-model"}"#;
let config: GLiRELConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.model_name, "test-model");
assert_eq!(config.hidden_size, 1024); assert_eq!(config.max_width, 12); }
#[test]
fn test_config_deserialization_full() {
let json = r#"{"model_name": "custom", "hidden_size": 768, "max_width": 8}"#;
let config: GLiRELConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.model_name, "custom");
assert_eq!(config.hidden_size, 768);
assert_eq!(config.max_width, 8);
}
#[test]
fn test_config_deserialization_empty() {
let json = r#"{}"#;
let config: GLiRELConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.model_name, "");
assert_eq!(config.hidden_size, 1024);
assert_eq!(config.max_width, 12);
}
#[test]
fn test_sigmoid_boundary_values() {
assert!((sigmoid(100.0) - 1.0).abs() < 1e-6);
assert!(sigmoid(-100.0).abs() < 1e-6);
let s_pos = sigmoid(2.5);
let s_neg = sigmoid(-2.5);
assert!((s_pos + s_neg - 1.0).abs() < 1e-6);
}
#[test]
fn test_scored_to_triples() {
let scored = vec![
ScoredRelation {
head_idx: 0,
tail_idx: 1,
relation_type: "founded".to_string(),
confidence: Confidence::new(0.95),
},
ScoredRelation {
head_idx: 1,
tail_idx: 0,
relation_type: "works_for".to_string(),
confidence: Confidence::new(0.6),
},
];
let triples = GLiREL::scored_to_triples(scored);
assert_eq!(triples.len(), 2);
assert_eq!(triples[0].head_idx, 0);
assert_eq!(triples[0].tail_idx, 1);
assert_eq!(triples[0].relation_type, "founded");
assert_eq!(triples[1].relation_type, "works_for");
}
#[test]
fn test_scored_to_triples_empty() {
let triples = GLiREL::scored_to_triples(vec![]);
assert!(triples.is_empty());
}
#[test]
fn test_extract_relations_empty_inputs() {
let sr = ScoredRelation {
head_idx: 0,
tail_idx: 1,
relation_type: "test".to_string(),
confidence: Confidence::new(0.0),
};
assert_eq!(sr.head_idx, 0);
assert_eq!(sr.tail_idx, 1);
}
#[test]
fn test_from_local_nonexistent() {
let result = GLiREL::from_local(std::path::Path::new("/nonexistent/path"));
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("model not found") || err.contains("not found"));
}
#[test]
fn test_from_local_missing_tokenizer() {
let dir = std::env::temp_dir().join("anno_test_glirel_no_tok");
let _ = std::fs::create_dir_all(&dir);
std::fs::write(dir.join("model.onnx"), b"dummy").unwrap();
let result = GLiREL::from_local(&dir);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Tokenizer not found"),
"expected tokenizer error, got: {err}"
);
let _ = std::fs::remove_dir_all(&dir);
}
}