1use crate::{Entity, Error, Result};
87use ndarray::{Array2, Array3};
88use std::collections::HashMap;
89use std::sync::Arc;
90
91use hf_hub::api::sync::Api;
92use ort::{
93 execution_providers::CPUExecutionProvider, session::builder::GraphOptimizationLevel,
94 session::Session,
95};
96use tokenizers::Tokenizer;
97
98#[derive(Debug, Clone)]
100pub struct CorefCluster {
101 pub id: u32,
103 pub mentions: Vec<String>,
105 pub spans: Vec<(usize, usize)>,
107 pub canonical: String,
109}
110
111#[derive(Debug, Clone)]
113pub struct T5CorefConfig {
114 pub max_input_length: usize,
116 pub max_output_length: usize,
118 pub num_beams: usize,
120 pub optimization_level: u8,
122 pub num_threads: usize,
124}
125
126impl Default for T5CorefConfig {
127 fn default() -> Self {
128 Self {
129 max_input_length: 512,
130 max_output_length: 512,
131 num_beams: 1, optimization_level: 3,
133 num_threads: 4,
134 }
135 }
136}
137
138pub struct T5Coref {
147 encoder: crate::sync::Mutex<Session>,
149 decoder: crate::sync::Mutex<Session>,
151 tokenizer: Arc<Tokenizer>,
153 config: T5CorefConfig,
155 model_path: String,
157}
158
159impl T5Coref {
160 pub fn from_path(model_path: &str, config: T5CorefConfig) -> Result<Self> {
166 let encoder_path = format!("{}/encoder_model.onnx", model_path);
167 let decoder_path = format!("{}/decoder_model.onnx", model_path);
168 let tokenizer_path = format!("{}/tokenizer.json", model_path);
169
170 if !std::path::Path::new(&encoder_path).exists() {
172 return Err(Error::Retrieval(format!(
173 "Encoder not found at {}. Export with: optimum-cli export onnx --model <model> --task text2text-generation-with-past {}",
174 encoder_path, model_path
175 )));
176 }
177
178 let get_opt_level = || match config.optimization_level {
180 1 => GraphOptimizationLevel::Level1,
181 2 => GraphOptimizationLevel::Level2,
182 _ => GraphOptimizationLevel::Level3,
183 };
184
185 let encoder = Session::builder()
187 .map_err(|e| Error::Retrieval(format!("Encoder builder: {}", e)))?
188 .with_optimization_level(get_opt_level())
189 .map_err(|e| Error::Retrieval(format!("Encoder opt: {}", e)))?
190 .with_execution_providers([CPUExecutionProvider::default().build()])
191 .map_err(|e| Error::Retrieval(format!("Encoder provider: {}", e)))?
192 .with_intra_threads(config.num_threads)
193 .map_err(|e| Error::Retrieval(format!("Encoder threads: {}", e)))?
194 .commit_from_file(&encoder_path)
195 .map_err(|e| Error::Retrieval(format!("Encoder load: {}", e)))?;
196
197 let decoder = Session::builder()
199 .map_err(|e| Error::Retrieval(format!("Decoder builder: {}", e)))?
200 .with_optimization_level(get_opt_level())
201 .map_err(|e| Error::Retrieval(format!("Decoder opt: {}", e)))?
202 .with_execution_providers([CPUExecutionProvider::default().build()])
203 .map_err(|e| Error::Retrieval(format!("Decoder provider: {}", e)))?
204 .with_intra_threads(config.num_threads)
205 .map_err(|e| Error::Retrieval(format!("Decoder threads: {}", e)))?
206 .commit_from_file(&decoder_path)
207 .map_err(|e| Error::Retrieval(format!("Decoder load: {}", e)))?;
208
209 let tokenizer = Tokenizer::from_file(&tokenizer_path)
211 .map_err(|e| Error::Retrieval(format!("Tokenizer: {}", e)))?;
212
213 log::info!("[T5-Coref] Loaded model from {}", model_path);
214
215 Ok(Self {
216 encoder: crate::sync::Mutex::new(encoder),
217 decoder: crate::sync::Mutex::new(decoder),
218 tokenizer: Arc::new(tokenizer),
219 config,
220 model_path: model_path.to_string(),
221 })
222 }
223
224 pub fn from_pretrained(model_id: &str) -> Result<Self> {
228 Self::from_pretrained_with_config(model_id, T5CorefConfig::default())
229 }
230
231 pub fn from_pretrained_with_config(model_id: &str, config: T5CorefConfig) -> Result<Self> {
233 let api = Api::new().map_err(|e| Error::Retrieval(format!("HuggingFace API: {}", e)))?;
234
235 let repo = api.model(model_id.to_string());
236
237 let encoder_path = repo
239 .get("encoder_model.onnx")
240 .or_else(|_| repo.get("onnx/encoder_model.onnx"))
241 .map_err(|e| Error::Retrieval(format!("Encoder download: {}", e)))?;
242
243 let decoder_path = repo
244 .get("decoder_model.onnx")
245 .or_else(|_| repo.get("onnx/decoder_model.onnx"))
246 .or_else(|_| repo.get("decoder_with_past_model.onnx"))
247 .map_err(|e| Error::Retrieval(format!("Decoder download: {}", e)))?;
248
249 let tokenizer_path = repo
250 .get("tokenizer.json")
251 .map_err(|e| Error::Retrieval(format!("Tokenizer download: {}", e)))?;
252
253 let get_opt_level = || match config.optimization_level {
255 1 => GraphOptimizationLevel::Level1,
256 2 => GraphOptimizationLevel::Level2,
257 _ => GraphOptimizationLevel::Level3,
258 };
259
260 let encoder = Session::builder()
262 .map_err(|e| Error::Retrieval(format!("Encoder builder: {}", e)))?
263 .with_optimization_level(get_opt_level())
264 .map_err(|e| Error::Retrieval(format!("Encoder opt: {}", e)))?
265 .with_execution_providers([CPUExecutionProvider::default().build()])
266 .map_err(|e| Error::Retrieval(format!("Encoder provider: {}", e)))?
267 .commit_from_file(&encoder_path)
268 .map_err(|e| Error::Retrieval(format!("Encoder load: {}", e)))?;
269
270 let decoder = Session::builder()
272 .map_err(|e| Error::Retrieval(format!("Decoder builder: {}", e)))?
273 .with_optimization_level(get_opt_level())
274 .map_err(|e| Error::Retrieval(format!("Decoder opt: {}", e)))?
275 .with_execution_providers([CPUExecutionProvider::default().build()])
276 .map_err(|e| Error::Retrieval(format!("Decoder provider: {}", e)))?
277 .commit_from_file(&decoder_path)
278 .map_err(|e| Error::Retrieval(format!("Decoder load: {}", e)))?;
279
280 let tokenizer = Tokenizer::from_file(&tokenizer_path)
282 .map_err(|e| Error::Retrieval(format!("Tokenizer: {}", e)))?;
283
284 log::info!("[T5-Coref] Loaded model from {}", model_id);
285
286 Ok(Self {
287 encoder: crate::sync::Mutex::new(encoder),
288 decoder: crate::sync::Mutex::new(decoder),
289 tokenizer: Arc::new(tokenizer),
290 config,
291 model_path: model_id.to_string(),
292 })
293 }
294
295 pub fn resolve(&self, text: &str) -> Result<Vec<CorefCluster>> {
301 if text.is_empty() {
302 return Ok(vec![]);
303 }
304 match self.resolve_t5(text) {
305 Ok(clusters) if !clusters.is_empty() => Ok(clusters),
306 Ok(_) => {
307 log::debug!("[T5-Coref] inference produced no clusters, using heuristic fallback");
308 self.resolve_simple(text)
309 }
310 Err(e) => {
311 log::warn!(
312 "[T5-Coref] inference failed ({}), using heuristic fallback",
313 e
314 );
315 self.resolve_simple(text)
316 }
317 }
318 }
319
320 fn resolve_t5(&self, text: &str) -> Result<Vec<CorefCluster>> {
326 let marked = self.mark_mentions(text);
327 let (input_ids, attention_mask) = self.tokenize_input(&marked)?;
328 let (enc_hidden, enc_seq_len, hidden_size) =
329 self.run_encoder(&input_ids, &attention_mask)?;
330 let output_ids =
331 self.greedy_decode(&enc_hidden, enc_seq_len, hidden_size, &attention_mask)?;
332 let decoded = self.decode_tokens(&output_ids)?;
333 Ok(self.parse_coref_output(&decoded))
334 }
335
336 fn mark_mentions(&self, text: &str) -> String {
339 const PRONOUNS: &[&str] = &[
340 "he", "she", "they", "it", "him", "her", "them", "his", "hers", "their", "its",
341 ];
342 let mut out = String::with_capacity(text.len() + 64);
343 for (i, word) in text.split_whitespace().enumerate() {
344 if i > 0 {
345 out.push(' ');
346 }
347 let lower = word
348 .trim_matches(|c: char| !c.is_alphabetic())
349 .to_lowercase();
350 let is_pronoun = PRONOUNS.contains(&lower.as_str());
351 let is_cap = word.chars().next().map(|c| c.is_uppercase()).unwrap_or(false);
352 if is_pronoun || is_cap {
353 out.push_str("<m> ");
354 out.push_str(word);
355 out.push_str(" </m>");
356 } else {
357 out.push_str(word);
358 }
359 }
360 out
361 }
362
363 fn tokenize_input(&self, text: &str) -> Result<(Vec<i64>, Vec<i64>)> {
367 let mut enc = self
368 .tokenizer
369 .encode(text, true)
370 .map_err(|e| Error::Parse(format!("T5Coref tokenizer encode: {e}")))?;
371 enc.truncate(self.config.max_input_length, 0, tokenizers::TruncationDirection::Right);
373 let input_ids: Vec<i64> = enc.get_ids().iter().map(|&x| x as i64).collect();
374 let attention_mask: Vec<i64> = enc
375 .get_attention_mask()
376 .iter()
377 .map(|&x| x as i64)
378 .collect();
379 Ok((input_ids, attention_mask))
380 }
381
382 fn run_encoder(
384 &self,
385 input_ids: &[i64],
386 attention_mask: &[i64],
387 ) -> Result<(Vec<f32>, usize, usize)> {
388 let batch = 1usize;
389 let seq_len = input_ids.len();
390
391 let ids_arr = Array2::<i64>::from_shape_vec((batch, seq_len), input_ids.to_vec())
392 .map_err(|e| Error::Parse(format!("encoder ids shape: {e}")))?;
393 let mask_arr =
394 Array2::<i64>::from_shape_vec((batch, seq_len), attention_mask.to_vec())
395 .map_err(|e| Error::Parse(format!("encoder mask shape: {e}")))?;
396
397 let ids_t = super::ort_compat::tensor_from_ndarray(ids_arr)
398 .map_err(|e| Error::Parse(format!("encoder ids tensor: {e}")))?;
399 let mask_t = super::ort_compat::tensor_from_ndarray(mask_arr)
400 .map_err(|e| Error::Parse(format!("encoder mask tensor: {e}")))?;
401
402 let (hidden_flat, hidden_size) = {
405 let mut enc = crate::sync::lock(&self.encoder);
406 let outputs = enc
407 .run(ort::inputs![
408 "input_ids" => ids_t.into_dyn(),
409 "attention_mask" => mask_t.into_dyn(),
410 ])
411 .map_err(|e| Error::Parse(format!("T5Coref encoder run: {e}")))?;
412 let hidden_val = outputs.get("last_hidden_state").ok_or_else(|| {
413 Error::Parse(
414 "T5 encoder output 'last_hidden_state' not found; check ONNX export".into(),
415 )
416 })?;
417 let (shape, data) = hidden_val
418 .try_extract_tensor::<f32>()
419 .map_err(|e| Error::Parse(format!("encoder extract tensor: {e}")))?;
420 if shape.len() != 3 || shape[0] != 1 {
421 return Err(Error::Parse(format!(
422 "T5 encoder: unexpected hidden-state shape {:?}",
423 shape
424 )));
425 }
426 (data.to_vec(), shape[2] as usize)
427 }; Ok((hidden_flat, seq_len, hidden_size))
429 }
430
431 fn decoder_step(
436 &self,
437 encoder_hidden: &[f32],
438 enc_seq_len: usize,
439 hidden_size: usize,
440 attention_mask: &[i64],
441 decoder_input_ids: &[i64],
442 ) -> Result<i64> {
443 let batch = 1usize;
444 let dec_len = decoder_input_ids.len();
445
446 let enc_h =
447 Array3::<f32>::from_shape_vec((batch, enc_seq_len, hidden_size), encoder_hidden.to_vec())
448 .map_err(|e| Error::Parse(format!("decoder enc_hidden shape: {e}")))?;
449 let attn = Array2::<i64>::from_shape_vec((batch, enc_seq_len), attention_mask.to_vec())
450 .map_err(|e| Error::Parse(format!("decoder attn shape: {e}")))?;
451 let dec_ids =
452 Array2::<i64>::from_shape_vec((batch, dec_len), decoder_input_ids.to_vec())
453 .map_err(|e| Error::Parse(format!("decoder_ids shape: {e}")))?;
454
455 let enc_h_t = super::ort_compat::tensor_from_ndarray(enc_h)
456 .map_err(|e| Error::Parse(format!("enc_h tensor: {e}")))?;
457 let attn_t = super::ort_compat::tensor_from_ndarray(attn)
458 .map_err(|e| Error::Parse(format!("attn tensor: {e}")))?;
459 let dec_ids_t = super::ort_compat::tensor_from_ndarray(dec_ids)
460 .map_err(|e| Error::Parse(format!("dec_ids tensor: {e}")))?;
461
462 let next_token = {
464 let mut dec = crate::sync::lock(&self.decoder);
465 let outputs = dec
466 .run(ort::inputs![
467 "encoder_hidden_states" => enc_h_t.into_dyn(),
468 "attention_mask" => attn_t.into_dyn(),
469 "decoder_input_ids" => dec_ids_t.into_dyn(),
470 ])
471 .map_err(|e| Error::Parse(format!("T5Coref decoder run: {e}")))?;
472 let logits_val = outputs.get("logits").ok_or_else(|| {
473 Error::Parse("T5 decoder output 'logits' not found; check ONNX export".into())
474 })?;
475 let (shape, logits_data) = logits_val
476 .try_extract_tensor::<f32>()
477 .map_err(|e| Error::Parse(format!("decoder logits extract: {e}")))?;
478 if shape.len() != 3 || shape[0] != 1 {
480 return Err(Error::Parse(format!(
481 "T5 decoder: unexpected logits shape {:?}",
482 shape
483 )));
484 }
485 let vocab_size = shape[2] as usize;
486 let last_offset = (dec_len - 1) * vocab_size;
487 let last_logits = &logits_data[last_offset..last_offset + vocab_size];
488 last_logits
489 .iter()
490 .enumerate()
491 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
492 .map(|(i, _)| i as i64)
493 .unwrap_or(1) }; Ok(next_token)
496 }
497
498 fn greedy_decode(
501 &self,
502 encoder_hidden: &[f32],
503 enc_seq_len: usize,
504 hidden_size: usize,
505 attention_mask: &[i64],
506 ) -> Result<Vec<i64>> {
507 const T5_PAD: i64 = 0;
509 const T5_EOS: i64 = 1;
510 let mut generated = vec![T5_PAD];
511
512 for _ in 0..self.config.max_output_length {
513 let next = self.decoder_step(
514 encoder_hidden,
515 enc_seq_len,
516 hidden_size,
517 attention_mask,
518 &generated,
519 )?;
520 if next == T5_EOS {
521 break;
522 }
523 generated.push(next);
524 }
525
526 Ok(generated[1..].to_vec()) }
528
529 fn decode_tokens(&self, token_ids: &[i64]) -> Result<String> {
531 let ids: Vec<u32> = token_ids.iter().map(|&x| x as u32).collect();
532 self.tokenizer
533 .decode(&ids, true)
534 .map_err(|e| Error::Parse(format!("T5Coref decode_tokens: {e}")))
535 }
536
537 fn parse_coref_output(&self, decoded: &str) -> Vec<CorefCluster> {
547 let mut clusters: HashMap<u32, CorefCluster> = HashMap::new();
548 let tokens: Vec<&str> = decoded.split_whitespace().collect();
549 let mut offset: usize = 0;
550 let mut i = 0;
551
552 while i < tokens.len() {
553 let tok = tokens[i];
554 let is_pipe = tokens.get(i + 1).map(|&t| t == "|").unwrap_or(false);
556 let cluster_id: Option<u32> = if is_pipe {
557 tokens
558 .get(i + 2)
559 .and_then(|t| t.trim_matches(|c: char| !c.is_ascii_digit()).parse().ok())
560 } else {
561 None
562 };
563
564 if let Some(cid) = cluster_id {
565 let mention = tok.trim_matches(|c: char| !c.is_alphanumeric()).to_string();
566 if !mention.is_empty() {
567 let start = offset;
568 let end = offset + mention.len();
569 let entry = clusters.entry(cid).or_insert_with(|| CorefCluster {
570 id: cid,
571 mentions: Vec::new(),
572 spans: Vec::new(),
573 canonical: String::new(),
574 });
575 entry.mentions.push(mention);
576 entry.spans.push((start, end));
577 }
578 offset += tok.len() + 1;
579 i += 3; continue;
581 }
582
583 offset += tok.len() + 1;
584 i += 1;
585 }
586
587 let mut result: Vec<CorefCluster> = clusters
588 .into_values()
589 .filter(|c| c.mentions.len() > 1)
590 .collect();
591
592 for c in &mut result {
593 c.canonical = c
594 .mentions
595 .iter()
596 .max_by_key(|m| m.len())
597 .cloned()
598 .unwrap_or_default();
599 }
600 result.sort_by_key(|c| c.id);
601 result
602 }
603
604 pub fn resolve_marked(&self, marked_text: &str) -> Result<Vec<CorefCluster>> {
609 let (plain_text, mentions) = self.extract_mentions(marked_text)?;
611
612 if mentions.is_empty() {
613 return Ok(vec![]);
614 }
615
616 self.cluster_mentions(&plain_text, &mentions)
623 }
624
625 pub fn resolve_entities(&self, text: &str, entities: &[Entity]) -> Result<Vec<CorefCluster>> {
629 if entities.is_empty() {
630 return Ok(vec![]);
631 }
632
633 let mentions: Vec<(String, usize, usize)> = entities
635 .iter()
636 .map(|e| (e.text.clone(), e.start, e.end))
637 .collect();
638
639 self.cluster_mentions(text, &mentions)
640 }
641
642 fn resolve_simple(&self, text: &str) -> Result<Vec<CorefCluster>> {
644 let pronouns = ["he", "she", "they", "it", "his", "her", "their", "its"];
646
647 let words: Vec<(String, usize, usize)> = {
648 let mut result = Vec::new();
649 let mut pos = 0;
650 for word in text.split_whitespace() {
651 if let Some(start) = text[pos..].find(word) {
652 let abs_start = pos + start;
653 result.push((word.to_string(), abs_start, abs_start + word.len()));
654 pos = abs_start + word.len();
655 }
656 }
657 result
658 };
659
660 let antecedents: Vec<&(String, usize, usize)> = words
662 .iter()
663 .filter(|(w, _, _)| {
664 w.chars().next().map(|c| c.is_uppercase()).unwrap_or(false)
665 && !pronouns.contains(&w.to_lowercase().as_str())
666 })
667 .collect();
668
669 let pronoun_mentions: Vec<&(String, usize, usize)> = words
671 .iter()
672 .filter(|(w, _, _)| pronouns.contains(&w.to_lowercase().as_str()))
673 .collect();
674
675 let mut clusters: Vec<CorefCluster> = Vec::new();
677 let mut assigned: HashMap<usize, u32> = HashMap::new();
678
679 for (ant_text, ant_start, ant_end) in &antecedents {
680 if assigned.contains_key(ant_start) {
682 continue;
683 }
684
685 let cluster_id = clusters.len() as u32;
686 let mut mentions = vec![ant_text.clone()];
687 let mut spans = vec![(*ant_start, *ant_end)];
688
689 assigned.insert(*ant_start, cluster_id);
690
691 for (pro_text, pro_start, pro_end) in &pronoun_mentions {
693 if *pro_start > *ant_end && !assigned.contains_key(pro_start) {
694 let compatible = match pro_text.to_lowercase().as_str() {
696 "he" | "him" | "his" => true, "she" | "her" | "hers" => true,
698 "they" | "them" | "their" | "theirs" => true,
699 "it" | "its" => true,
700 _ => true,
701 };
702
703 if compatible {
704 mentions.push(pro_text.clone());
705 spans.push((*pro_start, *pro_end));
706 assigned.insert(*pro_start, cluster_id);
707 break; }
709 }
710 }
711
712 if mentions.len() > 1 {
713 clusters.push(CorefCluster {
714 id: cluster_id,
715 canonical: ant_text.clone(),
716 mentions,
717 spans,
718 });
719 }
720 }
721
722 Ok(clusters)
723 }
724
725 #[allow(clippy::type_complexity)] fn extract_mentions(&self, marked_text: &str) -> Result<(String, Vec<(String, usize, usize)>)> {
728 let mut plain_text = String::new();
729 let mut mentions = Vec::new();
730 let mut offset = 0;
731
732 let mut remaining = marked_text;
733 while !remaining.is_empty() {
734 if let Some(start_pos) = remaining.find("<m>") {
735 plain_text.push_str(&remaining[..start_pos]);
737 offset += start_pos;
738
739 let after_start = &remaining[start_pos + 3..];
741 if let Some(end_pos) = after_start.find("</m>") {
742 let mention_text = after_start[..end_pos].trim();
743 let mention_start = offset;
744 plain_text.push_str(mention_text);
745 let mention_end = offset + mention_text.len();
746 offset = mention_end;
747
748 mentions.push((mention_text.to_string(), mention_start, mention_end));
749
750 remaining = &after_start[end_pos + 4..];
751 } else {
752 plain_text.push_str(remaining);
754 break;
755 }
756 } else {
757 plain_text.push_str(remaining);
759 break;
760 }
761 }
762
763 Ok((plain_text, mentions))
764 }
765
766 fn cluster_mentions(
768 &self,
769 _text: &str,
770 mentions: &[(String, usize, usize)],
771 ) -> Result<Vec<CorefCluster>> {
772 let mut clusters: Vec<CorefCluster> = Vec::new();
774 let mut assigned: HashMap<usize, u32> = HashMap::new();
775
776 let pronouns = [
777 "he", "she", "they", "it", "him", "her", "them", "his", "hers", "their", "its",
778 ];
779
780 for (i, (text_i, start_i, end_i)) in mentions.iter().enumerate() {
781 if assigned.contains_key(&i) {
782 continue;
783 }
784
785 let lower_i = text_i.to_lowercase();
786 let is_pronoun_i = pronouns.contains(&lower_i.as_str());
787
788 if is_pronoun_i {
789 for j in (0..i).rev() {
791 let (text_j, _, _) = &mentions[j];
792 let lower_j = text_j.to_lowercase();
793 if !pronouns.contains(&lower_j.as_str()) {
794 if let Some(&cluster_id) = assigned.get(&j) {
795 assigned.insert(i, cluster_id);
796 clusters[cluster_id as usize].mentions.push(text_i.clone());
797 clusters[cluster_id as usize].spans.push((*start_i, *end_i));
798 }
799 break;
800 }
801 }
802 continue;
803 }
804
805 let cluster_id = clusters.len() as u32;
807 let mut cluster_mentions = vec![text_i.clone()];
808 let mut cluster_spans = vec![(*start_i, *end_i)];
809 assigned.insert(i, cluster_id);
810
811 for (j, (text_j, start_j, end_j)) in mentions.iter().enumerate().skip(i + 1) {
813 if assigned.contains_key(&j) {
814 continue;
815 }
816
817 let lower_j = text_j.to_lowercase();
818
819 let matches = lower_i == lower_j
821 || lower_i.contains(&lower_j)
823 || lower_j.contains(&lower_i)
824 || {
826 let last_i = lower_i.split_whitespace().last();
827 let last_j = lower_j.split_whitespace().last();
828 last_i.is_some() && last_i == last_j && last_i.map(|w| w.len() > 2).unwrap_or(false)
829 };
830
831 if matches {
832 cluster_mentions.push(text_j.clone());
833 cluster_spans.push((*start_j, *end_j));
834 assigned.insert(j, cluster_id);
835 }
836 }
837
838 let canonical = cluster_mentions
840 .iter()
841 .max_by_key(|m| m.len())
842 .cloned()
843 .unwrap_or_else(|| text_i.clone());
844
845 clusters.push(CorefCluster {
846 id: cluster_id,
847 mentions: cluster_mentions,
848 spans: cluster_spans,
849 canonical,
850 });
851 }
852
853 let multi_clusters: Vec<CorefCluster> = clusters
855 .into_iter()
856 .filter(|c| c.mentions.len() > 1)
857 .collect();
858
859 Ok(multi_clusters)
860 }
861
862 pub fn model_path(&self) -> &str {
864 &self.model_path
865 }
866}
867
868#[cfg(test)]
873mod tests {
874 use super::*;
875
876 #[test]
880 fn test_coref_config_default() {
881 let config = T5CorefConfig::default();
882 assert_eq!(config.max_input_length, 512);
883 assert_eq!(config.num_beams, 1);
884 }
885
886 #[test]
887 fn test_cluster_struct() {
888 let cluster = CorefCluster {
889 id: 0,
890 mentions: vec!["Marie Curie".to_string(), "She".to_string()],
891 spans: vec![(0, 11), (50, 53)],
892 canonical: "Marie Curie".to_string(),
893 };
894
895 assert_eq!(cluster.mentions.len(), 2);
896 assert_eq!(cluster.canonical, "Marie Curie");
897 }
898}