use crate::Result;
#[derive(Debug, Clone)]
pub struct CorefCluster {
pub id: u32,
pub mentions: Vec<String>,
pub spans: Vec<(usize, usize)>,
pub canonical: String,
}
pub trait CorefBackend: Send + Sync {
fn resolve(&self, text: &str) -> Result<Vec<CorefCluster>>;
fn name(&self) -> &'static str;
fn is_available(&self) -> bool;
}
#[cfg(feature = "onnx")]
impl CorefBackend for super::fcoref::FCoref {
fn resolve(&self, text: &str) -> Result<Vec<CorefCluster>> {
self.resolve(text)
}
fn name(&self) -> &'static str {
"fcoref"
}
fn is_available(&self) -> bool {
true }
}
#[cfg(feature = "onnx")]
impl CorefBackend for super::t5::T5Coref {
fn resolve(&self, text: &str) -> Result<Vec<CorefCluster>> {
self.resolve(text)
}
fn name(&self) -> &'static str {
"coref-t5"
}
fn is_available(&self) -> bool {
true }
}
impl CorefBackend for super::mention_ranking::MentionRankingCoref {
fn resolve(&self, text: &str) -> Result<Vec<CorefCluster>> {
let clusters = self.resolve(text)?;
Ok(clusters
.into_iter()
.enumerate()
.map(|(i, mc)| {
let mentions: Vec<String> = mc.mentions.iter().map(|m| m.text.clone()).collect();
let spans: Vec<(usize, usize)> =
mc.mentions.iter().map(|m| (m.start, m.end)).collect();
let canonical = mentions
.iter()
.max_by_key(|t| t.len())
.cloned()
.unwrap_or_default();
CorefCluster {
id: i as u32,
mentions,
spans,
canonical,
}
})
.collect())
}
fn name(&self) -> &'static str {
"mention-ranking"
}
fn is_available(&self) -> bool {
true }
}
#[cfg(test)]
mod tests {
use super::*;
struct StubCoref {
available: bool,
}
impl CorefBackend for StubCoref {
fn resolve(&self, text: &str) -> Result<Vec<CorefCluster>> {
if text.is_empty() {
return Ok(vec![]);
}
Ok(vec![CorefCluster {
id: 0,
mentions: vec!["John".to_string(), "He".to_string()],
spans: vec![(0, 4), (30, 32)],
canonical: "John".to_string(),
}])
}
fn name(&self) -> &'static str {
"stub-coref"
}
fn is_available(&self) -> bool {
self.available
}
}
#[test]
fn trait_object_dispatch() {
let backend: Box<dyn CorefBackend> = Box::new(StubCoref { available: true });
let clusters = backend
.resolve("John went to the store. He bought milk.")
.unwrap();
assert_eq!(clusters.len(), 1);
assert_eq!(clusters[0].canonical, "John");
assert_eq!(clusters[0].mentions.len(), 2);
}
#[test]
fn trait_object_empty_input() {
let backend: Box<dyn CorefBackend> = Box::new(StubCoref { available: true });
let clusters = backend.resolve("").unwrap();
assert!(
clusters.is_empty(),
"empty input should produce no clusters"
);
}
#[test]
fn trait_object_name_and_availability() {
let available: Box<dyn CorefBackend> = Box::new(StubCoref { available: true });
let unavailable: Box<dyn CorefBackend> = Box::new(StubCoref { available: false });
assert_eq!(available.name(), "stub-coref");
assert!(available.is_available());
assert!(!unavailable.is_available());
}
#[test]
fn heterogeneous_vec_of_trait_objects() {
let backends: Vec<Box<dyn CorefBackend>> = vec![
Box::new(StubCoref { available: true }),
Box::new(StubCoref { available: false }),
];
let available_count = backends.iter().filter(|b| b.is_available()).count();
assert_eq!(available_count, 1);
for b in &backends {
if b.is_available() {
let clusters = b.resolve("test text").unwrap();
assert!(!clusters.is_empty());
}
}
}
}