use super::coref::{CorefChain, Mention, MentionType};
use anno::{Entity, EntityType};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum MemoryPolicy {
#[default]
Unbounded,
LeastRecentlyUsed {
max_clusters: usize,
},
LeastFrequentlyUsed {
max_clusters: usize,
},
DualCache {
l_cache_size: usize,
g_cache_size: usize,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IncrementalConfig {
pub window_size: usize,
pub window_overlap: usize,
pub memory_policy: MemoryPolicy,
pub similarity_threshold: f64,
pub token_based: bool,
pub max_pronoun_search_windows: usize,
pub use_exact_match: bool,
pub use_substring_match: bool,
pub grouped_window_size: usize,
}
impl Default for IncrementalConfig {
fn default() -> Self {
Self {
window_size: 1500,
window_overlap: 200,
memory_policy: MemoryPolicy::Unbounded,
similarity_threshold: 0.7,
token_based: true,
max_pronoun_search_windows: 3,
use_exact_match: true,
use_substring_match: true,
grouped_window_size: 10,
}
}
}
#[derive(Debug, Clone)]
struct ClusterMetadata {
id: u64,
representative: String,
mentions: Vec<MentionRecord>,
last_accessed_window: usize,
access_count: usize,
entity_type: Option<EntityType>,
#[allow(dead_code)]
first_mention_offset: usize,
}
#[derive(Debug, Clone)]
struct MentionRecord {
text: String,
start: usize,
end: usize,
#[allow(dead_code)]
window_index: usize,
mention_type: MentionType,
}
#[derive(Debug)]
pub struct EntityMemory {
clusters: HashMap<u64, ClusterMetadata>,
next_cluster_id: u64,
policy: MemoryPolicy,
current_window: usize,
l_cache: VecDeque<u64>,
g_cache: Vec<u64>,
}
impl EntityMemory {
pub fn new(policy: MemoryPolicy) -> Self {
Self {
clusters: HashMap::new(),
next_cluster_id: 0,
policy,
current_window: 0,
l_cache: VecDeque::new(),
g_cache: Vec::new(),
}
}
fn create_cluster(&mut self, mention: &MentionRecord, entity_type: Option<EntityType>) -> u64 {
let id = self.next_cluster_id;
self.next_cluster_id += 1;
let cluster = ClusterMetadata {
id,
representative: mention.text.clone(),
mentions: vec![mention.clone()],
last_accessed_window: self.current_window,
access_count: 1,
entity_type,
first_mention_offset: mention.start,
};
self.clusters.insert(id, cluster);
self.update_cache_on_access(id);
self.maybe_evict();
id
}
fn add_to_cluster(&mut self, cluster_id: u64, mention: &MentionRecord) {
if let Some(cluster) = self.clusters.get_mut(&cluster_id) {
cluster.mentions.push(mention.clone());
cluster.last_accessed_window = self.current_window;
cluster.access_count += 1;
if mention.text.len() > cluster.representative.len()
&& mention.mention_type != MentionType::Pronominal
{
cluster.representative = mention.text.clone();
}
}
self.update_cache_on_access(cluster_id);
}
pub fn find_best_match(
&self,
mention_text: &str,
mention_type: MentionType,
similarity_threshold: f64,
use_exact_match: bool,
use_substring_match: bool,
) -> Option<u64> {
let mention_key = name_key(mention_text);
let mut best_match: Option<(u64, f64)> = None;
for (id, cluster) in &self.clusters {
if !self.types_compatible(mention_type, cluster) {
continue;
}
let score = self.compute_match_score(
&mention_key,
mention_type,
cluster,
use_exact_match,
use_substring_match,
);
if score >= similarity_threshold {
let should_update = match best_match {
None => true,
Some((_, best_score)) => score > best_score,
};
if should_update {
best_match = Some((*id, score));
}
}
}
best_match.map(|(id, _)| id)
}
fn compute_match_score(
&self,
mention_key: &str,
mention_type: MentionType,
cluster: &ClusterMetadata,
use_exact_match: bool,
use_substring_match: bool,
) -> f64 {
let rep_key = name_key(&cluster.representative);
if use_exact_match && mention_key == rep_key {
return 1.0;
}
for m in &cluster.mentions {
let m_key = name_key(&m.text);
if mention_key == m_key {
return 0.95;
}
}
if use_substring_match && mention_type != MentionType::Pronominal {
if rep_key.contains(mention_key) || mention_key.contains(&rep_key) {
return 0.85;
}
let mention_parts: Vec<&str> = mention_key.split_whitespace().collect();
let rep_parts: Vec<&str> = rep_key.split_whitespace().collect();
if !mention_parts.is_empty() && !rep_parts.is_empty() {
if mention_parts.last() == rep_parts.last() {
return 0.8;
}
if mention_parts.first() == rep_parts.first() {
return 0.75;
}
}
}
if mention_type == MentionType::Pronominal {
return 0.6; }
let sim = trigram_similarity(mention_key, &rep_key);
sim * 0.7 }
fn types_compatible(&self, mention_type: MentionType, cluster: &ClusterMetadata) -> bool {
match cluster.entity_type {
Some(EntityType::Person) => {
true
}
Some(EntityType::Organization) | Some(EntityType::Location) => {
mention_type != MentionType::Pronominal
}
_ => true,
}
}
fn update_cache_on_access(&mut self, cluster_id: u64) {
match self.policy {
MemoryPolicy::DualCache { l_cache_size, .. } => {
self.l_cache.retain(|&id| id != cluster_id);
self.l_cache.push_front(cluster_id);
if let Some(cluster) = self.clusters.get(&cluster_id) {
let access_count = cluster.access_count;
self.g_cache.retain(|&id| id != cluster_id);
let pos = self.g_cache.iter().position(|&id| {
self.clusters
.get(&id)
.map(|c| c.access_count < access_count)
.unwrap_or(true)
});
match pos {
Some(p) => self.g_cache.insert(p, cluster_id),
None => self.g_cache.push(cluster_id),
}
}
while self.l_cache.len() > l_cache_size {
self.l_cache.pop_back();
}
}
MemoryPolicy::LeastRecentlyUsed { .. } => {
}
MemoryPolicy::LeastFrequentlyUsed { .. } => {
}
MemoryPolicy::Unbounded => {}
}
}
fn maybe_evict(&mut self) {
match self.policy {
MemoryPolicy::LeastRecentlyUsed { max_clusters } => {
while self.clusters.len() > max_clusters {
let lru_id = self
.clusters
.iter()
.min_by_key(|(_, c)| c.last_accessed_window)
.map(|(&id, _)| id);
if let Some(id) = lru_id {
self.clusters.remove(&id);
} else {
break;
}
}
}
MemoryPolicy::LeastFrequentlyUsed { max_clusters } => {
while self.clusters.len() > max_clusters {
let lfu_id = self
.clusters
.iter()
.min_by_key(|(_, c)| c.access_count)
.map(|(&id, _)| id);
if let Some(id) = lfu_id {
self.clusters.remove(&id);
} else {
break;
}
}
}
MemoryPolicy::DualCache { g_cache_size, .. } => {
while self.g_cache.len() > g_cache_size {
if let Some(id) = self.g_cache.pop() {
if !self.l_cache.contains(&id) {
self.clusters.remove(&id);
}
}
}
}
MemoryPolicy::Unbounded => {}
}
}
pub fn advance_window(&mut self) {
self.current_window += 1;
}
pub fn to_chains(&self) -> Vec<CorefChain> {
self.clusters
.values()
.map(|cluster| {
let mentions: Vec<Mention> = cluster
.mentions
.iter()
.map(|m| {
let mut mention = Mention::new(&m.text, m.start, m.end);
mention.mention_type = Some(m.mention_type);
mention
})
.collect();
let mut chain = CorefChain::new(mentions);
chain.cluster_id = Some(cluster.id.into());
chain.entity_type = cluster.entity_type.as_ref().map(|t| format!("{:?}", t));
chain
})
.collect()
}
pub fn cluster_count(&self) -> usize {
self.clusters.len()
}
pub fn mention_count(&self) -> usize {
self.clusters.values().map(|c| c.mentions.len()).sum()
}
}
#[derive(Debug)]
pub struct IncrementalCorefResolver {
config: IncrementalConfig,
}
impl Default for IncrementalCorefResolver {
fn default() -> Self {
Self::new(IncrementalConfig::default())
}
}
impl IncrementalCorefResolver {
pub fn new(config: IncrementalConfig) -> Self {
Self { config }
}
pub fn resolve_document(&self, text: &str) -> Vec<CorefChain> {
let mut memory = EntityMemory::new(self.config.memory_policy);
let windows = self.split_into_windows(text);
for (window_idx, (window_text, window_offset)) in windows.iter().enumerate() {
let mentions = self.extract_mentions(window_text, *window_offset);
for mention in mentions {
self.process_mention(&mut memory, &mention);
}
memory.advance_window();
if self.config.grouped_window_size > 0
&& (window_idx + 1) % self.config.grouped_window_size == 0
{
self.expand_grouped_window(&mut memory, window_idx);
}
}
memory.to_chains()
}
pub fn resolve_entities(&self, entities: &[Entity]) -> Vec<Entity> {
let mut memory = EntityMemory::new(self.config.memory_policy);
let mut resolved = entities.to_vec();
let mut current_window_start = 0usize;
let mut window_idx = 0;
for (i, entity) in entities.iter().enumerate() {
if entity.start() >= current_window_start + self.config.window_size {
window_idx += 1;
current_window_start = entity.start().saturating_sub(self.config.window_overlap);
memory.advance_window();
}
let mention = MentionRecord {
text: entity.text.clone(),
start: entity.start(),
end: entity.end(),
window_index: window_idx,
mention_type: self.classify_mention_type(&entity.text),
};
let cluster_id = self.process_mention(&mut memory, &mention);
resolved[i].canonical_id = Some(cluster_id.into());
}
resolved
}
fn split_into_windows(&self, text: &str) -> Vec<(String, usize)> {
let mut windows = Vec::new();
let mut offset = 0;
if self.config.token_based {
#[derive(Debug, Clone, Copy)]
struct TokenSpan {
start_char: usize,
end_char: usize,
}
fn tokenize_with_char_offsets(text: &str) -> Vec<TokenSpan> {
let mut tokens = Vec::new();
let mut in_word = false;
let mut word_start_char = 0;
let mut char_pos = 0;
for c in text.chars() {
if c.is_whitespace() {
if in_word {
tokens.push(TokenSpan {
start_char: word_start_char,
end_char: char_pos,
});
in_word = false;
}
} else if !in_word {
in_word = true;
word_start_char = char_pos;
}
char_pos += 1;
}
if in_word {
tokens.push(TokenSpan {
start_char: word_start_char,
end_char: char_pos,
});
}
tokens
}
let tokens = tokenize_with_char_offsets(text);
let step = self
.config
.window_size
.saturating_sub(self.config.window_overlap);
while offset < tokens.len() {
let end = (offset + self.config.window_size).min(tokens.len());
if end == 0 || offset >= end {
break;
}
let char_start = tokens[offset].start_char;
let char_end = tokens[end - 1].end_char;
let window_text = anno::offset::TextSpan::from_chars(text, char_start, char_end)
.extract(text)
.to_string();
windows.push((window_text, char_start));
if end >= tokens.len() {
break;
}
offset += step.max(1);
}
} else {
let text_char_len = text.chars().count();
let step = self
.config
.window_size
.saturating_sub(self.config.window_overlap);
while offset < text_char_len {
let end = (offset + self.config.window_size).min(text_char_len);
let adjusted_end = if end < text_char_len {
let end_byte = anno::offset::TextSpan::from_chars(text, end, end).byte_start;
let mut adjusted_end_byte = end_byte;
if let Some(ws_byte_pos) = text[..end_byte].rfind(char::is_whitespace) {
let ws_len = text[ws_byte_pos..]
.chars()
.next()
.map(|c| c.len_utf8())
.unwrap_or(1);
adjusted_end_byte = ws_byte_pos + ws_len;
while adjusted_end_byte < end_byte {
match text[adjusted_end_byte..].chars().next() {
Some(c) if c.is_whitespace() => {
adjusted_end_byte += c.len_utf8();
}
_ => break,
}
}
}
let (adjusted_end_char, _) =
anno::offset::bytes_to_chars(text, adjusted_end_byte, adjusted_end_byte);
if adjusted_end_char > offset {
adjusted_end_char.min(text_char_len)
} else {
end
}
} else {
end
};
let window_text = anno::offset::TextSpan::from_chars(text, offset, adjusted_end)
.extract(text)
.to_string();
windows.push((window_text, offset));
if adjusted_end >= text_char_len {
break;
}
offset += step.max(1);
}
}
windows
}
fn extract_mentions(&self, text: &str, offset: usize) -> Vec<MentionRecord> {
let mut mentions = Vec::new();
let mut in_word = false;
let mut word_start_byte = 0;
let mut word_start_char = 0;
let mut char_pos = 0;
let mut maybe_push_mention = |word: &str, local_start: usize, local_end: usize| {
let mention_type = self.classify_mention_type(word);
let should_track = match mention_type {
MentionType::Pronominal => true,
MentionType::Proper => true,
_ => word
.chars()
.next()
.map(|c| c.is_uppercase())
.unwrap_or(false),
};
if should_track {
mentions.push(MentionRecord {
text: word.to_string(),
start: offset + local_start,
end: offset + local_end,
window_index: 0, mention_type,
});
}
};
for (byte_idx, c) in text.char_indices() {
if c.is_whitespace() {
if in_word {
let word = &text[word_start_byte..byte_idx];
maybe_push_mention(word, word_start_char, char_pos);
in_word = false;
}
} else if !in_word {
in_word = true;
word_start_byte = byte_idx;
word_start_char = char_pos;
}
char_pos += 1;
}
if in_word {
let word = &text[word_start_byte..];
maybe_push_mention(word, word_start_char, char_pos);
}
mentions
}
fn classify_mention_type(&self, text: &str) -> MentionType {
let lower = text.to_lowercase();
let pronouns = [
"he",
"him",
"his",
"himself",
"she",
"her",
"hers",
"herself",
"they",
"them",
"their",
"theirs",
"themself",
"themselves",
"it",
"its",
"itself",
"i",
"me",
"my",
"mine",
"myself",
"we",
"us",
"our",
"ours",
"ourselves",
"you",
"your",
"yours",
"yourself",
"yourselves",
"xe",
"xem",
"xyr",
"xyrs",
"ze",
"zir",
"zirs",
"ey",
"em",
"eir",
"eirs",
];
if pronouns.contains(&lower.as_str()) {
return MentionType::Pronominal;
}
if text
.chars()
.next()
.map(|c| c.is_uppercase())
.unwrap_or(false)
{
return MentionType::Proper;
}
if lower.starts_with("the ") {
return MentionType::Nominal;
}
MentionType::Unknown
}
fn process_mention(&self, memory: &mut EntityMemory, mention: &MentionRecord) -> u64 {
if let Some(cluster_id) = memory.find_best_match(
&mention.text,
mention.mention_type,
self.config.similarity_threshold,
self.config.use_exact_match,
self.config.use_substring_match,
) {
memory.add_to_cluster(cluster_id, mention);
cluster_id
} else {
let entity_type = if mention.mention_type == MentionType::Proper {
Some(EntityType::Person) } else {
None
};
memory.create_cluster(mention, entity_type)
}
}
fn expand_grouped_window(&self, memory: &mut EntityMemory, _window_idx: usize) {
let cluster_ids: Vec<u64> = memory.clusters.keys().copied().collect();
for i in 0..cluster_ids.len() {
for j in (i + 1)..cluster_ids.len() {
let id_i = cluster_ids[i];
let id_j = cluster_ids[j];
if self.should_merge_clusters(memory, id_i, id_j) {
if let Some(cluster_j) = memory.clusters.remove(&id_j) {
if let Some(cluster_i) = memory.clusters.get_mut(&id_i) {
cluster_i.mentions.extend(cluster_j.mentions);
cluster_i.access_count += cluster_j.access_count;
}
}
}
}
}
}
fn should_merge_clusters(&self, memory: &EntityMemory, id_a: u64, id_b: u64) -> bool {
let (cluster_a, cluster_b) = match (memory.clusters.get(&id_a), memory.clusters.get(&id_b))
{
(Some(a), Some(b)) => (a, b),
_ => return false,
};
if cluster_a.entity_type.is_some()
&& cluster_b.entity_type.is_some()
&& cluster_a.entity_type != cluster_b.entity_type
{
return false;
}
let rep_a = name_key(&cluster_a.representative);
let rep_b = name_key(&cluster_b.representative);
if rep_a == rep_b {
return true;
}
if rep_a.contains(&rep_b) || rep_b.contains(&rep_a) {
return true;
}
if trigram_similarity(&rep_a, &rep_b) > 0.8 {
return true;
}
false
}
}
fn name_key(s: &str) -> String {
use std::borrow::Cow;
fn strip_bom_only(s: &str) -> Cow<'_, str> {
if !s.chars().any(|c| c == '\u{FEFF}') {
return Cow::Borrowed(s);
}
Cow::Owned(s.chars().filter(|&c| c != '\u{FEFF}').collect())
}
let s = strip_bom_only(s);
let cfg = textprep::ScrubConfig {
normalize_newlines: true,
remove_zero_width: false,
remove_bidi_controls: true,
collapse_whitespace: true,
normalization: textprep::ScrubNormalization::Nfc,
case: textprep::ScrubCase::Lower,
strip_diacritics: false,
};
textprep::scrub_with(s.as_ref(), &cfg)
}
fn trigram_similarity(a: &str, b: &str) -> f64 {
if a.is_empty() || b.is_empty() {
return 0.0;
}
if a.chars().count() < 3 || b.chars().count() < 3 {
return 0.0;
}
textprep::similarity::trigram_jaccard(a, b)
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct IncrementalStats {
pub windows_processed: usize,
pub mentions_processed: usize,
pub final_clusters: usize,
pub clusters_evicted: usize,
pub avg_mentions_per_cluster: f64,
pub max_cluster_size: usize,
}
impl IncrementalStats {
pub fn from_memory(memory: &EntityMemory, windows: usize) -> Self {
let cluster_sizes: Vec<usize> =
memory.clusters.values().map(|c| c.mentions.len()).collect();
Self {
windows_processed: windows,
mentions_processed: memory.mention_count(),
final_clusters: memory.cluster_count(),
clusters_evicted: 0, avg_mentions_per_cluster: if memory.cluster_count() > 0 {
memory.mention_count() as f64 / memory.cluster_count() as f64
} else {
0.0
},
max_cluster_size: cluster_sizes.into_iter().max().unwrap_or(0),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_windowing_and_mentions_use_character_offsets_on_unicode() {
use anno::offset::TextSpan;
let config = IncrementalConfig {
token_based: false,
window_size: 12,
window_overlap: 3,
..Default::default()
};
let resolver = IncrementalCorefResolver::new(config);
let text = "🎉 Dr. John went to 東京. He waved.";
let windows = resolver.split_into_windows(text);
assert!(!windows.is_empty());
for (window_text, window_offset) in &windows {
let span = TextSpan::from_chars(
text,
*window_offset,
*window_offset + window_text.chars().count(),
);
assert_eq!(span.extract(text), window_text);
}
let mentions = resolver.extract_mentions(text, 0);
let john = mentions
.iter()
.find(|m| m.text == "John")
.expect("expected to detect 'John'");
let extracted = TextSpan::from_chars(text, john.start, john.end).extract(text);
assert_eq!(extracted, "John");
assert_eq!(
john.start, 6,
"🎉(1) + space(1) + Dr.(2) + .(1) + space(1) = 6"
);
assert_eq!(john.end, 10);
}
#[test]
fn test_trigram_similarity() {
assert!((trigram_similarity("hello", "hello") - 1.0).abs() < 0.001);
assert!(trigram_similarity("hello", "world") < 0.3);
assert!(trigram_similarity("john", "johnson") > 0.3);
assert!(trigram_similarity("", "hello") == 0.0);
}
#[test]
fn test_entity_memory_basic() {
let mut memory = EntityMemory::new(MemoryPolicy::Unbounded);
let mention1 = MentionRecord {
text: "John Smith".to_string(),
start: 0,
end: 10,
window_index: 0,
mention_type: MentionType::Proper,
};
let cluster_id = memory.create_cluster(&mention1, Some(EntityType::Person));
assert_eq!(memory.cluster_count(), 1);
let mention2 = MentionRecord {
text: "Smith".to_string(),
start: 50,
end: 55,
window_index: 0,
mention_type: MentionType::Proper,
};
memory.add_to_cluster(cluster_id, &mention2);
assert_eq!(memory.mention_count(), 2);
}
#[test]
fn test_entity_memory_lru_eviction() {
let mut memory = EntityMemory::new(MemoryPolicy::LeastRecentlyUsed { max_clusters: 2 });
let mention1 = MentionRecord {
text: "John".to_string(),
start: 0,
end: 4,
window_index: 0,
mention_type: MentionType::Proper,
};
memory.create_cluster(&mention1, None);
memory.advance_window();
let mention2 = MentionRecord {
text: "Mary".to_string(),
start: 10,
end: 14,
window_index: 1,
mention_type: MentionType::Proper,
};
memory.create_cluster(&mention2, None);
memory.advance_window();
let mention3 = MentionRecord {
text: "Bob".to_string(),
start: 20,
end: 23,
window_index: 2,
mention_type: MentionType::Proper,
};
memory.create_cluster(&mention3, None);
assert_eq!(memory.cluster_count(), 2);
}
#[test]
fn test_find_best_match() {
let mut memory = EntityMemory::new(MemoryPolicy::Unbounded);
let mention1 = MentionRecord {
text: "John Smith".to_string(),
start: 0,
end: 10,
window_index: 0,
mention_type: MentionType::Proper,
};
memory.create_cluster(&mention1, Some(EntityType::Person));
let match_result = memory.find_best_match("Smith", MentionType::Proper, 0.7, true, true);
assert!(match_result.is_some());
}
#[test]
fn test_incremental_resolver_basic() {
let config = IncrementalConfig {
window_size: 100,
window_overlap: 20,
token_based: false,
..Default::default()
};
let resolver = IncrementalCorefResolver::new(config);
let text = "John went to the store. He bought milk. John came home.";
let chains = resolver.resolve_document(text);
assert!(!chains.is_empty());
}
#[test]
fn test_resolve_entities() {
let resolver = IncrementalCorefResolver::default();
let entities = vec![
Entity::new("John Smith", EntityType::Person, 0, 10, 0.9),
Entity::new("Smith", EntityType::Person, 50, 55, 0.85),
Entity::new("he", EntityType::Person, 100, 102, 0.7),
];
let resolved = resolver.resolve_entities(&entities);
assert_eq!(resolved[0].canonical_id, resolved[1].canonical_id);
}
#[test]
fn test_mention_type_classification() {
let resolver = IncrementalCorefResolver::default();
assert_eq!(
resolver.classify_mention_type("he"),
MentionType::Pronominal
);
assert_eq!(
resolver.classify_mention_type("she"),
MentionType::Pronominal
);
assert_eq!(
resolver.classify_mention_type("they"),
MentionType::Pronominal
);
assert_eq!(
resolver.classify_mention_type("xe"),
MentionType::Pronominal
);
assert_eq!(resolver.classify_mention_type("John"), MentionType::Proper);
}
#[test]
fn test_dual_cache_policy() {
let memory = EntityMemory::new(MemoryPolicy::DualCache {
l_cache_size: 5,
g_cache_size: 10,
});
assert_eq!(memory.cluster_count(), 0);
}
}