use crate::{Entity, Error, Result};
type MentionList = (String, Vec<(String, usize, usize)>);
use ndarray::{Array2, Array3};
use std::collections::HashMap;
use std::sync::Arc;
use ort::session::Session;
use tokenizers::Tokenizer;
pub use super::resolve::CorefCluster;
#[derive(Debug, Clone)]
pub struct T5CorefConfig {
pub max_input_length: usize,
pub max_output_length: usize,
pub num_beams: usize,
pub optimization_level: u8,
pub num_threads: usize,
}
impl Default for T5CorefConfig {
fn default() -> Self {
Self {
max_input_length: 512,
max_output_length: 512,
num_beams: 1, optimization_level: 3,
num_threads: 4,
}
}
}
pub struct T5Coref {
encoder: std::sync::Mutex<Session>,
decoder: std::sync::Mutex<Session>,
tokenizer: Arc<Tokenizer>,
config: T5CorefConfig,
model_path: String,
}
impl T5Coref {
pub fn from_path(model_path: &str, config: T5CorefConfig) -> Result<Self> {
let encoder_path = format!("{}/encoder_model.onnx", model_path);
let decoder_path = format!("{}/decoder_model.onnx", model_path);
let tokenizer_path = format!("{}/tokenizer.json", model_path);
if !std::path::Path::new(&encoder_path).exists() {
return Err(Error::Retrieval(format!(
"Encoder not found at {}. Export with: optimum-cli export onnx --model <model> --task text2text-generation-with-past {}",
encoder_path, model_path
)));
}
use crate::backends::hf_loader;
let sess_config = hf_loader::OnnxSessionConfig {
optimization_level: config.optimization_level,
num_threads: config.num_threads,
use_cpu_provider: true,
..Default::default()
};
let encoder = hf_loader::create_onnx_session(
std::path::Path::new(&encoder_path),
sess_config.clone(),
)?;
let decoder =
hf_loader::create_onnx_session(std::path::Path::new(&decoder_path), sess_config)?;
let tokenizer = hf_loader::load_tokenizer(std::path::Path::new(&tokenizer_path))?;
log::info!("[T5-Coref] Loaded model from {}", model_path);
Ok(Self {
encoder: std::sync::Mutex::new(encoder),
decoder: std::sync::Mutex::new(decoder),
tokenizer: Arc::new(tokenizer),
config,
model_path: model_path.to_string(),
})
}
pub fn from_pretrained(model_id: &str) -> Result<Self> {
Self::from_pretrained_with_config(model_id, T5CorefConfig::default())
}
pub fn from_pretrained_with_config(model_id: &str, config: T5CorefConfig) -> Result<Self> {
use crate::backends::hf_loader;
let api = hf_loader::hf_api()?;
let repo = api.model(model_id.to_string());
let encoder_path = hf_loader::download_model_file(
&repo,
&["encoder_model.onnx", "onnx/encoder_model.onnx"],
)?;
let decoder_path = hf_loader::download_model_file(
&repo,
&[
"decoder_model.onnx",
"onnx/decoder_model.onnx",
"decoder_with_past_model.onnx",
],
)?;
let tokenizer_path = hf_loader::download_model_file(&repo, &["tokenizer.json"])?;
let sess_config = hf_loader::OnnxSessionConfig {
optimization_level: config.optimization_level,
num_threads: config.num_threads,
use_cpu_provider: true,
..Default::default()
};
let encoder = hf_loader::create_onnx_session(&encoder_path, sess_config.clone())?;
let decoder = hf_loader::create_onnx_session(&decoder_path, sess_config)?;
let tokenizer = hf_loader::load_tokenizer(&tokenizer_path)?;
log::info!("[T5-Coref] Loaded model from {}", model_id);
Ok(Self {
encoder: std::sync::Mutex::new(encoder),
decoder: std::sync::Mutex::new(decoder),
tokenizer: Arc::new(tokenizer),
config,
model_path: model_id.to_string(),
})
}
pub fn resolve(&self, text: &str) -> Result<Vec<CorefCluster>> {
if text.is_empty() {
return Ok(vec![]);
}
match self.resolve_t5(text) {
Ok(clusters) if !clusters.is_empty() => Ok(clusters),
Ok(_) => {
log::debug!("[T5-Coref] inference produced no clusters, using heuristic fallback");
self.resolve_simple(text)
}
Err(e) => {
log::warn!(
"[T5-Coref] inference failed ({}), using heuristic fallback",
e
);
self.resolve_simple(text)
}
}
}
fn resolve_t5(&self, text: &str) -> Result<Vec<CorefCluster>> {
let marked = self.mark_mentions(text);
let (input_ids, attention_mask) = self.tokenize_input(&marked)?;
let (enc_hidden, enc_seq_len, hidden_size) =
self.run_encoder(&input_ids, &attention_mask)?;
let output_ids =
self.greedy_decode(&enc_hidden, enc_seq_len, hidden_size, &attention_mask)?;
let decoded = self.decode_tokens(&output_ids)?;
Ok(self.parse_coref_output(&decoded))
}
fn mark_mentions(&self, text: &str) -> String {
mark_mentions_for_t5(text)
}
fn tokenize_input(&self, text: &str) -> Result<(Vec<i64>, Vec<i64>)> {
let mut enc = self
.tokenizer
.encode(text, true)
.map_err(|e| Error::Parse(format!("T5Coref tokenizer encode: {e}")))?;
enc.truncate(
self.config.max_input_length,
0,
tokenizers::TruncationDirection::Right,
);
let input_ids: Vec<i64> = enc.get_ids().iter().map(|&x| x as i64).collect();
let attention_mask: Vec<i64> = enc.get_attention_mask().iter().map(|&x| x as i64).collect();
Ok((input_ids, attention_mask))
}
fn run_encoder(
&self,
input_ids: &[i64],
attention_mask: &[i64],
) -> Result<(Vec<f32>, usize, usize)> {
let batch = 1usize;
let seq_len = input_ids.len();
let ids_arr = Array2::<i64>::from_shape_vec((batch, seq_len), input_ids.to_vec())
.map_err(|e| Error::Parse(format!("encoder ids shape: {e}")))?;
let mask_arr = Array2::<i64>::from_shape_vec((batch, seq_len), attention_mask.to_vec())
.map_err(|e| Error::Parse(format!("encoder mask shape: {e}")))?;
let ids_t = super::super::ort_compat::tensor_from_ndarray(ids_arr)
.map_err(|e| Error::Parse(format!("encoder ids tensor: {e}")))?;
let mask_t = super::super::ort_compat::tensor_from_ndarray(mask_arr)
.map_err(|e| Error::Parse(format!("encoder mask tensor: {e}")))?;
let (hidden_flat, hidden_size) = {
let mut enc = self.encoder.lock().unwrap_or_else(|e| e.into_inner());
let outputs = enc
.run(ort::inputs![
"input_ids" => ids_t.into_dyn(),
"attention_mask" => mask_t.into_dyn(),
])
.map_err(|e| Error::Parse(format!("T5Coref encoder run: {e}")))?;
let hidden_val = outputs.get("last_hidden_state").ok_or_else(|| {
Error::Parse(
"T5 encoder output 'last_hidden_state' not found; check ONNX export".into(),
)
})?;
let (shape, data) = hidden_val
.try_extract_tensor::<f32>()
.map_err(|e| Error::Parse(format!("encoder extract tensor: {e}")))?;
if shape.len() != 3 || shape[0] != 1 {
return Err(Error::Parse(format!(
"T5 encoder: unexpected hidden-state shape {:?}",
shape
)));
}
(data.to_vec(), shape[2] as usize)
}; Ok((hidden_flat, seq_len, hidden_size))
}
fn decoder_step(
&self,
encoder_hidden: &[f32],
enc_seq_len: usize,
hidden_size: usize,
attention_mask: &[i64],
decoder_input_ids: &[i64],
) -> Result<i64> {
let batch = 1usize;
let dec_len = decoder_input_ids.len();
let enc_h = Array3::<f32>::from_shape_vec(
(batch, enc_seq_len, hidden_size),
encoder_hidden.to_vec(),
)
.map_err(|e| Error::Parse(format!("decoder enc_hidden shape: {e}")))?;
let attn = Array2::<i64>::from_shape_vec((batch, enc_seq_len), attention_mask.to_vec())
.map_err(|e| Error::Parse(format!("decoder attn shape: {e}")))?;
let dec_ids = Array2::<i64>::from_shape_vec((batch, dec_len), decoder_input_ids.to_vec())
.map_err(|e| Error::Parse(format!("decoder_ids shape: {e}")))?;
let enc_h_t = super::super::ort_compat::tensor_from_ndarray(enc_h)
.map_err(|e| Error::Parse(format!("enc_h tensor: {e}")))?;
let attn_t = super::super::ort_compat::tensor_from_ndarray(attn)
.map_err(|e| Error::Parse(format!("attn tensor: {e}")))?;
let dec_ids_t = super::super::ort_compat::tensor_from_ndarray(dec_ids)
.map_err(|e| Error::Parse(format!("dec_ids tensor: {e}")))?;
let next_token = {
let mut dec = self.decoder.lock().unwrap_or_else(|e| e.into_inner());
let outputs = dec
.run(ort::inputs![
"encoder_hidden_states" => enc_h_t.into_dyn(),
"attention_mask" => attn_t.into_dyn(),
"decoder_input_ids" => dec_ids_t.into_dyn(),
])
.map_err(|e| Error::Parse(format!("T5Coref decoder run: {e}")))?;
let logits_val = outputs.get("logits").ok_or_else(|| {
Error::Parse("T5 decoder output 'logits' not found; check ONNX export".into())
})?;
let (shape, logits_data) = logits_val
.try_extract_tensor::<f32>()
.map_err(|e| Error::Parse(format!("decoder logits extract: {e}")))?;
if shape.len() != 3 || shape[0] != 1 {
return Err(Error::Parse(format!(
"T5 decoder: unexpected logits shape {:?}",
shape
)));
}
let vocab_size = shape[2] as usize;
let last_offset = (dec_len - 1) * vocab_size;
let last_logits = &logits_data[last_offset..last_offset + vocab_size];
last_logits
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i as i64)
.unwrap_or(1) }; Ok(next_token)
}
fn greedy_decode(
&self,
encoder_hidden: &[f32],
enc_seq_len: usize,
hidden_size: usize,
attention_mask: &[i64],
) -> Result<Vec<i64>> {
const T5_PAD: i64 = 0;
const T5_EOS: i64 = 1;
let mut generated = vec![T5_PAD];
for _ in 0..self.config.max_output_length {
let next = self.decoder_step(
encoder_hidden,
enc_seq_len,
hidden_size,
attention_mask,
&generated,
)?;
if next == T5_EOS {
break;
}
generated.push(next);
}
Ok(generated[1..].to_vec()) }
fn decode_tokens(&self, token_ids: &[i64]) -> Result<String> {
let ids: Vec<u32> = token_ids.iter().map(|&x| x as u32).collect();
self.tokenizer
.decode(&ids, true)
.map_err(|e| Error::Parse(format!("T5Coref decode_tokens: {e}")))
}
fn parse_coref_output(&self, decoded: &str) -> Vec<CorefCluster> {
parse_t5_coref_output(decoded)
}
pub fn resolve_marked(&self, marked_text: &str) -> Result<Vec<CorefCluster>> {
let (plain_text, mentions) = self.extract_mentions(marked_text)?;
if mentions.is_empty() {
return Ok(vec![]);
}
match self.resolve_t5_raw(marked_text) {
Ok(clusters) if !clusters.is_empty() => Ok(clusters),
Ok(_) => self.cluster_mentions(&plain_text, &mentions),
Err(e) => {
log::warn!(
"[T5-Coref] resolve_marked inference failed ({}), using fallback",
e
);
self.cluster_mentions(&plain_text, &mentions)
}
}
}
pub fn resolve_entities(&self, text: &str, entities: &[Entity]) -> Result<Vec<CorefCluster>> {
if entities.is_empty() {
return Ok(vec![]);
}
let marked = self.mark_entity_spans(text, entities);
match self.resolve_t5_raw(&marked) {
Ok(clusters) if !clusters.is_empty() => Ok(clusters),
Ok(_) => {
let mentions: Vec<(String, usize, usize)> = entities
.iter()
.map(|e| (e.text.clone(), e.start(), e.end()))
.collect();
self.cluster_mentions(text, &mentions)
}
Err(e) => {
log::warn!(
"[T5-Coref] resolve_entities inference failed ({}), using fallback",
e
);
let mentions: Vec<(String, usize, usize)> = entities
.iter()
.map(|e| (e.text.clone(), e.start(), e.end()))
.collect();
self.cluster_mentions(text, &mentions)
}
}
}
fn resolve_t5_raw(&self, marked_text: &str) -> Result<Vec<CorefCluster>> {
let (input_ids, attention_mask) = self.tokenize_input(marked_text)?;
let (enc_hidden, enc_seq_len, hidden_size) =
self.run_encoder(&input_ids, &attention_mask)?;
let output_ids =
self.greedy_decode(&enc_hidden, enc_seq_len, hidden_size, &attention_mask)?;
let decoded = self.decode_tokens(&output_ids)?;
Ok(self.parse_coref_output(&decoded))
}
fn mark_entity_spans(&self, text: &str, entities: &[Entity]) -> String {
let chars: Vec<char> = text.chars().collect();
let char_len = chars.len();
let mut sorted: Vec<&Entity> = entities.iter().collect();
sorted.sort_by_key(|e| e.start());
let mut out = String::with_capacity(text.len() + entities.len() * 10);
let mut cursor = 0usize;
for e in &sorted {
if e.start() >= e.end() || e.start() < cursor || e.end() > char_len {
continue;
}
for &ch in &chars[cursor..e.start()] {
out.push(ch);
}
out.push_str("<m> ");
for &ch in &chars[e.start()..e.end()] {
out.push(ch);
}
out.push_str(" </m>");
cursor = e.end();
}
for &ch in &chars[cursor..] {
out.push(ch);
}
out
}
fn resolve_simple(&self, text: &str) -> Result<Vec<CorefCluster>> {
let pronouns = ["he", "she", "they", "it", "his", "her", "their", "its"];
let words: Vec<(String, usize, usize)> = {
let mut result = Vec::new();
let mut pos = 0;
for word in text.split_whitespace() {
if let Some(start) = text[pos..].find(word) {
let abs_start = pos + start;
result.push((word.to_string(), abs_start, abs_start + word.len()));
pos = abs_start + word.len();
}
}
result
};
let antecedents: Vec<&(String, usize, usize)> = words
.iter()
.filter(|(w, _, _)| {
w.chars().next().map(|c| c.is_uppercase()).unwrap_or(false)
&& !pronouns.contains(&w.to_lowercase().as_str())
})
.collect();
let pronoun_mentions: Vec<&(String, usize, usize)> = words
.iter()
.filter(|(w, _, _)| pronouns.contains(&w.to_lowercase().as_str()))
.collect();
let mut clusters: Vec<CorefCluster> = Vec::new();
let mut assigned: HashMap<usize, u32> = HashMap::new();
for (ant_text, ant_start, ant_end) in &antecedents {
if assigned.contains_key(ant_start) {
continue;
}
let cluster_id = clusters.len() as u32;
let mut mentions = vec![ant_text.clone()];
let mut spans = vec![(*ant_start, *ant_end)];
assigned.insert(*ant_start, cluster_id);
for (pro_text, pro_start, pro_end) in &pronoun_mentions {
if *pro_start > *ant_end && !assigned.contains_key(pro_start) {
let compatible = match pro_text.to_lowercase().as_str() {
"he" | "him" | "his" => true, "she" | "her" | "hers" => true,
"they" | "them" | "their" | "theirs" => true,
"it" | "its" => true,
_ => true,
};
if compatible {
mentions.push(pro_text.clone());
spans.push((*pro_start, *pro_end));
assigned.insert(*pro_start, cluster_id);
break; }
}
}
if mentions.len() > 1 {
clusters.push(CorefCluster {
id: cluster_id,
canonical: ant_text.clone(),
mentions,
spans,
});
}
}
Ok(clusters)
}
fn extract_mentions(&self, marked_text: &str) -> Result<MentionList> {
extract_t5_mentions(marked_text)
}
fn cluster_mentions(
&self,
_text: &str,
mentions: &[(String, usize, usize)],
) -> Result<Vec<CorefCluster>> {
let mut clusters: Vec<CorefCluster> = Vec::new();
let mut assigned: HashMap<usize, u32> = HashMap::new();
let pronouns = [
"he", "she", "they", "it", "him", "her", "them", "his", "hers", "their", "its",
];
for (i, (text_i, start_i, end_i)) in mentions.iter().enumerate() {
if assigned.contains_key(&i) {
continue;
}
let lower_i = text_i.to_lowercase();
let is_pronoun_i = pronouns.contains(&lower_i.as_str());
if is_pronoun_i {
for j in (0..i).rev() {
let (text_j, _, _) = &mentions[j];
let lower_j = text_j.to_lowercase();
if !pronouns.contains(&lower_j.as_str()) {
if let Some(&cluster_id) = assigned.get(&j) {
assigned.insert(i, cluster_id);
clusters[cluster_id as usize].mentions.push(text_i.clone());
clusters[cluster_id as usize].spans.push((*start_i, *end_i));
}
break;
}
}
continue;
}
let cluster_id = clusters.len() as u32;
let mut cluster_mentions = vec![text_i.clone()];
let mut cluster_spans = vec![(*start_i, *end_i)];
assigned.insert(i, cluster_id);
for (j, (text_j, start_j, end_j)) in mentions.iter().enumerate().skip(i + 1) {
if assigned.contains_key(&j) {
continue;
}
let lower_j = text_j.to_lowercase();
let matches = lower_i == lower_j
|| lower_i.contains(&lower_j)
|| lower_j.contains(&lower_i)
|| {
let last_i = lower_i.split_whitespace().last();
let last_j = lower_j.split_whitespace().last();
last_i.is_some() && last_i == last_j && last_i.map(|w| w.len() > 2).unwrap_or(false)
};
if matches {
cluster_mentions.push(text_j.clone());
cluster_spans.push((*start_j, *end_j));
assigned.insert(j, cluster_id);
}
}
let canonical = cluster_mentions
.iter()
.max_by_key(|m| m.len())
.cloned()
.unwrap_or_else(|| text_i.clone());
clusters.push(CorefCluster {
id: cluster_id,
mentions: cluster_mentions,
spans: cluster_spans,
canonical,
});
}
let multi_clusters: Vec<CorefCluster> = clusters
.into_iter()
.filter(|c| c.mentions.len() > 1)
.collect();
Ok(multi_clusters)
}
pub fn model_path(&self) -> &str {
&self.model_path
}
}
pub fn mark_mentions_for_t5(text: &str) -> String {
const PRONOUNS: &[&str] = &[
"he", "she", "they", "it", "him", "her", "them", "his", "hers", "their", "its",
];
let mut out = String::with_capacity(text.len() + 64);
for (i, word) in text.split_whitespace().enumerate() {
if i > 0 {
out.push(' ');
}
let lower = word
.trim_matches(|c: char| !c.is_alphabetic())
.to_lowercase();
let is_pronoun = PRONOUNS.contains(&lower.as_str());
let is_cap = word
.chars()
.next()
.map(|c| c.is_uppercase())
.unwrap_or(false);
if is_pronoun || is_cap {
out.push_str("<m> ");
out.push_str(word);
out.push_str(" </m>");
} else {
out.push_str(word);
}
}
out
}
pub fn parse_t5_coref_output(decoded: &str) -> Vec<CorefCluster> {
let mut clusters: HashMap<u32, CorefCluster> = HashMap::new();
let tokens: Vec<&str> = decoded.split_whitespace().collect();
let mut offset: usize = 0;
let mut i = 0;
while i < tokens.len() {
let tok = tokens[i];
let is_pipe = tokens.get(i + 1).map(|&t| t == "|").unwrap_or(false);
let cluster_id: Option<u32> = if is_pipe {
tokens
.get(i + 2)
.and_then(|t| t.trim_matches(|c: char| !c.is_ascii_digit()).parse().ok())
} else {
None
};
if let Some(cid) = cluster_id {
let mention = tok.trim_matches(|c: char| !c.is_alphanumeric()).to_string();
if !mention.is_empty() {
let start = offset;
let end = offset + mention.len();
let entry = clusters.entry(cid).or_insert_with(|| CorefCluster {
id: cid,
mentions: Vec::new(),
spans: Vec::new(),
canonical: String::new(),
});
entry.mentions.push(mention);
entry.spans.push((start, end));
}
offset += tok.len() + 1;
i += 3;
continue;
}
offset += tok.len() + 1;
i += 1;
}
let mut result: Vec<CorefCluster> = clusters
.into_values()
.filter(|c| c.mentions.len() > 1)
.collect();
for c in &mut result {
c.canonical = c
.mentions
.iter()
.max_by_key(|m| m.len())
.cloned()
.unwrap_or_default();
}
result.sort_by_key(|c| c.id);
result
}
pub fn extract_t5_mentions(marked_text: &str) -> Result<MentionList> {
let mut plain_text = String::new();
let mut mentions = Vec::new();
let mut offset = 0;
let mut remaining = marked_text;
while !remaining.is_empty() {
if let Some(start_pos) = remaining.find("<m>") {
plain_text.push_str(&remaining[..start_pos]);
offset += start_pos;
let after_start = &remaining[start_pos + 3..];
if let Some(end_pos) = after_start.find("</m>") {
let mention_text = after_start[..end_pos].trim();
let mention_start = offset;
plain_text.push_str(mention_text);
let mention_end = offset + mention_text.len();
offset = mention_end;
mentions.push((mention_text.to_string(), mention_start, mention_end));
remaining = &after_start[end_pos + 4..];
} else {
plain_text.push_str(remaining);
break;
}
} else {
plain_text.push_str(remaining);
break;
}
}
Ok((plain_text, mentions))
}
#[cfg(test)]
mod tests;