1use crate::{Entity, Error, Result};
87
88type MentionList = (String, Vec<(String, usize, usize)>);
90use ndarray::{Array2, Array3};
91use std::collections::HashMap;
92use std::sync::Arc;
93
94use hf_hub::api::sync::Api;
95use ort::{
96 execution_providers::CPUExecutionProvider, session::builder::GraphOptimizationLevel,
97 session::Session,
98};
99use tokenizers::Tokenizer;
100
101#[derive(Debug, Clone)]
103pub struct CorefCluster {
104 pub id: u32,
106 pub mentions: Vec<String>,
108 pub spans: Vec<(usize, usize)>,
110 pub canonical: String,
112}
113
114#[derive(Debug, Clone)]
116pub struct T5CorefConfig {
117 pub max_input_length: usize,
119 pub max_output_length: usize,
121 pub num_beams: usize,
123 pub optimization_level: u8,
125 pub num_threads: usize,
127}
128
129impl Default for T5CorefConfig {
130 fn default() -> Self {
131 Self {
132 max_input_length: 512,
133 max_output_length: 512,
134 num_beams: 1, optimization_level: 3,
136 num_threads: 4,
137 }
138 }
139}
140
141pub struct T5Coref {
150 encoder: crate::sync::Mutex<Session>,
152 decoder: crate::sync::Mutex<Session>,
154 tokenizer: Arc<Tokenizer>,
156 config: T5CorefConfig,
158 model_path: String,
160}
161
162impl T5Coref {
163 pub fn from_path(model_path: &str, config: T5CorefConfig) -> Result<Self> {
169 let encoder_path = format!("{}/encoder_model.onnx", model_path);
170 let decoder_path = format!("{}/decoder_model.onnx", model_path);
171 let tokenizer_path = format!("{}/tokenizer.json", model_path);
172
173 if !std::path::Path::new(&encoder_path).exists() {
175 return Err(Error::Retrieval(format!(
176 "Encoder not found at {}. Export with: optimum-cli export onnx --model <model> --task text2text-generation-with-past {}",
177 encoder_path, model_path
178 )));
179 }
180
181 let get_opt_level = || match config.optimization_level {
183 1 => GraphOptimizationLevel::Level1,
184 2 => GraphOptimizationLevel::Level2,
185 _ => GraphOptimizationLevel::Level3,
186 };
187
188 let encoder = Session::builder()
190 .map_err(|e| Error::Retrieval(format!("Encoder builder: {}", e)))?
191 .with_optimization_level(get_opt_level())
192 .map_err(|e| Error::Retrieval(format!("Encoder opt: {}", e)))?
193 .with_execution_providers([CPUExecutionProvider::default().build()])
194 .map_err(|e| Error::Retrieval(format!("Encoder provider: {}", e)))?
195 .with_intra_threads(config.num_threads)
196 .map_err(|e| Error::Retrieval(format!("Encoder threads: {}", e)))?
197 .commit_from_file(&encoder_path)
198 .map_err(|e| Error::Retrieval(format!("Encoder load: {}", e)))?;
199
200 let decoder = Session::builder()
202 .map_err(|e| Error::Retrieval(format!("Decoder builder: {}", e)))?
203 .with_optimization_level(get_opt_level())
204 .map_err(|e| Error::Retrieval(format!("Decoder opt: {}", e)))?
205 .with_execution_providers([CPUExecutionProvider::default().build()])
206 .map_err(|e| Error::Retrieval(format!("Decoder provider: {}", e)))?
207 .with_intra_threads(config.num_threads)
208 .map_err(|e| Error::Retrieval(format!("Decoder threads: {}", e)))?
209 .commit_from_file(&decoder_path)
210 .map_err(|e| Error::Retrieval(format!("Decoder load: {}", e)))?;
211
212 let tokenizer = Tokenizer::from_file(&tokenizer_path)
214 .map_err(|e| Error::Retrieval(format!("Tokenizer: {}", e)))?;
215
216 log::info!("[T5-Coref] Loaded model from {}", model_path);
217
218 Ok(Self {
219 encoder: crate::sync::Mutex::new(encoder),
220 decoder: crate::sync::Mutex::new(decoder),
221 tokenizer: Arc::new(tokenizer),
222 config,
223 model_path: model_path.to_string(),
224 })
225 }
226
227 pub fn from_pretrained(model_id: &str) -> Result<Self> {
231 Self::from_pretrained_with_config(model_id, T5CorefConfig::default())
232 }
233
234 pub fn from_pretrained_with_config(model_id: &str, config: T5CorefConfig) -> Result<Self> {
236 let api = Api::new().map_err(|e| Error::Retrieval(format!("HuggingFace API: {}", e)))?;
237
238 let repo = api.model(model_id.to_string());
239
240 let encoder_path = repo
242 .get("encoder_model.onnx")
243 .or_else(|_| repo.get("onnx/encoder_model.onnx"))
244 .map_err(|e| Error::Retrieval(format!("Encoder download: {}", e)))?;
245
246 let decoder_path = repo
247 .get("decoder_model.onnx")
248 .or_else(|_| repo.get("onnx/decoder_model.onnx"))
249 .or_else(|_| repo.get("decoder_with_past_model.onnx"))
250 .map_err(|e| Error::Retrieval(format!("Decoder download: {}", e)))?;
251
252 let tokenizer_path = repo
253 .get("tokenizer.json")
254 .map_err(|e| Error::Retrieval(format!("Tokenizer download: {}", e)))?;
255
256 let get_opt_level = || match config.optimization_level {
258 1 => GraphOptimizationLevel::Level1,
259 2 => GraphOptimizationLevel::Level2,
260 _ => GraphOptimizationLevel::Level3,
261 };
262
263 let encoder = Session::builder()
265 .map_err(|e| Error::Retrieval(format!("Encoder builder: {}", e)))?
266 .with_optimization_level(get_opt_level())
267 .map_err(|e| Error::Retrieval(format!("Encoder opt: {}", e)))?
268 .with_execution_providers([CPUExecutionProvider::default().build()])
269 .map_err(|e| Error::Retrieval(format!("Encoder provider: {}", e)))?
270 .commit_from_file(&encoder_path)
271 .map_err(|e| Error::Retrieval(format!("Encoder load: {}", e)))?;
272
273 let decoder = Session::builder()
275 .map_err(|e| Error::Retrieval(format!("Decoder builder: {}", e)))?
276 .with_optimization_level(get_opt_level())
277 .map_err(|e| Error::Retrieval(format!("Decoder opt: {}", e)))?
278 .with_execution_providers([CPUExecutionProvider::default().build()])
279 .map_err(|e| Error::Retrieval(format!("Decoder provider: {}", e)))?
280 .commit_from_file(&decoder_path)
281 .map_err(|e| Error::Retrieval(format!("Decoder load: {}", e)))?;
282
283 let tokenizer = Tokenizer::from_file(&tokenizer_path)
285 .map_err(|e| Error::Retrieval(format!("Tokenizer: {}", e)))?;
286
287 log::info!("[T5-Coref] Loaded model from {}", model_id);
288
289 Ok(Self {
290 encoder: crate::sync::Mutex::new(encoder),
291 decoder: crate::sync::Mutex::new(decoder),
292 tokenizer: Arc::new(tokenizer),
293 config,
294 model_path: model_id.to_string(),
295 })
296 }
297
298 pub fn resolve(&self, text: &str) -> Result<Vec<CorefCluster>> {
304 if text.is_empty() {
305 return Ok(vec![]);
306 }
307 match self.resolve_t5(text) {
308 Ok(clusters) if !clusters.is_empty() => Ok(clusters),
309 Ok(_) => {
310 log::debug!("[T5-Coref] inference produced no clusters, using heuristic fallback");
311 self.resolve_simple(text)
312 }
313 Err(e) => {
314 log::warn!(
315 "[T5-Coref] inference failed ({}), using heuristic fallback",
316 e
317 );
318 self.resolve_simple(text)
319 }
320 }
321 }
322
323 fn resolve_t5(&self, text: &str) -> Result<Vec<CorefCluster>> {
329 let marked = self.mark_mentions(text);
330 let (input_ids, attention_mask) = self.tokenize_input(&marked)?;
331 let (enc_hidden, enc_seq_len, hidden_size) =
332 self.run_encoder(&input_ids, &attention_mask)?;
333 let output_ids =
334 self.greedy_decode(&enc_hidden, enc_seq_len, hidden_size, &attention_mask)?;
335 let decoded = self.decode_tokens(&output_ids)?;
336 Ok(self.parse_coref_output(&decoded))
337 }
338
339 fn mark_mentions(&self, text: &str) -> String {
342 mark_mentions_for_t5(text)
343 }
344
345 fn tokenize_input(&self, text: &str) -> Result<(Vec<i64>, Vec<i64>)> {
349 let mut enc = self
350 .tokenizer
351 .encode(text, true)
352 .map_err(|e| Error::Parse(format!("T5Coref tokenizer encode: {e}")))?;
353 enc.truncate(
355 self.config.max_input_length,
356 0,
357 tokenizers::TruncationDirection::Right,
358 );
359 let input_ids: Vec<i64> = enc.get_ids().iter().map(|&x| x as i64).collect();
360 let attention_mask: Vec<i64> = enc.get_attention_mask().iter().map(|&x| x as i64).collect();
361 Ok((input_ids, attention_mask))
362 }
363
364 fn run_encoder(
366 &self,
367 input_ids: &[i64],
368 attention_mask: &[i64],
369 ) -> Result<(Vec<f32>, usize, usize)> {
370 let batch = 1usize;
371 let seq_len = input_ids.len();
372
373 let ids_arr = Array2::<i64>::from_shape_vec((batch, seq_len), input_ids.to_vec())
374 .map_err(|e| Error::Parse(format!("encoder ids shape: {e}")))?;
375 let mask_arr = Array2::<i64>::from_shape_vec((batch, seq_len), attention_mask.to_vec())
376 .map_err(|e| Error::Parse(format!("encoder mask shape: {e}")))?;
377
378 let ids_t = super::ort_compat::tensor_from_ndarray(ids_arr)
379 .map_err(|e| Error::Parse(format!("encoder ids tensor: {e}")))?;
380 let mask_t = super::ort_compat::tensor_from_ndarray(mask_arr)
381 .map_err(|e| Error::Parse(format!("encoder mask tensor: {e}")))?;
382
383 let (hidden_flat, hidden_size) = {
386 let mut enc = crate::sync::lock(&self.encoder);
387 let outputs = enc
388 .run(ort::inputs![
389 "input_ids" => ids_t.into_dyn(),
390 "attention_mask" => mask_t.into_dyn(),
391 ])
392 .map_err(|e| Error::Parse(format!("T5Coref encoder run: {e}")))?;
393 let hidden_val = outputs.get("last_hidden_state").ok_or_else(|| {
394 Error::Parse(
395 "T5 encoder output 'last_hidden_state' not found; check ONNX export".into(),
396 )
397 })?;
398 let (shape, data) = hidden_val
399 .try_extract_tensor::<f32>()
400 .map_err(|e| Error::Parse(format!("encoder extract tensor: {e}")))?;
401 if shape.len() != 3 || shape[0] != 1 {
402 return Err(Error::Parse(format!(
403 "T5 encoder: unexpected hidden-state shape {:?}",
404 shape
405 )));
406 }
407 (data.to_vec(), shape[2] as usize)
408 }; Ok((hidden_flat, seq_len, hidden_size))
410 }
411
412 fn decoder_step(
417 &self,
418 encoder_hidden: &[f32],
419 enc_seq_len: usize,
420 hidden_size: usize,
421 attention_mask: &[i64],
422 decoder_input_ids: &[i64],
423 ) -> Result<i64> {
424 let batch = 1usize;
425 let dec_len = decoder_input_ids.len();
426
427 let enc_h = Array3::<f32>::from_shape_vec(
428 (batch, enc_seq_len, hidden_size),
429 encoder_hidden.to_vec(),
430 )
431 .map_err(|e| Error::Parse(format!("decoder enc_hidden shape: {e}")))?;
432 let attn = Array2::<i64>::from_shape_vec((batch, enc_seq_len), attention_mask.to_vec())
433 .map_err(|e| Error::Parse(format!("decoder attn shape: {e}")))?;
434 let dec_ids = Array2::<i64>::from_shape_vec((batch, dec_len), decoder_input_ids.to_vec())
435 .map_err(|e| Error::Parse(format!("decoder_ids shape: {e}")))?;
436
437 let enc_h_t = super::ort_compat::tensor_from_ndarray(enc_h)
438 .map_err(|e| Error::Parse(format!("enc_h tensor: {e}")))?;
439 let attn_t = super::ort_compat::tensor_from_ndarray(attn)
440 .map_err(|e| Error::Parse(format!("attn tensor: {e}")))?;
441 let dec_ids_t = super::ort_compat::tensor_from_ndarray(dec_ids)
442 .map_err(|e| Error::Parse(format!("dec_ids tensor: {e}")))?;
443
444 let next_token = {
446 let mut dec = crate::sync::lock(&self.decoder);
447 let outputs = dec
448 .run(ort::inputs![
449 "encoder_hidden_states" => enc_h_t.into_dyn(),
450 "attention_mask" => attn_t.into_dyn(),
451 "decoder_input_ids" => dec_ids_t.into_dyn(),
452 ])
453 .map_err(|e| Error::Parse(format!("T5Coref decoder run: {e}")))?;
454 let logits_val = outputs.get("logits").ok_or_else(|| {
455 Error::Parse("T5 decoder output 'logits' not found; check ONNX export".into())
456 })?;
457 let (shape, logits_data) = logits_val
458 .try_extract_tensor::<f32>()
459 .map_err(|e| Error::Parse(format!("decoder logits extract: {e}")))?;
460 if shape.len() != 3 || shape[0] != 1 {
462 return Err(Error::Parse(format!(
463 "T5 decoder: unexpected logits shape {:?}",
464 shape
465 )));
466 }
467 let vocab_size = shape[2] as usize;
468 let last_offset = (dec_len - 1) * vocab_size;
469 let last_logits = &logits_data[last_offset..last_offset + vocab_size];
470 last_logits
471 .iter()
472 .enumerate()
473 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
474 .map(|(i, _)| i as i64)
475 .unwrap_or(1) }; Ok(next_token)
478 }
479
480 fn greedy_decode(
483 &self,
484 encoder_hidden: &[f32],
485 enc_seq_len: usize,
486 hidden_size: usize,
487 attention_mask: &[i64],
488 ) -> Result<Vec<i64>> {
489 const T5_PAD: i64 = 0;
491 const T5_EOS: i64 = 1;
492 let mut generated = vec![T5_PAD];
493
494 for _ in 0..self.config.max_output_length {
495 let next = self.decoder_step(
496 encoder_hidden,
497 enc_seq_len,
498 hidden_size,
499 attention_mask,
500 &generated,
501 )?;
502 if next == T5_EOS {
503 break;
504 }
505 generated.push(next);
506 }
507
508 Ok(generated[1..].to_vec()) }
510
511 fn decode_tokens(&self, token_ids: &[i64]) -> Result<String> {
513 let ids: Vec<u32> = token_ids.iter().map(|&x| x as u32).collect();
514 self.tokenizer
515 .decode(&ids, true)
516 .map_err(|e| Error::Parse(format!("T5Coref decode_tokens: {e}")))
517 }
518
519 fn parse_coref_output(&self, decoded: &str) -> Vec<CorefCluster> {
529 parse_t5_coref_output(decoded)
530 }
531
532 pub fn resolve_marked(&self, marked_text: &str) -> Result<Vec<CorefCluster>> {
545 let (plain_text, mentions) = self.extract_mentions(marked_text)?;
546 if mentions.is_empty() {
547 return Ok(vec![]);
548 }
549 match self.resolve_t5_raw(marked_text) {
551 Ok(clusters) if !clusters.is_empty() => Ok(clusters),
552 Ok(_) => self.cluster_mentions(&plain_text, &mentions),
553 Err(e) => {
554 log::warn!(
555 "[T5-Coref] resolve_marked inference failed ({}), using fallback",
556 e
557 );
558 self.cluster_mentions(&plain_text, &mentions)
559 }
560 }
561 }
562
563 pub fn resolve_entities(&self, text: &str, entities: &[Entity]) -> Result<Vec<CorefCluster>> {
568 if entities.is_empty() {
569 return Ok(vec![]);
570 }
571
572 let marked = self.mark_entity_spans(text, entities);
574 match self.resolve_t5_raw(&marked) {
575 Ok(clusters) if !clusters.is_empty() => Ok(clusters),
576 Ok(_) => {
577 let mentions: Vec<(String, usize, usize)> = entities
578 .iter()
579 .map(|e| (e.text.clone(), e.start, e.end))
580 .collect();
581 self.cluster_mentions(text, &mentions)
582 }
583 Err(e) => {
584 log::warn!(
585 "[T5-Coref] resolve_entities inference failed ({}), using fallback",
586 e
587 );
588 let mentions: Vec<(String, usize, usize)> = entities
589 .iter()
590 .map(|e| (e.text.clone(), e.start, e.end))
591 .collect();
592 self.cluster_mentions(text, &mentions)
593 }
594 }
595 }
596
597 fn resolve_t5_raw(&self, marked_text: &str) -> Result<Vec<CorefCluster>> {
602 let (input_ids, attention_mask) = self.tokenize_input(marked_text)?;
603 let (enc_hidden, enc_seq_len, hidden_size) =
604 self.run_encoder(&input_ids, &attention_mask)?;
605 let output_ids =
606 self.greedy_decode(&enc_hidden, enc_seq_len, hidden_size, &attention_mask)?;
607 let decoded = self.decode_tokens(&output_ids)?;
608 Ok(self.parse_coref_output(&decoded))
609 }
610
611 fn mark_entity_spans(&self, text: &str, entities: &[Entity]) -> String {
615 let chars: Vec<char> = text.chars().collect();
616 let char_len = chars.len();
617
618 let mut sorted: Vec<&Entity> = entities.iter().collect();
619 sorted.sort_by_key(|e| e.start);
620
621 let mut out = String::with_capacity(text.len() + entities.len() * 10);
622 let mut cursor = 0usize; for e in &sorted {
625 if e.start >= e.end || e.start < cursor || e.end > char_len {
626 continue;
627 }
628 for &ch in &chars[cursor..e.start] {
630 out.push(ch);
631 }
632 out.push_str("<m> ");
633 for &ch in &chars[e.start..e.end] {
634 out.push(ch);
635 }
636 out.push_str(" </m>");
637 cursor = e.end;
638 }
639
640 for &ch in &chars[cursor..] {
642 out.push(ch);
643 }
644 out
645 }
646
647 fn resolve_simple(&self, text: &str) -> Result<Vec<CorefCluster>> {
649 let pronouns = ["he", "she", "they", "it", "his", "her", "their", "its"];
651
652 let words: Vec<(String, usize, usize)> = {
653 let mut result = Vec::new();
654 let mut pos = 0;
655 for word in text.split_whitespace() {
656 if let Some(start) = text[pos..].find(word) {
657 let abs_start = pos + start;
658 result.push((word.to_string(), abs_start, abs_start + word.len()));
659 pos = abs_start + word.len();
660 }
661 }
662 result
663 };
664
665 let antecedents: Vec<&(String, usize, usize)> = words
667 .iter()
668 .filter(|(w, _, _)| {
669 w.chars().next().map(|c| c.is_uppercase()).unwrap_or(false)
670 && !pronouns.contains(&w.to_lowercase().as_str())
671 })
672 .collect();
673
674 let pronoun_mentions: Vec<&(String, usize, usize)> = words
676 .iter()
677 .filter(|(w, _, _)| pronouns.contains(&w.to_lowercase().as_str()))
678 .collect();
679
680 let mut clusters: Vec<CorefCluster> = Vec::new();
682 let mut assigned: HashMap<usize, u32> = HashMap::new();
683
684 for (ant_text, ant_start, ant_end) in &antecedents {
685 if assigned.contains_key(ant_start) {
687 continue;
688 }
689
690 let cluster_id = clusters.len() as u32;
691 let mut mentions = vec![ant_text.clone()];
692 let mut spans = vec![(*ant_start, *ant_end)];
693
694 assigned.insert(*ant_start, cluster_id);
695
696 for (pro_text, pro_start, pro_end) in &pronoun_mentions {
698 if *pro_start > *ant_end && !assigned.contains_key(pro_start) {
699 let compatible = match pro_text.to_lowercase().as_str() {
701 "he" | "him" | "his" => true, "she" | "her" | "hers" => true,
703 "they" | "them" | "their" | "theirs" => true,
704 "it" | "its" => true,
705 _ => true,
706 };
707
708 if compatible {
709 mentions.push(pro_text.clone());
710 spans.push((*pro_start, *pro_end));
711 assigned.insert(*pro_start, cluster_id);
712 break; }
714 }
715 }
716
717 if mentions.len() > 1 {
718 clusters.push(CorefCluster {
719 id: cluster_id,
720 canonical: ant_text.clone(),
721 mentions,
722 spans,
723 });
724 }
725 }
726
727 Ok(clusters)
728 }
729
730 fn extract_mentions(&self, marked_text: &str) -> Result<MentionList> {
731 extract_t5_mentions(marked_text)
732 }
733
734 fn cluster_mentions(
736 &self,
737 _text: &str,
738 mentions: &[(String, usize, usize)],
739 ) -> Result<Vec<CorefCluster>> {
740 let mut clusters: Vec<CorefCluster> = Vec::new();
742 let mut assigned: HashMap<usize, u32> = HashMap::new();
743
744 let pronouns = [
745 "he", "she", "they", "it", "him", "her", "them", "his", "hers", "their", "its",
746 ];
747
748 for (i, (text_i, start_i, end_i)) in mentions.iter().enumerate() {
749 if assigned.contains_key(&i) {
750 continue;
751 }
752
753 let lower_i = text_i.to_lowercase();
754 let is_pronoun_i = pronouns.contains(&lower_i.as_str());
755
756 if is_pronoun_i {
757 for j in (0..i).rev() {
759 let (text_j, _, _) = &mentions[j];
760 let lower_j = text_j.to_lowercase();
761 if !pronouns.contains(&lower_j.as_str()) {
762 if let Some(&cluster_id) = assigned.get(&j) {
763 assigned.insert(i, cluster_id);
764 clusters[cluster_id as usize].mentions.push(text_i.clone());
765 clusters[cluster_id as usize].spans.push((*start_i, *end_i));
766 }
767 break;
768 }
769 }
770 continue;
771 }
772
773 let cluster_id = clusters.len() as u32;
775 let mut cluster_mentions = vec![text_i.clone()];
776 let mut cluster_spans = vec![(*start_i, *end_i)];
777 assigned.insert(i, cluster_id);
778
779 for (j, (text_j, start_j, end_j)) in mentions.iter().enumerate().skip(i + 1) {
781 if assigned.contains_key(&j) {
782 continue;
783 }
784
785 let lower_j = text_j.to_lowercase();
786
787 let matches = lower_i == lower_j
789 || lower_i.contains(&lower_j)
791 || lower_j.contains(&lower_i)
792 || {
794 let last_i = lower_i.split_whitespace().last();
795 let last_j = lower_j.split_whitespace().last();
796 last_i.is_some() && last_i == last_j && last_i.map(|w| w.len() > 2).unwrap_or(false)
797 };
798
799 if matches {
800 cluster_mentions.push(text_j.clone());
801 cluster_spans.push((*start_j, *end_j));
802 assigned.insert(j, cluster_id);
803 }
804 }
805
806 let canonical = cluster_mentions
808 .iter()
809 .max_by_key(|m| m.len())
810 .cloned()
811 .unwrap_or_else(|| text_i.clone());
812
813 clusters.push(CorefCluster {
814 id: cluster_id,
815 mentions: cluster_mentions,
816 spans: cluster_spans,
817 canonical,
818 });
819 }
820
821 let multi_clusters: Vec<CorefCluster> = clusters
823 .into_iter()
824 .filter(|c| c.mentions.len() > 1)
825 .collect();
826
827 Ok(multi_clusters)
828 }
829
830 pub fn model_path(&self) -> &str {
832 &self.model_path
833 }
834}
835
836pub fn mark_mentions_for_t5(text: &str) -> String {
845 const PRONOUNS: &[&str] = &[
846 "he", "she", "they", "it", "him", "her", "them", "his", "hers", "their", "its",
847 ];
848 let mut out = String::with_capacity(text.len() + 64);
849 for (i, word) in text.split_whitespace().enumerate() {
850 if i > 0 {
851 out.push(' ');
852 }
853 let lower = word
854 .trim_matches(|c: char| !c.is_alphabetic())
855 .to_lowercase();
856 let is_pronoun = PRONOUNS.contains(&lower.as_str());
857 let is_cap = word
858 .chars()
859 .next()
860 .map(|c| c.is_uppercase())
861 .unwrap_or(false);
862 if is_pronoun || is_cap {
863 out.push_str("<m> ");
864 out.push_str(word);
865 out.push_str(" </m>");
866 } else {
867 out.push_str(word);
868 }
869 }
870 out
871}
872
873pub fn parse_t5_coref_output(decoded: &str) -> Vec<CorefCluster> {
878 let mut clusters: HashMap<u32, CorefCluster> = HashMap::new();
879 let tokens: Vec<&str> = decoded.split_whitespace().collect();
880 let mut offset: usize = 0;
881 let mut i = 0;
882
883 while i < tokens.len() {
884 let tok = tokens[i];
885 let is_pipe = tokens.get(i + 1).map(|&t| t == "|").unwrap_or(false);
886 let cluster_id: Option<u32> = if is_pipe {
887 tokens
888 .get(i + 2)
889 .and_then(|t| t.trim_matches(|c: char| !c.is_ascii_digit()).parse().ok())
890 } else {
891 None
892 };
893
894 if let Some(cid) = cluster_id {
895 let mention = tok.trim_matches(|c: char| !c.is_alphanumeric()).to_string();
896 if !mention.is_empty() {
897 let start = offset;
898 let end = offset + mention.len();
899 let entry = clusters.entry(cid).or_insert_with(|| CorefCluster {
900 id: cid,
901 mentions: Vec::new(),
902 spans: Vec::new(),
903 canonical: String::new(),
904 });
905 entry.mentions.push(mention);
906 entry.spans.push((start, end));
907 }
908 offset += tok.len() + 1;
909 i += 3;
910 continue;
911 }
912
913 offset += tok.len() + 1;
914 i += 1;
915 }
916
917 let mut result: Vec<CorefCluster> = clusters
918 .into_values()
919 .filter(|c| c.mentions.len() > 1)
920 .collect();
921 for c in &mut result {
922 c.canonical = c
923 .mentions
924 .iter()
925 .max_by_key(|m| m.len())
926 .cloned()
927 .unwrap_or_default();
928 }
929 result.sort_by_key(|c| c.id);
930 result
931}
932
933pub fn extract_t5_mentions(marked_text: &str) -> Result<MentionList> {
939 let mut plain_text = String::new();
940 let mut mentions = Vec::new();
941 let mut offset = 0;
942
943 let mut remaining = marked_text;
944 while !remaining.is_empty() {
945 if let Some(start_pos) = remaining.find("<m>") {
946 plain_text.push_str(&remaining[..start_pos]);
947 offset += start_pos;
948
949 let after_start = &remaining[start_pos + 3..];
950 if let Some(end_pos) = after_start.find("</m>") {
951 let mention_text = after_start[..end_pos].trim();
952 let mention_start = offset;
953 plain_text.push_str(mention_text);
954 let mention_end = offset + mention_text.len();
955 offset = mention_end;
956
957 mentions.push((mention_text.to_string(), mention_start, mention_end));
958 remaining = &after_start[end_pos + 4..];
959 } else {
960 plain_text.push_str(remaining);
961 break;
962 }
963 } else {
964 plain_text.push_str(remaining);
965 break;
966 }
967 }
968
969 Ok((plain_text, mentions))
970}
971
972#[cfg(test)]
977mod tests;