use crate::{Confidence, Entity};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct CoreferenceCluster {
pub id: u64,
pub members: Vec<usize>,
pub representative: usize,
pub canonical_name: String,
}
#[derive(Debug, Clone)]
pub struct CoreferenceConfig {
pub similarity_threshold: Confidence,
pub max_distance: Option<usize>,
pub use_string_match: bool,
}
impl Default for CoreferenceConfig {
fn default() -> Self {
Self {
similarity_threshold: Confidence::new(0.85),
max_distance: Some(500),
use_string_match: true,
}
}
}
pub fn resolve_coreferences(
entities: &[Entity],
embeddings: &[f32], hidden_dim: usize,
config: &CoreferenceConfig,
) -> Vec<CoreferenceCluster> {
let n = entities.len();
if n == 0 {
return vec![];
}
let mut parent: Vec<usize> = (0..n).collect();
fn find(parent: &mut [usize], i: usize) -> usize {
if parent[i] != i {
parent[i] = find(parent, parent[i]);
}
parent[i]
}
fn union(parent: &mut [usize], i: usize, j: usize) {
let pi = find(parent, i);
let pj = find(parent, j);
if pi != pj {
parent[pi] = pj;
}
}
for i in 0..n {
for j in (i + 1)..n {
if config.use_string_match {
let text_i = entities[i].text.to_lowercase();
let text_j = entities[j].text.to_lowercase();
if text_i == text_j || text_i.contains(&text_j) || text_j.contains(&text_i) {
if entities[i].entity_type == entities[j].entity_type {
union(&mut parent, i, j);
continue;
}
}
}
if let Some(max_dist) = config.max_distance {
let dist = if entities[i].end() <= entities[j].start() {
entities[j].start() - entities[i].end()
} else {
entities[i].start().saturating_sub(entities[j].end())
};
if dist > max_dist {
continue;
}
}
if embeddings.len() >= (j + 1) * hidden_dim {
let emb_i = &embeddings[i * hidden_dim..(i + 1) * hidden_dim];
let emb_j = &embeddings[j * hidden_dim..(j + 1) * hidden_dim];
let dot: f32 = emb_i.iter().zip(emb_j).map(|(a, b)| a * b).sum();
let norm_i: f32 = emb_i.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_j: f32 = emb_j.iter().map(|x| x * x).sum::<f32>().sqrt();
let similarity = if norm_i > 0.0 && norm_j > 0.0 {
dot / (norm_i * norm_j)
} else {
0.0
};
if similarity >= config.similarity_threshold.value() as f32 {
if entities[i].entity_type == entities[j].entity_type {
union(&mut parent, i, j);
}
}
}
}
}
let mut cluster_members: HashMap<usize, Vec<usize>> = HashMap::new();
for i in 0..n {
let root = find(&mut parent, i);
cluster_members.entry(root).or_default().push(i);
}
let mut clusters = Vec::new();
let mut cluster_id = 0u64;
for (_root, members) in cluster_members {
if members.len() > 1 {
let representative = *members
.iter()
.max_by_key(|&&i| entities[i].text.len())
.unwrap_or(&members[0]);
clusters.push(CoreferenceCluster {
id: cluster_id,
members,
representative,
canonical_name: entities[representative].text.clone(),
});
cluster_id += 1;
}
}
clusters
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Entity, EntityType};
#[test]
fn test_coreference_config_default() {
let config = CoreferenceConfig::default();
assert!((config.similarity_threshold.value() - 0.85).abs() < f64::EPSILON);
assert_eq!(config.max_distance, Some(500));
assert!(config.use_string_match);
}
#[test]
fn test_coreference_config_clone() {
let config = CoreferenceConfig {
similarity_threshold: Confidence::new(0.7),
max_distance: None,
use_string_match: false,
};
let cloned = config.clone();
assert!((cloned.similarity_threshold.value() - 0.7).abs() < f64::EPSILON);
assert!(cloned.max_distance.is_none());
assert!(!cloned.use_string_match);
}
#[test]
fn test_coreference_embedding_similarity_clusters() {
let entities = vec![
Entity::new("She", EntityType::Person, 0, 3, 0.9),
Entity::new("Her", EntityType::Person, 10, 13, 0.9),
];
let hidden_dim = 4;
let embeddings = vec![
1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ];
let config = CoreferenceConfig {
similarity_threshold: Confidence::new(0.85),
max_distance: Some(500),
use_string_match: false, };
let clusters = resolve_coreferences(&entities, &embeddings, hidden_dim, &config);
assert_eq!(clusters.len(), 1, "identical embeddings should cluster");
assert_eq!(clusters[0].members.len(), 2);
}
#[test]
fn test_coreference_embedding_below_threshold_no_cluster() {
let entities = vec![
Entity::new("Alice", EntityType::Person, 0, 5, 0.9),
Entity::new("Bob", EntityType::Person, 10, 13, 0.9),
];
let hidden_dim = 4;
let embeddings = vec![
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ];
let config = CoreferenceConfig {
similarity_threshold: Confidence::new(0.85),
max_distance: Some(500),
use_string_match: false,
};
let clusters = resolve_coreferences(&entities, &embeddings, hidden_dim, &config);
assert!(
clusters.is_empty(),
"orthogonal embeddings should not cluster"
);
}
#[test]
fn test_coreference_representative_is_longest_mention() {
let entities = vec![
Entity::new("Dr. Robert Johnson", EntityType::Person, 0, 18, 0.9),
Entity::new("Johnson", EntityType::Person, 30, 37, 0.9),
Entity::new("He", EntityType::Person, 50, 52, 0.9),
];
let hidden_dim = 4;
let embeddings = vec![
1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ];
let config = CoreferenceConfig {
similarity_threshold: Confidence::new(0.5),
max_distance: Some(500),
use_string_match: false,
};
let clusters = resolve_coreferences(&entities, &embeddings, hidden_dim, &config);
assert_eq!(clusters.len(), 1);
assert_eq!(
clusters[0].canonical_name, "Dr. Robert Johnson",
"representative should be the longest mention"
);
}
#[test]
fn test_coreference_no_distance_limit() {
let entities = vec![
Entity::new("Alice", EntityType::Person, 0, 5, 0.9),
Entity::new("Alice", EntityType::Person, 10000, 10005, 0.9),
];
let embeddings = vec![0.0f32; 2 * 4];
let config = CoreferenceConfig {
similarity_threshold: Confidence::new(0.85),
max_distance: None, use_string_match: true,
};
let clusters = resolve_coreferences(&entities, &embeddings, 4, &config);
assert_eq!(
clusters.len(),
1,
"no distance limit should allow clustering"
);
}
#[test]
fn test_coreference_two_separate_clusters() {
let entities = vec![
Entity::new("Alice", EntityType::Person, 0, 5, 0.9),
Entity::new("Alice", EntityType::Person, 20, 25, 0.9),
Entity::new("Acme", EntityType::Organization, 40, 44, 0.9),
Entity::new("Acme", EntityType::Organization, 60, 64, 0.9),
];
let embeddings = vec![0.0f32; 4 * 768];
let clusters =
resolve_coreferences(&entities, &embeddings, 768, &CoreferenceConfig::default());
assert_eq!(clusters.len(), 2, "should form two separate clusters");
for cluster in &clusters {
assert_eq!(cluster.members.len(), 2);
}
}
#[test]
fn test_coreference_cluster_ids_are_sequential() {
let entities = vec![
Entity::new("A", EntityType::Person, 0, 1, 0.9),
Entity::new("A", EntityType::Person, 5, 6, 0.9),
Entity::new("B", EntityType::Organization, 10, 11, 0.9),
Entity::new("B", EntityType::Organization, 15, 16, 0.9),
];
let embeddings = vec![0.0f32; 4 * 4];
let clusters =
resolve_coreferences(&entities, &embeddings, 4, &CoreferenceConfig::default());
let mut ids: Vec<u64> = clusters.iter().map(|c| c.id).collect();
ids.sort();
for (i, id) in ids.iter().enumerate() {
assert_eq!(*id, i as u64, "cluster IDs should be sequential");
}
}
#[test]
fn test_coreference_cluster_debug_and_clone() {
let cluster = CoreferenceCluster {
id: 0,
members: vec![0, 1, 2],
representative: 0,
canonical_name: "Test Entity".to_string(),
};
let cloned = cluster.clone();
assert_eq!(cloned.id, 0);
assert_eq!(cloned.members, vec![0, 1, 2]);
assert_eq!(cloned.canonical_name, "Test Entity");
let debug = format!("{:?}", cluster);
assert!(debug.contains("Test Entity"));
}
}