Skip to main content

anno/backends/
coref_t5.rs

1//! T5-based coreference resolution using ONNX Runtime.
2//!
3//! Experimental scaffold for seq2seq coreference.
4//!
5//! The intended approach treats coreference as a text-to-text transformation:
6//!
7//! ```text
8//! Input:  "<m> Elon </m> founded <m> Tesla </m>. <m> He </m> later led SpaceX."
9//! Output: "Elon | 1 founded Tesla | 2. He | 1 later led SpaceX."
10//! ```
11//!
12//! The model learns to assign cluster IDs to mentions, enabling coreference
13//! without explicit pairwise classification.
14//!
15//! **Status**: `T5Coref` loads ONNX artifacts, but `resolve()` currently uses a lightweight
16//! heuristic fallback (it does not yet run a full encoder/decoder loop).
17//!
18//! # Architecture
19//!
20//! ```text
21//! Text with Marked Mentions
22//!         │
23//!         ▼
24//! ┌───────────────────┐
25//! │   T5 Encoder      │
26//! │   (ONNX)          │
27//! └─────────┬─────────┘
28//!           │
29//!           ▼
30//! ┌───────────────────┐
31//! │   T5 Decoder      │
32//! │   (Autoregressive)│
33//! └─────────┬─────────┘
34//!           │
35//!           ▼
36//! Text with Cluster IDs
37//!         │
38//!         ▼
39//! ┌───────────────────┐
40//! │  Parse Clusters   │
41//! └───────────────────┘
42//!         │
43//!         ▼
44//! CoreferenceCluster[]
45//! ```
46//!
47//! # Model Export (One-Time Setup)
48//!
49//! Export a T5 coreference model to ONNX using Optimum:
50//!
51//! ```bash
52//! pip install optimum[onnxruntime]
53//! optimum-cli export onnx \
54//!     --model "google/flan-t5-base" \
55//!     --task text2text-generation-with-past \
56//!     t5_coref_onnx/
57//! ```
58//!
59//! # Example
60//!
61//! ```rust,ignore
62//! use anno::backends::coref_t5::{T5Coref, T5CorefConfig};
63//!
64//! let coref = T5Coref::from_path("path/to/t5_coref_onnx", T5CorefConfig::default())?
65//!     .with_heuristic_fallback();
66//!
67//! let text = "Sophie Wilson designed the ARM processor. She changed computing.";
68//! let clusters = coref.resolve(text)?;
69//!
70//! // clusters[0] = { members: ["Sophie Wilson", "She"], canonical: "Sophie Wilson" }
71//! ```
72//!
73//! # Research Background
74//!
75//! This approach is based on:
76//! - Seq2seq coref: "Coreference Resolution as Query-based Span Prediction" (Wu et al.)
77//! - FLAN-T5 fine-tuning for coreference tasks
78//! - Entity-centric markup format for mention boundaries
79//!
80//! The seq2seq approach outperforms traditional pairwise classifiers on:
81//! - OntoNotes 5.0 (coreference benchmark)
82//! - GAP (gendered pronoun resolution benchmark)
83
84// Note: This module is feature-gated via `#[cfg(feature = "onnx")]` in mod.rs
85
86use crate::{Entity, Error, Result};
87use ndarray::{Array2, Array3};
88use std::collections::HashMap;
89use std::sync::Arc;
90
91use hf_hub::api::sync::Api;
92use ort::{
93    execution_providers::CPUExecutionProvider, session::builder::GraphOptimizationLevel,
94    session::Session,
95};
96use tokenizers::Tokenizer;
97
98/// A coreference cluster (group of mentions referring to the same entity).
99#[derive(Debug, Clone)]
100pub struct CorefCluster {
101    /// Cluster ID
102    pub id: u32,
103    /// Member mention texts
104    pub mentions: Vec<String>,
105    /// Member mention spans (start, end)
106    pub spans: Vec<(usize, usize)>,
107    /// Canonical name (longest/most informative mention)
108    pub canonical: String,
109}
110
111/// Configuration for T5 coreference model.
112#[derive(Debug, Clone)]
113pub struct T5CorefConfig {
114    /// Maximum input length (tokens)
115    pub max_input_length: usize,
116    /// Maximum output length (tokens)
117    pub max_output_length: usize,
118    /// Beam search width (1 = greedy)
119    pub num_beams: usize,
120    /// ONNX optimization level
121    pub optimization_level: u8,
122    /// Number of inference threads
123    pub num_threads: usize,
124}
125
126impl Default for T5CorefConfig {
127    fn default() -> Self {
128        Self {
129            max_input_length: 512,
130            max_output_length: 512,
131            num_beams: 1, // Greedy for speed
132            optimization_level: 3,
133            num_threads: 4,
134        }
135    }
136}
137
138/// T5-based coreference resolution.
139///
140/// Uses a seq2seq model to assign cluster IDs to marked mentions.
141///
142/// # Note
143///
144/// Currently uses a simplified rule-based fallback. Full seq2seq inference
145/// is planned for a future release when encoder-decoder ONNX support matures.
146pub struct T5Coref {
147    /// Encoder ONNX session.
148    encoder: crate::sync::Mutex<Session>,
149    /// Decoder ONNX session.
150    decoder: crate::sync::Mutex<Session>,
151    /// HuggingFace tokenizer for input encoding and output decoding.
152    tokenizer: Arc<Tokenizer>,
153    /// Inference configuration.
154    config: T5CorefConfig,
155    /// Model path or HuggingFace model ID.
156    model_path: String,
157}
158
159impl T5Coref {
160    /// Create a new T5 coreference model from a local ONNX export.
161    ///
162    /// # Arguments
163    ///
164    /// * `model_path` - Path to directory containing encoder.onnx and decoder_model.onnx
165    pub fn from_path(model_path: &str, config: T5CorefConfig) -> Result<Self> {
166        let encoder_path = format!("{}/encoder_model.onnx", model_path);
167        let decoder_path = format!("{}/decoder_model.onnx", model_path);
168        let tokenizer_path = format!("{}/tokenizer.json", model_path);
169
170        // Check files exist
171        if !std::path::Path::new(&encoder_path).exists() {
172            return Err(Error::Retrieval(format!(
173                "Encoder not found at {}. Export with: optimum-cli export onnx --model <model> --task text2text-generation-with-past {}",
174                encoder_path, model_path
175            )));
176        }
177
178        // Helper to create opt level
179        let get_opt_level = || match config.optimization_level {
180            1 => GraphOptimizationLevel::Level1,
181            2 => GraphOptimizationLevel::Level2,
182            _ => GraphOptimizationLevel::Level3,
183        };
184
185        // Load encoder
186        let encoder = Session::builder()
187            .map_err(|e| Error::Retrieval(format!("Encoder builder: {}", e)))?
188            .with_optimization_level(get_opt_level())
189            .map_err(|e| Error::Retrieval(format!("Encoder opt: {}", e)))?
190            .with_execution_providers([CPUExecutionProvider::default().build()])
191            .map_err(|e| Error::Retrieval(format!("Encoder provider: {}", e)))?
192            .with_intra_threads(config.num_threads)
193            .map_err(|e| Error::Retrieval(format!("Encoder threads: {}", e)))?
194            .commit_from_file(&encoder_path)
195            .map_err(|e| Error::Retrieval(format!("Encoder load: {}", e)))?;
196
197        // Load decoder
198        let decoder = Session::builder()
199            .map_err(|e| Error::Retrieval(format!("Decoder builder: {}", e)))?
200            .with_optimization_level(get_opt_level())
201            .map_err(|e| Error::Retrieval(format!("Decoder opt: {}", e)))?
202            .with_execution_providers([CPUExecutionProvider::default().build()])
203            .map_err(|e| Error::Retrieval(format!("Decoder provider: {}", e)))?
204            .with_intra_threads(config.num_threads)
205            .map_err(|e| Error::Retrieval(format!("Decoder threads: {}", e)))?
206            .commit_from_file(&decoder_path)
207            .map_err(|e| Error::Retrieval(format!("Decoder load: {}", e)))?;
208
209        // Load tokenizer
210        let tokenizer = Tokenizer::from_file(&tokenizer_path)
211            .map_err(|e| Error::Retrieval(format!("Tokenizer: {}", e)))?;
212
213        log::info!("[T5-Coref] Loaded model from {}", model_path);
214
215        Ok(Self {
216            encoder: crate::sync::Mutex::new(encoder),
217            decoder: crate::sync::Mutex::new(decoder),
218            tokenizer: Arc::new(tokenizer),
219            config,
220            model_path: model_path.to_string(),
221        })
222    }
223
224    /// Create from HuggingFace model ID.
225    ///
226    /// Downloads ONNX-exported T5 model from HuggingFace Hub.
227    pub fn from_pretrained(model_id: &str) -> Result<Self> {
228        Self::from_pretrained_with_config(model_id, T5CorefConfig::default())
229    }
230
231    /// Create from HuggingFace with custom config.
232    pub fn from_pretrained_with_config(model_id: &str, config: T5CorefConfig) -> Result<Self> {
233        let api = Api::new().map_err(|e| Error::Retrieval(format!("HuggingFace API: {}", e)))?;
234
235        let repo = api.model(model_id.to_string());
236
237        // Download ONNX files
238        let encoder_path = repo
239            .get("encoder_model.onnx")
240            .or_else(|_| repo.get("onnx/encoder_model.onnx"))
241            .map_err(|e| Error::Retrieval(format!("Encoder download: {}", e)))?;
242
243        let decoder_path = repo
244            .get("decoder_model.onnx")
245            .or_else(|_| repo.get("onnx/decoder_model.onnx"))
246            .or_else(|_| repo.get("decoder_with_past_model.onnx"))
247            .map_err(|e| Error::Retrieval(format!("Decoder download: {}", e)))?;
248
249        let tokenizer_path = repo
250            .get("tokenizer.json")
251            .map_err(|e| Error::Retrieval(format!("Tokenizer download: {}", e)))?;
252
253        // Helper to create opt level
254        let get_opt_level = || match config.optimization_level {
255            1 => GraphOptimizationLevel::Level1,
256            2 => GraphOptimizationLevel::Level2,
257            _ => GraphOptimizationLevel::Level3,
258        };
259
260        // Load encoder
261        let encoder = Session::builder()
262            .map_err(|e| Error::Retrieval(format!("Encoder builder: {}", e)))?
263            .with_optimization_level(get_opt_level())
264            .map_err(|e| Error::Retrieval(format!("Encoder opt: {}", e)))?
265            .with_execution_providers([CPUExecutionProvider::default().build()])
266            .map_err(|e| Error::Retrieval(format!("Encoder provider: {}", e)))?
267            .commit_from_file(&encoder_path)
268            .map_err(|e| Error::Retrieval(format!("Encoder load: {}", e)))?;
269
270        // Load decoder
271        let decoder = Session::builder()
272            .map_err(|e| Error::Retrieval(format!("Decoder builder: {}", e)))?
273            .with_optimization_level(get_opt_level())
274            .map_err(|e| Error::Retrieval(format!("Decoder opt: {}", e)))?
275            .with_execution_providers([CPUExecutionProvider::default().build()])
276            .map_err(|e| Error::Retrieval(format!("Decoder provider: {}", e)))?
277            .commit_from_file(&decoder_path)
278            .map_err(|e| Error::Retrieval(format!("Decoder load: {}", e)))?;
279
280        // Load tokenizer
281        let tokenizer = Tokenizer::from_file(&tokenizer_path)
282            .map_err(|e| Error::Retrieval(format!("Tokenizer: {}", e)))?;
283
284        log::info!("[T5-Coref] Loaded model from {}", model_id);
285
286        Ok(Self {
287            encoder: crate::sync::Mutex::new(encoder),
288            decoder: crate::sync::Mutex::new(decoder),
289            tokenizer: Arc::new(tokenizer),
290            config,
291            model_path: model_id.to_string(),
292        })
293    }
294
295    /// Resolve coreference in text.
296    ///
297    /// Runs the T5 encoder-decoder loop when ONNX weights are available.
298    /// Falls back to the rule-based heuristic if inference fails or produces
299    /// no clusters (e.g. model not fine-tuned for coref, or GPU OOM).
300    pub fn resolve(&self, text: &str) -> Result<Vec<CorefCluster>> {
301        if text.is_empty() {
302            return Ok(vec![]);
303        }
304        match self.resolve_t5(text) {
305            Ok(clusters) if !clusters.is_empty() => Ok(clusters),
306            Ok(_) => {
307                log::debug!("[T5-Coref] inference produced no clusters, using heuristic fallback");
308                self.resolve_simple(text)
309            }
310            Err(e) => {
311                log::warn!(
312                    "[T5-Coref] inference failed ({}), using heuristic fallback",
313                    e
314                );
315                self.resolve_simple(text)
316            }
317        }
318    }
319
320    // -------------------------------------------------------------------------
321    // T5 encoder-decoder inference
322    // -------------------------------------------------------------------------
323
324    /// Full T5 inference path: mark → encode → greedy-decode → parse.
325    fn resolve_t5(&self, text: &str) -> Result<Vec<CorefCluster>> {
326        let marked = self.mark_mentions(text);
327        let (input_ids, attention_mask) = self.tokenize_input(&marked)?;
328        let (enc_hidden, enc_seq_len, hidden_size) =
329            self.run_encoder(&input_ids, &attention_mask)?;
330        let output_ids =
331            self.greedy_decode(&enc_hidden, enc_seq_len, hidden_size, &attention_mask)?;
332        let decoded = self.decode_tokens(&output_ids)?;
333        Ok(self.parse_coref_output(&decoded))
334    }
335
336    /// Heuristically mark pronouns and capitalised tokens with `<m>…</m>` so the
337    /// T5 model sees explicit mention boundaries.
338    fn mark_mentions(&self, text: &str) -> String {
339        const PRONOUNS: &[&str] = &[
340            "he", "she", "they", "it", "him", "her", "them", "his", "hers", "their", "its",
341        ];
342        let mut out = String::with_capacity(text.len() + 64);
343        for (i, word) in text.split_whitespace().enumerate() {
344            if i > 0 {
345                out.push(' ');
346            }
347            let lower = word
348                .trim_matches(|c: char| !c.is_alphabetic())
349                .to_lowercase();
350            let is_pronoun = PRONOUNS.contains(&lower.as_str());
351            let is_cap = word.chars().next().map(|c| c.is_uppercase()).unwrap_or(false);
352            if is_pronoun || is_cap {
353                out.push_str("<m> ");
354                out.push_str(word);
355                out.push_str(" </m>");
356            } else {
357                out.push_str(word);
358            }
359        }
360        out
361    }
362
363    /// Tokenize `text` with the HuggingFace tokenizer.
364    /// Returns `(input_ids, attention_mask)` as `i64` vecs, truncated to
365    /// `config.max_input_length`.
366    fn tokenize_input(&self, text: &str) -> Result<(Vec<i64>, Vec<i64>)> {
367        let mut enc = self
368            .tokenizer
369            .encode(text, true)
370            .map_err(|e| Error::Parse(format!("T5Coref tokenizer encode: {e}")))?;
371        // Encoding::truncate returns () — not a Result.
372        enc.truncate(self.config.max_input_length, 0, tokenizers::TruncationDirection::Right);
373        let input_ids: Vec<i64> = enc.get_ids().iter().map(|&x| x as i64).collect();
374        let attention_mask: Vec<i64> = enc
375            .get_attention_mask()
376            .iter()
377            .map(|&x| x as i64)
378            .collect();
379        Ok((input_ids, attention_mask))
380    }
381
382    /// Run the T5 encoder.  Returns `(flat_hidden_states, seq_len, hidden_size)`.
383    fn run_encoder(
384        &self,
385        input_ids: &[i64],
386        attention_mask: &[i64],
387    ) -> Result<(Vec<f32>, usize, usize)> {
388        let batch = 1usize;
389        let seq_len = input_ids.len();
390
391        let ids_arr = Array2::<i64>::from_shape_vec((batch, seq_len), input_ids.to_vec())
392            .map_err(|e| Error::Parse(format!("encoder ids shape: {e}")))?;
393        let mask_arr =
394            Array2::<i64>::from_shape_vec((batch, seq_len), attention_mask.to_vec())
395                .map_err(|e| Error::Parse(format!("encoder mask shape: {e}")))?;
396
397        let ids_t = super::ort_compat::tensor_from_ndarray(ids_arr)
398            .map_err(|e| Error::Parse(format!("encoder ids tensor: {e}")))?;
399        let mask_t = super::ort_compat::tensor_from_ndarray(mask_arr)
400            .map_err(|e| Error::Parse(format!("encoder mask tensor: {e}")))?;
401
402        // Scope the mutex guard: `outputs` borrows from the session; extract owned
403        // data before the guard drops.
404        let (hidden_flat, hidden_size) = {
405            let mut enc = crate::sync::lock(&self.encoder);
406            let outputs = enc
407                .run(ort::inputs![
408                    "input_ids" => ids_t.into_dyn(),
409                    "attention_mask" => mask_t.into_dyn(),
410                ])
411                .map_err(|e| Error::Parse(format!("T5Coref encoder run: {e}")))?;
412            let hidden_val = outputs.get("last_hidden_state").ok_or_else(|| {
413                Error::Parse(
414                    "T5 encoder output 'last_hidden_state' not found; check ONNX export".into(),
415                )
416            })?;
417            let (shape, data) = hidden_val
418                .try_extract_tensor::<f32>()
419                .map_err(|e| Error::Parse(format!("encoder extract tensor: {e}")))?;
420            if shape.len() != 3 || shape[0] != 1 {
421                return Err(Error::Parse(format!(
422                    "T5 encoder: unexpected hidden-state shape {:?}",
423                    shape
424                )));
425            }
426            (data.to_vec(), shape[2] as usize)
427        }; // enc guard drops here
428        Ok((hidden_flat, seq_len, hidden_size))
429    }
430
431    /// Run one greedy decoder step and return the next token ID.
432    ///
433    /// The full `decoder_input_ids` sequence is fed each time (no KV-cache),
434    /// which is O(n²) but correct and avoids managing past key-values.
435    fn decoder_step(
436        &self,
437        encoder_hidden: &[f32],
438        enc_seq_len: usize,
439        hidden_size: usize,
440        attention_mask: &[i64],
441        decoder_input_ids: &[i64],
442    ) -> Result<i64> {
443        let batch = 1usize;
444        let dec_len = decoder_input_ids.len();
445
446        let enc_h =
447            Array3::<f32>::from_shape_vec((batch, enc_seq_len, hidden_size), encoder_hidden.to_vec())
448                .map_err(|e| Error::Parse(format!("decoder enc_hidden shape: {e}")))?;
449        let attn = Array2::<i64>::from_shape_vec((batch, enc_seq_len), attention_mask.to_vec())
450            .map_err(|e| Error::Parse(format!("decoder attn shape: {e}")))?;
451        let dec_ids =
452            Array2::<i64>::from_shape_vec((batch, dec_len), decoder_input_ids.to_vec())
453                .map_err(|e| Error::Parse(format!("decoder_ids shape: {e}")))?;
454
455        let enc_h_t = super::ort_compat::tensor_from_ndarray(enc_h)
456            .map_err(|e| Error::Parse(format!("enc_h tensor: {e}")))?;
457        let attn_t = super::ort_compat::tensor_from_ndarray(attn)
458            .map_err(|e| Error::Parse(format!("attn tensor: {e}")))?;
459        let dec_ids_t = super::ort_compat::tensor_from_ndarray(dec_ids)
460            .map_err(|e| Error::Parse(format!("dec_ids tensor: {e}")))?;
461
462        // Scope the mutex guard: extract owned data before the guard drops.
463        let next_token = {
464            let mut dec = crate::sync::lock(&self.decoder);
465            let outputs = dec
466                .run(ort::inputs![
467                    "encoder_hidden_states" => enc_h_t.into_dyn(),
468                    "attention_mask"        => attn_t.into_dyn(),
469                    "decoder_input_ids"     => dec_ids_t.into_dyn(),
470                ])
471                .map_err(|e| Error::Parse(format!("T5Coref decoder run: {e}")))?;
472            let logits_val = outputs.get("logits").ok_or_else(|| {
473                Error::Parse("T5 decoder output 'logits' not found; check ONNX export".into())
474            })?;
475            let (shape, logits_data) = logits_val
476                .try_extract_tensor::<f32>()
477                .map_err(|e| Error::Parse(format!("decoder logits extract: {e}")))?;
478            // Expected shape: [1, dec_len, vocab_size]
479            if shape.len() != 3 || shape[0] != 1 {
480                return Err(Error::Parse(format!(
481                    "T5 decoder: unexpected logits shape {:?}",
482                    shape
483                )));
484            }
485            let vocab_size = shape[2] as usize;
486            let last_offset = (dec_len - 1) * vocab_size;
487            let last_logits = &logits_data[last_offset..last_offset + vocab_size];
488            last_logits
489                .iter()
490                .enumerate()
491                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
492                .map(|(i, _)| i as i64)
493                .unwrap_or(1) // EOS as fallback
494        }; // dec guard drops here
495        Ok(next_token)
496    }
497
498    /// Greedy decode from encoder output.  Returns generated token IDs (excluding the
499    /// leading pad/decoder-start token).
500    fn greedy_decode(
501        &self,
502        encoder_hidden: &[f32],
503        enc_seq_len: usize,
504        hidden_size: usize,
505        attention_mask: &[i64],
506    ) -> Result<Vec<i64>> {
507        // T5 uses pad token (0) as the decoder start token; EOS is 1.
508        const T5_PAD: i64 = 0;
509        const T5_EOS: i64 = 1;
510        let mut generated = vec![T5_PAD];
511
512        for _ in 0..self.config.max_output_length {
513            let next = self.decoder_step(
514                encoder_hidden,
515                enc_seq_len,
516                hidden_size,
517                attention_mask,
518                &generated,
519            )?;
520            if next == T5_EOS {
521                break;
522            }
523            generated.push(next);
524        }
525
526        Ok(generated[1..].to_vec()) // drop the leading pad start token
527    }
528
529    /// Decode output token IDs to a string with the tokenizer.
530    fn decode_tokens(&self, token_ids: &[i64]) -> Result<String> {
531        let ids: Vec<u32> = token_ids.iter().map(|&x| x as u32).collect();
532        self.tokenizer
533            .decode(&ids, true)
534            .map_err(|e| Error::Parse(format!("T5Coref decode_tokens: {e}")))
535    }
536
537    /// Parse T5 cluster-ID output format (`"word | N"`) into `CorefCluster`s.
538    ///
539    /// The expected output format is:
540    /// ```text
541    /// "Elon | 1 founded Tesla | 2. He | 1 later led SpaceX."
542    /// ```
543    /// where ` | N` immediately follows a mention token and assigns it to cluster `N`.
544    ///
545    /// Singletons (clusters with only one mention) are filtered out.
546    fn parse_coref_output(&self, decoded: &str) -> Vec<CorefCluster> {
547        let mut clusters: HashMap<u32, CorefCluster> = HashMap::new();
548        let tokens: Vec<&str> = decoded.split_whitespace().collect();
549        let mut offset: usize = 0;
550        let mut i = 0;
551
552        while i < tokens.len() {
553            let tok = tokens[i];
554            // Look for the pattern: <word> "|" <digits>
555            let is_pipe = tokens.get(i + 1).map(|&t| t == "|").unwrap_or(false);
556            let cluster_id: Option<u32> = if is_pipe {
557                tokens
558                    .get(i + 2)
559                    .and_then(|t| t.trim_matches(|c: char| !c.is_ascii_digit()).parse().ok())
560            } else {
561                None
562            };
563
564            if let Some(cid) = cluster_id {
565                let mention = tok.trim_matches(|c: char| !c.is_alphanumeric()).to_string();
566                if !mention.is_empty() {
567                    let start = offset;
568                    let end = offset + mention.len();
569                    let entry = clusters.entry(cid).or_insert_with(|| CorefCluster {
570                        id: cid,
571                        mentions: Vec::new(),
572                        spans: Vec::new(),
573                        canonical: String::new(),
574                    });
575                    entry.mentions.push(mention);
576                    entry.spans.push((start, end));
577                }
578                offset += tok.len() + 1;
579                i += 3; // skip: word, "|", number
580                continue;
581            }
582
583            offset += tok.len() + 1;
584            i += 1;
585        }
586
587        let mut result: Vec<CorefCluster> = clusters
588            .into_values()
589            .filter(|c| c.mentions.len() > 1)
590            .collect();
591
592        for c in &mut result {
593            c.canonical = c
594                .mentions
595                .iter()
596                .max_by_key(|m| m.len())
597                .cloned()
598                .unwrap_or_default();
599        }
600        result.sort_by_key(|c| c.id);
601        result
602    }
603
604    /// Resolve coreference with pre-marked mentions.
605    ///
606    /// Expects mentions marked with `<m>` and `</m>` tags:
607    /// `"<m> Sophie Wilson </m> designed ARM. <m> She </m> changed computing."`
608    pub fn resolve_marked(&self, marked_text: &str) -> Result<Vec<CorefCluster>> {
609        // Extract mentions from marked text
610        let (plain_text, mentions) = self.extract_mentions(marked_text)?;
611
612        if mentions.is_empty() {
613            return Ok(vec![]);
614        }
615
616        // For full T5 inference, we would:
617        // 1. Encode the marked text
618        // 2. Run autoregressive decoding
619        // 3. Parse cluster IDs from output
620
621        // Simplified: cluster by string similarity
622        self.cluster_mentions(&plain_text, &mentions)
623    }
624
625    /// Resolve coreference for a set of entities.
626    ///
627    /// Takes entities from NER and groups coreferent mentions.
628    pub fn resolve_entities(&self, text: &str, entities: &[Entity]) -> Result<Vec<CorefCluster>> {
629        if entities.is_empty() {
630            return Ok(vec![]);
631        }
632
633        // Convert entities to mentions
634        let mentions: Vec<(String, usize, usize)> = entities
635            .iter()
636            .map(|e| (e.text.clone(), e.start, e.end))
637            .collect();
638
639        self.cluster_mentions(text, &mentions)
640    }
641
642    /// Simple rule-based coreference (fallback).
643    fn resolve_simple(&self, text: &str) -> Result<Vec<CorefCluster>> {
644        // Simple heuristic: find pronouns and link to nearest compatible noun
645        let pronouns = ["he", "she", "they", "it", "his", "her", "their", "its"];
646
647        let words: Vec<(String, usize, usize)> = {
648            let mut result = Vec::new();
649            let mut pos = 0;
650            for word in text.split_whitespace() {
651                if let Some(start) = text[pos..].find(word) {
652                    let abs_start = pos + start;
653                    result.push((word.to_string(), abs_start, abs_start + word.len()));
654                    pos = abs_start + word.len();
655                }
656            }
657            result
658        };
659
660        // Find potential antecedents (capitalized words, likely names)
661        let antecedents: Vec<&(String, usize, usize)> = words
662            .iter()
663            .filter(|(w, _, _)| {
664                w.chars().next().map(|c| c.is_uppercase()).unwrap_or(false)
665                    && !pronouns.contains(&w.to_lowercase().as_str())
666            })
667            .collect();
668
669        // Find pronouns
670        let pronoun_mentions: Vec<&(String, usize, usize)> = words
671            .iter()
672            .filter(|(w, _, _)| pronouns.contains(&w.to_lowercase().as_str()))
673            .collect();
674
675        // Build clusters
676        let mut clusters: Vec<CorefCluster> = Vec::new();
677        let mut assigned: HashMap<usize, u32> = HashMap::new();
678
679        for (ant_text, ant_start, ant_end) in &antecedents {
680            // Check if already assigned
681            if assigned.contains_key(ant_start) {
682                continue;
683            }
684
685            let cluster_id = clusters.len() as u32;
686            let mut mentions = vec![ant_text.clone()];
687            let mut spans = vec![(*ant_start, *ant_end)];
688
689            assigned.insert(*ant_start, cluster_id);
690
691            // Find pronouns after this antecedent that could refer to it
692            for (pro_text, pro_start, pro_end) in &pronoun_mentions {
693                if *pro_start > *ant_end && !assigned.contains_key(pro_start) {
694                    // Check gender compatibility (simplified)
695                    let compatible = match pro_text.to_lowercase().as_str() {
696                        "he" | "him" | "his" => true, // Could be anyone
697                        "she" | "her" | "hers" => true,
698                        "they" | "them" | "their" | "theirs" => true,
699                        "it" | "its" => true,
700                        _ => true,
701                    };
702
703                    if compatible {
704                        mentions.push(pro_text.clone());
705                        spans.push((*pro_start, *pro_end));
706                        assigned.insert(*pro_start, cluster_id);
707                        break; // Only link nearest pronoun
708                    }
709                }
710            }
711
712            if mentions.len() > 1 {
713                clusters.push(CorefCluster {
714                    id: cluster_id,
715                    canonical: ant_text.clone(),
716                    mentions,
717                    spans,
718                });
719            }
720        }
721
722        Ok(clusters)
723    }
724
725    /// Extract mentions from marked text.
726    #[allow(clippy::type_complexity)] // Return type is clear in context
727    fn extract_mentions(&self, marked_text: &str) -> Result<(String, Vec<(String, usize, usize)>)> {
728        let mut plain_text = String::new();
729        let mut mentions = Vec::new();
730        let mut offset = 0;
731
732        let mut remaining = marked_text;
733        while !remaining.is_empty() {
734            if let Some(start_pos) = remaining.find("<m>") {
735                // Add text before marker
736                plain_text.push_str(&remaining[..start_pos]);
737                offset += start_pos;
738
739                // Find end marker
740                let after_start = &remaining[start_pos + 3..];
741                if let Some(end_pos) = after_start.find("</m>") {
742                    let mention_text = after_start[..end_pos].trim();
743                    let mention_start = offset;
744                    plain_text.push_str(mention_text);
745                    let mention_end = offset + mention_text.len();
746                    offset = mention_end;
747
748                    mentions.push((mention_text.to_string(), mention_start, mention_end));
749
750                    remaining = &after_start[end_pos + 4..];
751                } else {
752                    // No end marker, add rest as-is
753                    plain_text.push_str(remaining);
754                    break;
755                }
756            } else {
757                // No more markers
758                plain_text.push_str(remaining);
759                break;
760            }
761        }
762
763        Ok((plain_text, mentions))
764    }
765
766    /// Cluster mentions by similarity.
767    fn cluster_mentions(
768        &self,
769        _text: &str,
770        mentions: &[(String, usize, usize)],
771    ) -> Result<Vec<CorefCluster>> {
772        // Simple clustering: exact match + substring match + pronoun resolution
773        let mut clusters: Vec<CorefCluster> = Vec::new();
774        let mut assigned: HashMap<usize, u32> = HashMap::new();
775
776        let pronouns = [
777            "he", "she", "they", "it", "him", "her", "them", "his", "hers", "their", "its",
778        ];
779
780        for (i, (text_i, start_i, end_i)) in mentions.iter().enumerate() {
781            if assigned.contains_key(&i) {
782                continue;
783            }
784
785            let lower_i = text_i.to_lowercase();
786            let is_pronoun_i = pronouns.contains(&lower_i.as_str());
787
788            if is_pronoun_i {
789                // Find nearest preceding non-pronoun
790                for j in (0..i).rev() {
791                    let (text_j, _, _) = &mentions[j];
792                    let lower_j = text_j.to_lowercase();
793                    if !pronouns.contains(&lower_j.as_str()) {
794                        if let Some(&cluster_id) = assigned.get(&j) {
795                            assigned.insert(i, cluster_id);
796                            clusters[cluster_id as usize].mentions.push(text_i.clone());
797                            clusters[cluster_id as usize].spans.push((*start_i, *end_i));
798                        }
799                        break;
800                    }
801                }
802                continue;
803            }
804
805            // Start new cluster
806            let cluster_id = clusters.len() as u32;
807            let mut cluster_mentions = vec![text_i.clone()];
808            let mut cluster_spans = vec![(*start_i, *end_i)];
809            assigned.insert(i, cluster_id);
810
811            // Find matches
812            for (j, (text_j, start_j, end_j)) in mentions.iter().enumerate().skip(i + 1) {
813                if assigned.contains_key(&j) {
814                    continue;
815                }
816
817                let lower_j = text_j.to_lowercase();
818
819                // Exact match
820                let matches = lower_i == lower_j
821                    // Substring match
822                    || lower_i.contains(&lower_j)
823                    || lower_j.contains(&lower_i)
824                    // Last word match (surname)
825                    || {
826                        let last_i = lower_i.split_whitespace().last();
827                        let last_j = lower_j.split_whitespace().last();
828                        last_i.is_some() && last_i == last_j && last_i.map(|w| w.len() > 2).unwrap_or(false)
829                    };
830
831                if matches {
832                    cluster_mentions.push(text_j.clone());
833                    cluster_spans.push((*start_j, *end_j));
834                    assigned.insert(j, cluster_id);
835                }
836            }
837
838            // Determine canonical (longest mention)
839            let canonical = cluster_mentions
840                .iter()
841                .max_by_key(|m| m.len())
842                .cloned()
843                .unwrap_or_else(|| text_i.clone());
844
845            clusters.push(CorefCluster {
846                id: cluster_id,
847                mentions: cluster_mentions,
848                spans: cluster_spans,
849                canonical,
850            });
851        }
852
853        // Filter to only multi-mention clusters
854        let multi_clusters: Vec<CorefCluster> = clusters
855            .into_iter()
856            .filter(|c| c.mentions.len() > 1)
857            .collect();
858
859        Ok(multi_clusters)
860    }
861
862    /// Get model path.
863    pub fn model_path(&self) -> &str {
864        &self.model_path
865    }
866}
867
868// =============================================================================
869// Tests
870// =============================================================================
871
872#[cfg(test)]
873mod tests {
874    use super::*;
875
876    // Note: Full T5Coref instance tests require actual model files
877    // which are expensive to download. Integration tests handle this.
878
879    #[test]
880    fn test_coref_config_default() {
881        let config = T5CorefConfig::default();
882        assert_eq!(config.max_input_length, 512);
883        assert_eq!(config.num_beams, 1);
884    }
885
886    #[test]
887    fn test_cluster_struct() {
888        let cluster = CorefCluster {
889            id: 0,
890            mentions: vec!["Marie Curie".to_string(), "She".to_string()],
891            spans: vec![(0, 11), (50, 53)],
892            canonical: "Marie Curie".to_string(),
893        };
894
895        assert_eq!(cluster.mentions.len(), 2);
896        assert_eq!(cluster.canonical, "Marie Curie");
897    }
898}