Skip to main content

anno/backends/coref_t5/
mod.rs

1//! T5-based coreference resolution using ONNX Runtime.
2//!
3//! Experimental scaffold for seq2seq coreference.
4//!
5//! The intended approach treats coreference as a text-to-text transformation:
6//!
7//! ```text
8//! Input:  "<m> Elon </m> founded <m> Tesla </m>. <m> He </m> later led SpaceX."
9//! Output: "Elon | 1 founded Tesla | 2. He | 1 later led SpaceX."
10//! ```
11//!
12//! The model learns to assign cluster IDs to mentions, enabling coreference
13//! without explicit pairwise classification.
14//!
15//! **Status**: `T5Coref` loads ONNX artifacts, but `resolve()` currently uses a lightweight
16//! heuristic fallback (it does not yet run a full encoder/decoder loop).
17//!
18//! # Architecture
19//!
20//! ```text
21//! Text with Marked Mentions
22//!         │
23//!         ▼
24//! ┌───────────────────┐
25//! │   T5 Encoder      │
26//! │   (ONNX)          │
27//! └─────────┬─────────┘
28//!           │
29//!           ▼
30//! ┌───────────────────┐
31//! │   T5 Decoder      │
32//! │   (Autoregressive)│
33//! └─────────┬─────────┘
34//!           │
35//!           ▼
36//! Text with Cluster IDs
37//!         │
38//!         ▼
39//! ┌───────────────────┐
40//! │  Parse Clusters   │
41//! └───────────────────┘
42//!         │
43//!         ▼
44//! CoreferenceCluster[]
45//! ```
46//!
47//! # Model Export (One-Time Setup)
48//!
49//! Export a T5 coreference model to ONNX using Optimum:
50//!
51//! ```bash
52//! pip install optimum[onnxruntime]
53//! optimum-cli export onnx \
54//!     --model "google/flan-t5-base" \
55//!     --task text2text-generation-with-past \
56//!     t5_coref_onnx/
57//! ```
58//!
59//! # Example
60//!
61//! ```rust,ignore
62//! use anno::backends::coref_t5::{T5Coref, T5CorefConfig};
63//!
64//! let coref = T5Coref::from_path("path/to/t5_coref_onnx", T5CorefConfig::default())?
65//!     .with_heuristic_fallback();
66//!
67//! let text = "Sophie Wilson designed the ARM processor. She changed computing.";
68//! let clusters = coref.resolve(text)?;
69//!
70//! // clusters[0] = { members: ["Sophie Wilson", "She"], canonical: "Sophie Wilson" }
71//! ```
72//!
73//! # Research Background
74//!
75//! This approach is based on:
76//! - Seq2seq coref: "Coreference Resolution as Query-based Span Prediction" (Wu et al.)
77//! - FLAN-T5 fine-tuning for coreference tasks
78//! - Entity-centric markup format for mention boundaries
79//!
80//! The seq2seq approach outperforms traditional pairwise classifiers on:
81//! - OntoNotes 5.0 (coreference benchmark)
82//! - GAP (gendered pronoun resolution benchmark)
83
84// Note: This module is feature-gated via `#[cfg(feature = "onnx")]` in mod.rs
85
86use crate::{Entity, Error, Result};
87
88/// Return type for mention extraction: `(plain_text, [(mention_text, char_start, char_end)])`.
89type MentionList = (String, Vec<(String, usize, usize)>);
90use ndarray::{Array2, Array3};
91use std::collections::HashMap;
92use std::sync::Arc;
93
94use hf_hub::api::sync::Api;
95use ort::{
96    execution_providers::CPUExecutionProvider, session::builder::GraphOptimizationLevel,
97    session::Session,
98};
99use tokenizers::Tokenizer;
100
101/// A coreference cluster (group of mentions referring to the same entity).
102#[derive(Debug, Clone)]
103pub struct CorefCluster {
104    /// Cluster ID
105    pub id: u32,
106    /// Member mention texts
107    pub mentions: Vec<String>,
108    /// Member mention spans (start, end)
109    pub spans: Vec<(usize, usize)>,
110    /// Canonical name (longest/most informative mention)
111    pub canonical: String,
112}
113
114/// Configuration for T5 coreference model.
115#[derive(Debug, Clone)]
116pub struct T5CorefConfig {
117    /// Maximum input length (tokens)
118    pub max_input_length: usize,
119    /// Maximum output length (tokens)
120    pub max_output_length: usize,
121    /// Beam search width (1 = greedy)
122    pub num_beams: usize,
123    /// ONNX optimization level
124    pub optimization_level: u8,
125    /// Number of inference threads
126    pub num_threads: usize,
127}
128
129impl Default for T5CorefConfig {
130    fn default() -> Self {
131        Self {
132            max_input_length: 512,
133            max_output_length: 512,
134            num_beams: 1, // Greedy for speed
135            optimization_level: 3,
136            num_threads: 4,
137        }
138    }
139}
140
141/// T5-based coreference resolution.
142///
143/// Uses a seq2seq model to assign cluster IDs to marked mentions.
144///
145/// # Note
146///
147/// Currently uses a simplified rule-based fallback. Full seq2seq inference
148/// is planned for a future release when encoder-decoder ONNX support matures.
149pub struct T5Coref {
150    /// Encoder ONNX session.
151    encoder: crate::sync::Mutex<Session>,
152    /// Decoder ONNX session.
153    decoder: crate::sync::Mutex<Session>,
154    /// HuggingFace tokenizer for input encoding and output decoding.
155    tokenizer: Arc<Tokenizer>,
156    /// Inference configuration.
157    config: T5CorefConfig,
158    /// Model path or HuggingFace model ID.
159    model_path: String,
160}
161
162impl T5Coref {
163    /// Create a new T5 coreference model from a local ONNX export.
164    ///
165    /// # Arguments
166    ///
167    /// * `model_path` - Path to directory containing encoder.onnx and decoder_model.onnx
168    pub fn from_path(model_path: &str, config: T5CorefConfig) -> Result<Self> {
169        let encoder_path = format!("{}/encoder_model.onnx", model_path);
170        let decoder_path = format!("{}/decoder_model.onnx", model_path);
171        let tokenizer_path = format!("{}/tokenizer.json", model_path);
172
173        // Check files exist
174        if !std::path::Path::new(&encoder_path).exists() {
175            return Err(Error::Retrieval(format!(
176                "Encoder not found at {}. Export with: optimum-cli export onnx --model <model> --task text2text-generation-with-past {}",
177                encoder_path, model_path
178            )));
179        }
180
181        // Helper to create opt level
182        let get_opt_level = || match config.optimization_level {
183            1 => GraphOptimizationLevel::Level1,
184            2 => GraphOptimizationLevel::Level2,
185            _ => GraphOptimizationLevel::Level3,
186        };
187
188        // Load encoder
189        let encoder = Session::builder()
190            .map_err(|e| Error::Retrieval(format!("Encoder builder: {}", e)))?
191            .with_optimization_level(get_opt_level())
192            .map_err(|e| Error::Retrieval(format!("Encoder opt: {}", e)))?
193            .with_execution_providers([CPUExecutionProvider::default().build()])
194            .map_err(|e| Error::Retrieval(format!("Encoder provider: {}", e)))?
195            .with_intra_threads(config.num_threads)
196            .map_err(|e| Error::Retrieval(format!("Encoder threads: {}", e)))?
197            .commit_from_file(&encoder_path)
198            .map_err(|e| Error::Retrieval(format!("Encoder load: {}", e)))?;
199
200        // Load decoder
201        let decoder = Session::builder()
202            .map_err(|e| Error::Retrieval(format!("Decoder builder: {}", e)))?
203            .with_optimization_level(get_opt_level())
204            .map_err(|e| Error::Retrieval(format!("Decoder opt: {}", e)))?
205            .with_execution_providers([CPUExecutionProvider::default().build()])
206            .map_err(|e| Error::Retrieval(format!("Decoder provider: {}", e)))?
207            .with_intra_threads(config.num_threads)
208            .map_err(|e| Error::Retrieval(format!("Decoder threads: {}", e)))?
209            .commit_from_file(&decoder_path)
210            .map_err(|e| Error::Retrieval(format!("Decoder load: {}", e)))?;
211
212        // Load tokenizer
213        let tokenizer = Tokenizer::from_file(&tokenizer_path)
214            .map_err(|e| Error::Retrieval(format!("Tokenizer: {}", e)))?;
215
216        log::info!("[T5-Coref] Loaded model from {}", model_path);
217
218        Ok(Self {
219            encoder: crate::sync::Mutex::new(encoder),
220            decoder: crate::sync::Mutex::new(decoder),
221            tokenizer: Arc::new(tokenizer),
222            config,
223            model_path: model_path.to_string(),
224        })
225    }
226
227    /// Create from HuggingFace model ID.
228    ///
229    /// Downloads ONNX-exported T5 model from HuggingFace Hub.
230    pub fn from_pretrained(model_id: &str) -> Result<Self> {
231        Self::from_pretrained_with_config(model_id, T5CorefConfig::default())
232    }
233
234    /// Create from HuggingFace with custom config.
235    pub fn from_pretrained_with_config(model_id: &str, config: T5CorefConfig) -> Result<Self> {
236        let api = Api::new().map_err(|e| Error::Retrieval(format!("HuggingFace API: {}", e)))?;
237
238        let repo = api.model(model_id.to_string());
239
240        // Download ONNX files
241        let encoder_path = repo
242            .get("encoder_model.onnx")
243            .or_else(|_| repo.get("onnx/encoder_model.onnx"))
244            .map_err(|e| Error::Retrieval(format!("Encoder download: {}", e)))?;
245
246        let decoder_path = repo
247            .get("decoder_model.onnx")
248            .or_else(|_| repo.get("onnx/decoder_model.onnx"))
249            .or_else(|_| repo.get("decoder_with_past_model.onnx"))
250            .map_err(|e| Error::Retrieval(format!("Decoder download: {}", e)))?;
251
252        let tokenizer_path = repo
253            .get("tokenizer.json")
254            .map_err(|e| Error::Retrieval(format!("Tokenizer download: {}", e)))?;
255
256        // Helper to create opt level
257        let get_opt_level = || match config.optimization_level {
258            1 => GraphOptimizationLevel::Level1,
259            2 => GraphOptimizationLevel::Level2,
260            _ => GraphOptimizationLevel::Level3,
261        };
262
263        // Load encoder
264        let encoder = Session::builder()
265            .map_err(|e| Error::Retrieval(format!("Encoder builder: {}", e)))?
266            .with_optimization_level(get_opt_level())
267            .map_err(|e| Error::Retrieval(format!("Encoder opt: {}", e)))?
268            .with_execution_providers([CPUExecutionProvider::default().build()])
269            .map_err(|e| Error::Retrieval(format!("Encoder provider: {}", e)))?
270            .commit_from_file(&encoder_path)
271            .map_err(|e| Error::Retrieval(format!("Encoder load: {}", e)))?;
272
273        // Load decoder
274        let decoder = Session::builder()
275            .map_err(|e| Error::Retrieval(format!("Decoder builder: {}", e)))?
276            .with_optimization_level(get_opt_level())
277            .map_err(|e| Error::Retrieval(format!("Decoder opt: {}", e)))?
278            .with_execution_providers([CPUExecutionProvider::default().build()])
279            .map_err(|e| Error::Retrieval(format!("Decoder provider: {}", e)))?
280            .commit_from_file(&decoder_path)
281            .map_err(|e| Error::Retrieval(format!("Decoder load: {}", e)))?;
282
283        // Load tokenizer
284        let tokenizer = Tokenizer::from_file(&tokenizer_path)
285            .map_err(|e| Error::Retrieval(format!("Tokenizer: {}", e)))?;
286
287        log::info!("[T5-Coref] Loaded model from {}", model_id);
288
289        Ok(Self {
290            encoder: crate::sync::Mutex::new(encoder),
291            decoder: crate::sync::Mutex::new(decoder),
292            tokenizer: Arc::new(tokenizer),
293            config,
294            model_path: model_id.to_string(),
295        })
296    }
297
298    /// Resolve coreference in text.
299    ///
300    /// Runs the T5 encoder-decoder loop when ONNX weights are available.
301    /// Falls back to the rule-based heuristic if inference fails or produces
302    /// no clusters (e.g. model not fine-tuned for coref, or GPU OOM).
303    pub fn resolve(&self, text: &str) -> Result<Vec<CorefCluster>> {
304        if text.is_empty() {
305            return Ok(vec![]);
306        }
307        match self.resolve_t5(text) {
308            Ok(clusters) if !clusters.is_empty() => Ok(clusters),
309            Ok(_) => {
310                log::debug!("[T5-Coref] inference produced no clusters, using heuristic fallback");
311                self.resolve_simple(text)
312            }
313            Err(e) => {
314                log::warn!(
315                    "[T5-Coref] inference failed ({}), using heuristic fallback",
316                    e
317                );
318                self.resolve_simple(text)
319            }
320        }
321    }
322
323    // -------------------------------------------------------------------------
324    // T5 encoder-decoder inference
325    // -------------------------------------------------------------------------
326
327    /// Full T5 inference path: mark → encode → greedy-decode → parse.
328    fn resolve_t5(&self, text: &str) -> Result<Vec<CorefCluster>> {
329        let marked = self.mark_mentions(text);
330        let (input_ids, attention_mask) = self.tokenize_input(&marked)?;
331        let (enc_hidden, enc_seq_len, hidden_size) =
332            self.run_encoder(&input_ids, &attention_mask)?;
333        let output_ids =
334            self.greedy_decode(&enc_hidden, enc_seq_len, hidden_size, &attention_mask)?;
335        let decoded = self.decode_tokens(&output_ids)?;
336        Ok(self.parse_coref_output(&decoded))
337    }
338
339    /// Heuristically mark pronouns and capitalised tokens with `<m>…</m>` so the
340    /// T5 model sees explicit mention boundaries.
341    fn mark_mentions(&self, text: &str) -> String {
342        mark_mentions_for_t5(text)
343    }
344
345    /// Tokenize `text` with the HuggingFace tokenizer.
346    /// Returns `(input_ids, attention_mask)` as `i64` vecs, truncated to
347    /// `config.max_input_length`.
348    fn tokenize_input(&self, text: &str) -> Result<(Vec<i64>, Vec<i64>)> {
349        let mut enc = self
350            .tokenizer
351            .encode(text, true)
352            .map_err(|e| Error::Parse(format!("T5Coref tokenizer encode: {e}")))?;
353        // Encoding::truncate returns () — not a Result.
354        enc.truncate(
355            self.config.max_input_length,
356            0,
357            tokenizers::TruncationDirection::Right,
358        );
359        let input_ids: Vec<i64> = enc.get_ids().iter().map(|&x| x as i64).collect();
360        let attention_mask: Vec<i64> = enc.get_attention_mask().iter().map(|&x| x as i64).collect();
361        Ok((input_ids, attention_mask))
362    }
363
364    /// Run the T5 encoder.  Returns `(flat_hidden_states, seq_len, hidden_size)`.
365    fn run_encoder(
366        &self,
367        input_ids: &[i64],
368        attention_mask: &[i64],
369    ) -> Result<(Vec<f32>, usize, usize)> {
370        let batch = 1usize;
371        let seq_len = input_ids.len();
372
373        let ids_arr = Array2::<i64>::from_shape_vec((batch, seq_len), input_ids.to_vec())
374            .map_err(|e| Error::Parse(format!("encoder ids shape: {e}")))?;
375        let mask_arr = Array2::<i64>::from_shape_vec((batch, seq_len), attention_mask.to_vec())
376            .map_err(|e| Error::Parse(format!("encoder mask shape: {e}")))?;
377
378        let ids_t = super::ort_compat::tensor_from_ndarray(ids_arr)
379            .map_err(|e| Error::Parse(format!("encoder ids tensor: {e}")))?;
380        let mask_t = super::ort_compat::tensor_from_ndarray(mask_arr)
381            .map_err(|e| Error::Parse(format!("encoder mask tensor: {e}")))?;
382
383        // Scope the mutex guard: `outputs` borrows from the session; extract owned
384        // data before the guard drops.
385        let (hidden_flat, hidden_size) = {
386            let mut enc = crate::sync::lock(&self.encoder);
387            let outputs = enc
388                .run(ort::inputs![
389                    "input_ids" => ids_t.into_dyn(),
390                    "attention_mask" => mask_t.into_dyn(),
391                ])
392                .map_err(|e| Error::Parse(format!("T5Coref encoder run: {e}")))?;
393            let hidden_val = outputs.get("last_hidden_state").ok_or_else(|| {
394                Error::Parse(
395                    "T5 encoder output 'last_hidden_state' not found; check ONNX export".into(),
396                )
397            })?;
398            let (shape, data) = hidden_val
399                .try_extract_tensor::<f32>()
400                .map_err(|e| Error::Parse(format!("encoder extract tensor: {e}")))?;
401            if shape.len() != 3 || shape[0] != 1 {
402                return Err(Error::Parse(format!(
403                    "T5 encoder: unexpected hidden-state shape {:?}",
404                    shape
405                )));
406            }
407            (data.to_vec(), shape[2] as usize)
408        }; // enc guard drops here
409        Ok((hidden_flat, seq_len, hidden_size))
410    }
411
412    /// Run one greedy decoder step and return the next token ID.
413    ///
414    /// The full `decoder_input_ids` sequence is fed each time (no KV-cache),
415    /// which is O(n²) but correct and avoids managing past key-values.
416    fn decoder_step(
417        &self,
418        encoder_hidden: &[f32],
419        enc_seq_len: usize,
420        hidden_size: usize,
421        attention_mask: &[i64],
422        decoder_input_ids: &[i64],
423    ) -> Result<i64> {
424        let batch = 1usize;
425        let dec_len = decoder_input_ids.len();
426
427        let enc_h = Array3::<f32>::from_shape_vec(
428            (batch, enc_seq_len, hidden_size),
429            encoder_hidden.to_vec(),
430        )
431        .map_err(|e| Error::Parse(format!("decoder enc_hidden shape: {e}")))?;
432        let attn = Array2::<i64>::from_shape_vec((batch, enc_seq_len), attention_mask.to_vec())
433            .map_err(|e| Error::Parse(format!("decoder attn shape: {e}")))?;
434        let dec_ids = Array2::<i64>::from_shape_vec((batch, dec_len), decoder_input_ids.to_vec())
435            .map_err(|e| Error::Parse(format!("decoder_ids shape: {e}")))?;
436
437        let enc_h_t = super::ort_compat::tensor_from_ndarray(enc_h)
438            .map_err(|e| Error::Parse(format!("enc_h tensor: {e}")))?;
439        let attn_t = super::ort_compat::tensor_from_ndarray(attn)
440            .map_err(|e| Error::Parse(format!("attn tensor: {e}")))?;
441        let dec_ids_t = super::ort_compat::tensor_from_ndarray(dec_ids)
442            .map_err(|e| Error::Parse(format!("dec_ids tensor: {e}")))?;
443
444        // Scope the mutex guard: extract owned data before the guard drops.
445        let next_token = {
446            let mut dec = crate::sync::lock(&self.decoder);
447            let outputs = dec
448                .run(ort::inputs![
449                    "encoder_hidden_states" => enc_h_t.into_dyn(),
450                    "attention_mask"        => attn_t.into_dyn(),
451                    "decoder_input_ids"     => dec_ids_t.into_dyn(),
452                ])
453                .map_err(|e| Error::Parse(format!("T5Coref decoder run: {e}")))?;
454            let logits_val = outputs.get("logits").ok_or_else(|| {
455                Error::Parse("T5 decoder output 'logits' not found; check ONNX export".into())
456            })?;
457            let (shape, logits_data) = logits_val
458                .try_extract_tensor::<f32>()
459                .map_err(|e| Error::Parse(format!("decoder logits extract: {e}")))?;
460            // Expected shape: [1, dec_len, vocab_size]
461            if shape.len() != 3 || shape[0] != 1 {
462                return Err(Error::Parse(format!(
463                    "T5 decoder: unexpected logits shape {:?}",
464                    shape
465                )));
466            }
467            let vocab_size = shape[2] as usize;
468            let last_offset = (dec_len - 1) * vocab_size;
469            let last_logits = &logits_data[last_offset..last_offset + vocab_size];
470            last_logits
471                .iter()
472                .enumerate()
473                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
474                .map(|(i, _)| i as i64)
475                .unwrap_or(1) // EOS as fallback
476        }; // dec guard drops here
477        Ok(next_token)
478    }
479
480    /// Greedy decode from encoder output.  Returns generated token IDs (excluding the
481    /// leading pad/decoder-start token).
482    fn greedy_decode(
483        &self,
484        encoder_hidden: &[f32],
485        enc_seq_len: usize,
486        hidden_size: usize,
487        attention_mask: &[i64],
488    ) -> Result<Vec<i64>> {
489        // T5 uses pad token (0) as the decoder start token; EOS is 1.
490        const T5_PAD: i64 = 0;
491        const T5_EOS: i64 = 1;
492        let mut generated = vec![T5_PAD];
493
494        for _ in 0..self.config.max_output_length {
495            let next = self.decoder_step(
496                encoder_hidden,
497                enc_seq_len,
498                hidden_size,
499                attention_mask,
500                &generated,
501            )?;
502            if next == T5_EOS {
503                break;
504            }
505            generated.push(next);
506        }
507
508        Ok(generated[1..].to_vec()) // drop the leading pad start token
509    }
510
511    /// Decode output token IDs to a string with the tokenizer.
512    fn decode_tokens(&self, token_ids: &[i64]) -> Result<String> {
513        let ids: Vec<u32> = token_ids.iter().map(|&x| x as u32).collect();
514        self.tokenizer
515            .decode(&ids, true)
516            .map_err(|e| Error::Parse(format!("T5Coref decode_tokens: {e}")))
517    }
518
519    /// Parse T5 cluster-ID output format (`"word | N"`) into `CorefCluster`s.
520    ///
521    /// The expected output format is:
522    /// ```text
523    /// "Elon | 1 founded Tesla | 2. He | 1 later led SpaceX."
524    /// ```
525    /// where ` | N` immediately follows a mention token and assigns it to cluster `N`.
526    ///
527    /// Singletons (clusters with only one mention) are filtered out.
528    fn parse_coref_output(&self, decoded: &str) -> Vec<CorefCluster> {
529        parse_t5_coref_output(decoded)
530    }
531
532    /// Resolve coreference with pre-marked mentions.
533    ///
534    /// Expects mentions marked with `<m>` and `</m>` tags:
535    /// `"<m> Sophie Wilson </m> designed ARM. <m> She </m> changed computing."`
536    /// Resolve coreference with pre-marked mentions.
537    ///
538    /// Expects mentions marked with `<m>` and `</m>` tags:
539    /// `"<m> Sophie Wilson </m> designed ARM. <m> She </m> changed computing."`
540    ///
541    /// Runs T5 inference directly on the marked text (skipping the auto-marking
542    /// step used by [`resolve`]).  Falls back to similarity-based clustering when
543    /// inference fails or produces no clusters.
544    pub fn resolve_marked(&self, marked_text: &str) -> Result<Vec<CorefCluster>> {
545        let (plain_text, mentions) = self.extract_mentions(marked_text)?;
546        if mentions.is_empty() {
547            return Ok(vec![]);
548        }
549        // The text is already marked — feed it directly to T5 without re-marking.
550        match self.resolve_t5_raw(marked_text) {
551            Ok(clusters) if !clusters.is_empty() => Ok(clusters),
552            Ok(_) => self.cluster_mentions(&plain_text, &mentions),
553            Err(e) => {
554                log::warn!(
555                    "[T5-Coref] resolve_marked inference failed ({}), using fallback",
556                    e
557                );
558                self.cluster_mentions(&plain_text, &mentions)
559            }
560        }
561    }
562
563    /// Resolve coreference for a set of entities from NER.
564    ///
565    /// Reconstructs `<m>…</m>` markers from entity spans, then runs T5 inference.
566    /// Falls back to similarity-based clustering when inference fails.
567    pub fn resolve_entities(&self, text: &str, entities: &[Entity]) -> Result<Vec<CorefCluster>> {
568        if entities.is_empty() {
569            return Ok(vec![]);
570        }
571
572        // Rebuild marked text from entity spans so T5 sees explicit mention boundaries.
573        let marked = self.mark_entity_spans(text, entities);
574        match self.resolve_t5_raw(&marked) {
575            Ok(clusters) if !clusters.is_empty() => Ok(clusters),
576            Ok(_) => {
577                let mentions: Vec<(String, usize, usize)> = entities
578                    .iter()
579                    .map(|e| (e.text.clone(), e.start, e.end))
580                    .collect();
581                self.cluster_mentions(text, &mentions)
582            }
583            Err(e) => {
584                log::warn!(
585                    "[T5-Coref] resolve_entities inference failed ({}), using fallback",
586                    e
587                );
588                let mentions: Vec<(String, usize, usize)> = entities
589                    .iter()
590                    .map(|e| (e.text.clone(), e.start, e.end))
591                    .collect();
592                self.cluster_mentions(text, &mentions)
593            }
594        }
595    }
596
597    /// Run T5 on pre-marked text (already has `<m>…</m>` tags).
598    ///
599    /// This is the shared inner path for [`resolve_marked`] and [`resolve_entities`];
600    /// unlike [`resolve_t5`] it does **not** call `mark_mentions`.
601    fn resolve_t5_raw(&self, marked_text: &str) -> Result<Vec<CorefCluster>> {
602        let (input_ids, attention_mask) = self.tokenize_input(marked_text)?;
603        let (enc_hidden, enc_seq_len, hidden_size) =
604            self.run_encoder(&input_ids, &attention_mask)?;
605        let output_ids =
606            self.greedy_decode(&enc_hidden, enc_seq_len, hidden_size, &attention_mask)?;
607        let decoded = self.decode_tokens(&output_ids)?;
608        Ok(self.parse_coref_output(&decoded))
609    }
610
611    /// Reconstruct a `<m>…</m>`-marked string from entity spans.
612    ///
613    /// Entities are sorted by start offset; overlapping spans are skipped.
614    fn mark_entity_spans(&self, text: &str, entities: &[Entity]) -> String {
615        let chars: Vec<char> = text.chars().collect();
616        let char_len = chars.len();
617
618        let mut sorted: Vec<&Entity> = entities.iter().collect();
619        sorted.sort_by_key(|e| e.start);
620
621        let mut out = String::with_capacity(text.len() + entities.len() * 10);
622        let mut cursor = 0usize; // char offset
623
624        for e in &sorted {
625            if e.start >= e.end || e.start < cursor || e.end > char_len {
626                continue;
627            }
628            // Text before this entity
629            for &ch in &chars[cursor..e.start] {
630                out.push(ch);
631            }
632            out.push_str("<m> ");
633            for &ch in &chars[e.start..e.end] {
634                out.push(ch);
635            }
636            out.push_str(" </m>");
637            cursor = e.end;
638        }
639
640        // Remaining text
641        for &ch in &chars[cursor..] {
642            out.push(ch);
643        }
644        out
645    }
646
647    /// Simple rule-based coreference (fallback).
648    fn resolve_simple(&self, text: &str) -> Result<Vec<CorefCluster>> {
649        // Simple heuristic: find pronouns and link to nearest compatible noun
650        let pronouns = ["he", "she", "they", "it", "his", "her", "their", "its"];
651
652        let words: Vec<(String, usize, usize)> = {
653            let mut result = Vec::new();
654            let mut pos = 0;
655            for word in text.split_whitespace() {
656                if let Some(start) = text[pos..].find(word) {
657                    let abs_start = pos + start;
658                    result.push((word.to_string(), abs_start, abs_start + word.len()));
659                    pos = abs_start + word.len();
660                }
661            }
662            result
663        };
664
665        // Find potential antecedents (capitalized words, likely names)
666        let antecedents: Vec<&(String, usize, usize)> = words
667            .iter()
668            .filter(|(w, _, _)| {
669                w.chars().next().map(|c| c.is_uppercase()).unwrap_or(false)
670                    && !pronouns.contains(&w.to_lowercase().as_str())
671            })
672            .collect();
673
674        // Find pronouns
675        let pronoun_mentions: Vec<&(String, usize, usize)> = words
676            .iter()
677            .filter(|(w, _, _)| pronouns.contains(&w.to_lowercase().as_str()))
678            .collect();
679
680        // Build clusters
681        let mut clusters: Vec<CorefCluster> = Vec::new();
682        let mut assigned: HashMap<usize, u32> = HashMap::new();
683
684        for (ant_text, ant_start, ant_end) in &antecedents {
685            // Check if already assigned
686            if assigned.contains_key(ant_start) {
687                continue;
688            }
689
690            let cluster_id = clusters.len() as u32;
691            let mut mentions = vec![ant_text.clone()];
692            let mut spans = vec![(*ant_start, *ant_end)];
693
694            assigned.insert(*ant_start, cluster_id);
695
696            // Find pronouns after this antecedent that could refer to it
697            for (pro_text, pro_start, pro_end) in &pronoun_mentions {
698                if *pro_start > *ant_end && !assigned.contains_key(pro_start) {
699                    // Check gender compatibility (simplified)
700                    let compatible = match pro_text.to_lowercase().as_str() {
701                        "he" | "him" | "his" => true, // Could be anyone
702                        "she" | "her" | "hers" => true,
703                        "they" | "them" | "their" | "theirs" => true,
704                        "it" | "its" => true,
705                        _ => true,
706                    };
707
708                    if compatible {
709                        mentions.push(pro_text.clone());
710                        spans.push((*pro_start, *pro_end));
711                        assigned.insert(*pro_start, cluster_id);
712                        break; // Only link nearest pronoun
713                    }
714                }
715            }
716
717            if mentions.len() > 1 {
718                clusters.push(CorefCluster {
719                    id: cluster_id,
720                    canonical: ant_text.clone(),
721                    mentions,
722                    spans,
723                });
724            }
725        }
726
727        Ok(clusters)
728    }
729
730    fn extract_mentions(&self, marked_text: &str) -> Result<MentionList> {
731        extract_t5_mentions(marked_text)
732    }
733
734    /// Cluster mentions by similarity.
735    fn cluster_mentions(
736        &self,
737        _text: &str,
738        mentions: &[(String, usize, usize)],
739    ) -> Result<Vec<CorefCluster>> {
740        // Simple clustering: exact match + substring match + pronoun resolution
741        let mut clusters: Vec<CorefCluster> = Vec::new();
742        let mut assigned: HashMap<usize, u32> = HashMap::new();
743
744        let pronouns = [
745            "he", "she", "they", "it", "him", "her", "them", "his", "hers", "their", "its",
746        ];
747
748        for (i, (text_i, start_i, end_i)) in mentions.iter().enumerate() {
749            if assigned.contains_key(&i) {
750                continue;
751            }
752
753            let lower_i = text_i.to_lowercase();
754            let is_pronoun_i = pronouns.contains(&lower_i.as_str());
755
756            if is_pronoun_i {
757                // Find nearest preceding non-pronoun
758                for j in (0..i).rev() {
759                    let (text_j, _, _) = &mentions[j];
760                    let lower_j = text_j.to_lowercase();
761                    if !pronouns.contains(&lower_j.as_str()) {
762                        if let Some(&cluster_id) = assigned.get(&j) {
763                            assigned.insert(i, cluster_id);
764                            clusters[cluster_id as usize].mentions.push(text_i.clone());
765                            clusters[cluster_id as usize].spans.push((*start_i, *end_i));
766                        }
767                        break;
768                    }
769                }
770                continue;
771            }
772
773            // Start new cluster
774            let cluster_id = clusters.len() as u32;
775            let mut cluster_mentions = vec![text_i.clone()];
776            let mut cluster_spans = vec![(*start_i, *end_i)];
777            assigned.insert(i, cluster_id);
778
779            // Find matches
780            for (j, (text_j, start_j, end_j)) in mentions.iter().enumerate().skip(i + 1) {
781                if assigned.contains_key(&j) {
782                    continue;
783                }
784
785                let lower_j = text_j.to_lowercase();
786
787                // Exact match
788                let matches = lower_i == lower_j
789                    // Substring match
790                    || lower_i.contains(&lower_j)
791                    || lower_j.contains(&lower_i)
792                    // Last word match (surname)
793                    || {
794                        let last_i = lower_i.split_whitespace().last();
795                        let last_j = lower_j.split_whitespace().last();
796                        last_i.is_some() && last_i == last_j && last_i.map(|w| w.len() > 2).unwrap_or(false)
797                    };
798
799                if matches {
800                    cluster_mentions.push(text_j.clone());
801                    cluster_spans.push((*start_j, *end_j));
802                    assigned.insert(j, cluster_id);
803                }
804            }
805
806            // Determine canonical (longest mention)
807            let canonical = cluster_mentions
808                .iter()
809                .max_by_key(|m| m.len())
810                .cloned()
811                .unwrap_or_else(|| text_i.clone());
812
813            clusters.push(CorefCluster {
814                id: cluster_id,
815                mentions: cluster_mentions,
816                spans: cluster_spans,
817                canonical,
818            });
819        }
820
821        // Filter to only multi-mention clusters
822        let multi_clusters: Vec<CorefCluster> = clusters
823            .into_iter()
824            .filter(|c| c.mentions.len() > 1)
825            .collect();
826
827        Ok(multi_clusters)
828    }
829
830    /// Get model path.
831    pub fn model_path(&self) -> &str {
832        &self.model_path
833    }
834}
835
836// =============================================================================
837// Free-function helpers (pure parsing — no ONNX, directly testable)
838// =============================================================================
839
840/// Heuristically mark pronouns and capitalised tokens in `text` with `<m>…</m>` tags.
841///
842/// This is the same logic used by `T5Coref::mark_mentions` and is exposed as a free
843/// function so it can be tested and reused without an ONNX session.
844pub fn mark_mentions_for_t5(text: &str) -> String {
845    const PRONOUNS: &[&str] = &[
846        "he", "she", "they", "it", "him", "her", "them", "his", "hers", "their", "its",
847    ];
848    let mut out = String::with_capacity(text.len() + 64);
849    for (i, word) in text.split_whitespace().enumerate() {
850        if i > 0 {
851            out.push(' ');
852        }
853        let lower = word
854            .trim_matches(|c: char| !c.is_alphabetic())
855            .to_lowercase();
856        let is_pronoun = PRONOUNS.contains(&lower.as_str());
857        let is_cap = word
858            .chars()
859            .next()
860            .map(|c| c.is_uppercase())
861            .unwrap_or(false);
862        if is_pronoun || is_cap {
863            out.push_str("<m> ");
864            out.push_str(word);
865            out.push_str(" </m>");
866        } else {
867            out.push_str(word);
868        }
869    }
870    out
871}
872
873/// Parse T5 cluster-ID output format (`"word | N"`) into [`CorefCluster`]s.
874///
875/// Singletons are filtered out.  This is the same logic as
876/// `T5Coref::parse_coref_output` and is exposed as a free function for testing.
877pub fn parse_t5_coref_output(decoded: &str) -> Vec<CorefCluster> {
878    let mut clusters: HashMap<u32, CorefCluster> = HashMap::new();
879    let tokens: Vec<&str> = decoded.split_whitespace().collect();
880    let mut offset: usize = 0;
881    let mut i = 0;
882
883    while i < tokens.len() {
884        let tok = tokens[i];
885        let is_pipe = tokens.get(i + 1).map(|&t| t == "|").unwrap_or(false);
886        let cluster_id: Option<u32> = if is_pipe {
887            tokens
888                .get(i + 2)
889                .and_then(|t| t.trim_matches(|c: char| !c.is_ascii_digit()).parse().ok())
890        } else {
891            None
892        };
893
894        if let Some(cid) = cluster_id {
895            let mention = tok.trim_matches(|c: char| !c.is_alphanumeric()).to_string();
896            if !mention.is_empty() {
897                let start = offset;
898                let end = offset + mention.len();
899                let entry = clusters.entry(cid).or_insert_with(|| CorefCluster {
900                    id: cid,
901                    mentions: Vec::new(),
902                    spans: Vec::new(),
903                    canonical: String::new(),
904                });
905                entry.mentions.push(mention);
906                entry.spans.push((start, end));
907            }
908            offset += tok.len() + 1;
909            i += 3;
910            continue;
911        }
912
913        offset += tok.len() + 1;
914        i += 1;
915    }
916
917    let mut result: Vec<CorefCluster> = clusters
918        .into_values()
919        .filter(|c| c.mentions.len() > 1)
920        .collect();
921    for c in &mut result {
922        c.canonical = c
923            .mentions
924            .iter()
925            .max_by_key(|m| m.len())
926            .cloned()
927            .unwrap_or_default();
928    }
929    result.sort_by_key(|c| c.id);
930    result
931}
932
933/// Extract `<m>…</m>` spans from `marked_text`, returning `(plain_text, mentions)`.
934///
935/// Each mention is `(text, char_start, char_end)` in the plain text.
936/// This is the same logic as `T5Coref::extract_mentions` and is exposed as a
937/// free function for testing.
938pub fn extract_t5_mentions(marked_text: &str) -> Result<MentionList> {
939    let mut plain_text = String::new();
940    let mut mentions = Vec::new();
941    let mut offset = 0;
942
943    let mut remaining = marked_text;
944    while !remaining.is_empty() {
945        if let Some(start_pos) = remaining.find("<m>") {
946            plain_text.push_str(&remaining[..start_pos]);
947            offset += start_pos;
948
949            let after_start = &remaining[start_pos + 3..];
950            if let Some(end_pos) = after_start.find("</m>") {
951                let mention_text = after_start[..end_pos].trim();
952                let mention_start = offset;
953                plain_text.push_str(mention_text);
954                let mention_end = offset + mention_text.len();
955                offset = mention_end;
956
957                mentions.push((mention_text.to_string(), mention_start, mention_end));
958                remaining = &after_start[end_pos + 4..];
959            } else {
960                plain_text.push_str(remaining);
961                break;
962            }
963        } else {
964            plain_text.push_str(remaining);
965            break;
966        }
967    }
968
969    Ok((plain_text, mentions))
970}
971
972// =============================================================================
973// Tests
974// =============================================================================
975
976#[cfg(test)]
977mod tests;