anno_core/core/coref.rs
1//! Coreference resolution data structures.
2//!
3//! Provides types for representing coreference chains (clusters of mentions
4//! that refer to the same entity) and utilities for working with them.
5//!
6//! # Terminology
7//!
8//! - **Mention**: A span of text referring to an entity (e.g., "John", "he", "the CEO")
9//! - **Chain/Cluster**: A set of mentions that corefer (refer to the same entity)
10//! - **Singleton**: A chain with only one mention (entity mentioned only once)
11//! - **Antecedent**: An earlier mention that a pronoun/noun phrase refers to
12//!
13//! # Example
14//!
15//! ```rust
16//! use anno_core::core::coref::{Mention, CorefChain, CorefDocument};
17//!
18//! // "John went to the store. He bought milk."
19//! let john = Mention::new("John", 0, 4);
20//! let he = Mention::new("He", 25, 27);
21//!
22//! let chain = CorefChain::new(vec![john, he]);
23//! assert_eq!(chain.len(), 2);
24//! assert!(!chain.is_singleton());
25//! ```
26
27use super::Entity;
28use serde::{Deserialize, Serialize};
29use std::collections::{HashMap, HashSet};
30
31// Re-export MentionType for convenience
32pub use super::types::MentionType;
33
34// =============================================================================
35// Mention
36// =============================================================================
37
38/// A single mention (text span) that may corefer with other mentions.
39///
40/// Mentions are comparable by span position, not by text content.
41/// Two mentions with identical text at different positions are distinct.
42///
43/// # Character vs Byte Offsets
44///
45/// `start` and `end` are **character** offsets, not byte offsets.
46/// For "北京 Beijing", the character offsets are:
47/// - "北" = 0..1 (but 3 bytes in UTF-8)
48/// - "京" = 1..2 (but 3 bytes)
49/// - " " = 2..3
50/// - "Beijing" = 3..10
51///
52/// Use `text.chars().skip(start).take(end - start).collect()` to extract.
53///
54/// # Head Span
55///
56/// The `head_start`/`head_end` fields mark the syntactic head for head-match
57/// evaluation (used in CEAF-e, LEA metrics). In "the former president of France",
58/// the head is "president" - the noun that determines agreement.
59#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub struct Mention {
61 /// The mention text (surface form).
62 pub text: String,
63 /// Start character offset (inclusive, 0-indexed).
64 pub start: usize,
65 /// End character offset (exclusive).
66 pub end: usize,
67 /// Head word start (for head-match metrics like CEAF).
68 pub head_start: Option<usize>,
69 /// Head word end.
70 pub head_end: Option<usize>,
71 /// Entity type if known (e.g., "PER", "ORG").
72 pub entity_type: Option<String>,
73 /// Mention category: Pronominal, Proper, Nominal, Zero.
74 pub mention_type: Option<MentionType>,
75}
76
77impl Mention {
78 /// `Mention::new("John", 0, 4)` creates a mention for "John" at characters 0..4.
79 ///
80 /// Offsets are character positions, not byte positions.
81 #[must_use]
82 pub fn new(text: impl Into<String>, start: usize, end: usize) -> Self {
83 Self {
84 text: text.into(),
85 start,
86 end,
87 head_start: None,
88 head_end: None,
89 entity_type: None,
90 mention_type: None,
91 }
92 }
93
94 /// Mention with head span for head-match evaluation.
95 ///
96 /// The head is the syntactic nucleus: in "the former president", head is "president".
97 ///
98 /// ```
99 /// # use anno_core::core::coref::Mention;
100 /// let m = Mention::with_head("the former president", 0, 20, 11, 20);
101 /// assert_eq!(m.head_start, Some(11)); // "president" starts at 11
102 /// ```
103 #[must_use]
104 pub fn with_head(
105 text: impl Into<String>,
106 start: usize,
107 end: usize,
108 head_start: usize,
109 head_end: usize,
110 ) -> Self {
111 Self {
112 text: text.into(),
113 start,
114 end,
115 head_start: Some(head_start),
116 head_end: Some(head_end),
117 entity_type: None,
118 mention_type: None,
119 }
120 }
121
122 /// Mention with type annotation for type-aware evaluation.
123 ///
124 /// ```
125 /// # use anno_core::core::coref::Mention;
126 /// # use anno_core::core::types::MentionType;
127 /// let pronoun = Mention::with_type("he", 25, 27, MentionType::Pronominal);
128 /// let proper = Mention::with_type("John Smith", 0, 10, MentionType::Proper);
129 /// ```
130 #[must_use]
131 pub fn with_type(
132 text: impl Into<String>,
133 start: usize,
134 end: usize,
135 mention_type: MentionType,
136 ) -> Self {
137 Self {
138 text: text.into(),
139 start,
140 end,
141 head_start: None,
142 head_end: None,
143 entity_type: None,
144 mention_type: Some(mention_type),
145 }
146 }
147
148 /// True if spans share any characters: `[0,5)` overlaps `[3,8)`.
149 #[must_use]
150 pub fn overlaps(&self, other: &Mention) -> bool {
151 self.start < other.end && other.start < self.end
152 }
153
154 /// True if spans are identical: same start AND end.
155 #[must_use]
156 pub fn span_matches(&self, other: &Mention) -> bool {
157 self.start == other.start && self.end == other.end
158 }
159
160 /// Span length in characters. Returns 0 if `end <= start`.
161 #[must_use]
162 pub fn len(&self) -> usize {
163 self.end.saturating_sub(self.start)
164 }
165
166 /// True if span has zero length.
167 #[must_use]
168 pub fn is_empty(&self) -> bool {
169 self.len() == 0
170 }
171
172 /// `(start, end)` tuple for use in hash sets and comparisons.
173 #[must_use]
174 pub fn span_id(&self) -> (usize, usize) {
175 (self.start, self.end)
176 }
177}
178
179impl std::fmt::Display for Mention {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 write!(f, "\"{}\" [{}-{})", self.text, self.start, self.end)
182 }
183}
184
185// =============================================================================
186// CorefChain (Cluster)
187// =============================================================================
188
189/// A coreference chain: mentions that all refer to the same entity.
190///
191/// ```
192/// # use anno_core::core::coref::{CorefChain, Mention};
193/// // "John went to the store. He bought milk."
194/// // ^^^^ ^^
195/// let john = Mention::new("John", 0, 4);
196/// let he = Mention::new("He", 25, 27);
197///
198/// let chain = CorefChain::new(vec![john, he]);
199/// assert_eq!(chain.len(), 2);
200/// assert!(!chain.is_singleton());
201/// ```
202///
203/// # Note
204///
205/// This type is for **evaluation and intermediate processing**. For production pipelines,
206/// use [`Track`](super::grounded::Track) which integrates with the Signal/Track/Identity hierarchy.
207#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
208pub struct CorefChain {
209 /// Mentions in document order (sorted by start position).
210 pub mentions: Vec<Mention>,
211 /// Cluster ID from the source data, if any.
212 pub cluster_id: Option<super::types::CanonicalId>,
213 /// Entity type shared by all mentions (e.g., "PERSON").
214 pub entity_type: Option<String>,
215}
216
217impl CorefChain {
218 /// Build a chain from mentions. Sorts by position automatically.
219 ///
220 /// ```
221 /// # use anno_core::core::coref::{CorefChain, Mention};
222 /// let chain = CorefChain::new(vec![
223 /// Mention::new("she", 50, 53),
224 /// Mention::new("Dr. Smith", 0, 9), // out of order
225 /// ]);
226 /// assert_eq!(chain.mentions[0].text, "Dr. Smith"); // sorted
227 /// ```
228 #[must_use]
229 pub fn new(mut mentions: Vec<Mention>) -> Self {
230 mentions.sort_by_key(|m| (m.start, m.end));
231 Self {
232 mentions,
233 cluster_id: None,
234 entity_type: None,
235 }
236 }
237
238 /// Build a chain with an explicit cluster ID.
239 #[must_use]
240 pub fn with_id(
241 mut mentions: Vec<Mention>,
242 cluster_id: impl Into<super::types::CanonicalId>,
243 ) -> Self {
244 mentions.sort_by_key(|m| (m.start, m.end));
245 Self {
246 mentions,
247 cluster_id: Some(cluster_id.into()),
248 entity_type: None,
249 }
250 }
251
252 /// A chain with exactly one mention (entity mentioned only once).
253 #[must_use]
254 pub fn singleton(mention: Mention) -> Self {
255 Self {
256 mentions: vec![mention],
257 cluster_id: None,
258 entity_type: None,
259 }
260 }
261
262 /// Number of mentions. A chain with 3 mentions has 2 implicit "links".
263 #[must_use]
264 pub fn len(&self) -> usize {
265 self.mentions.len()
266 }
267
268 /// True if chain has no mentions. Shouldn't happen in valid data.
269 #[must_use]
270 pub fn is_empty(&self) -> bool {
271 self.mentions.is_empty()
272 }
273
274 /// True if chain has exactly one mention (singleton entity).
275 #[must_use]
276 pub fn is_singleton(&self) -> bool {
277 self.mentions.len() == 1
278 }
279
280 /// All pairwise links. For MUC: `n` mentions = `n*(n-1)/2` links.
281 ///
282 /// ```
283 /// # use anno_core::core::coref::{CorefChain, Mention};
284 /// let chain = CorefChain::new(vec![
285 /// Mention::new("A", 0, 1),
286 /// Mention::new("B", 2, 3),
287 /// Mention::new("C", 4, 5),
288 /// ]);
289 /// assert_eq!(chain.links().len(), 3); // A-B, A-C, B-C
290 /// ```
291 #[must_use]
292 pub fn links(&self) -> Vec<(&Mention, &Mention)> {
293 let mut links = Vec::new();
294 for i in 0..self.mentions.len() {
295 for j in (i + 1)..self.mentions.len() {
296 links.push((&self.mentions[i], &self.mentions[j]));
297 }
298 }
299 links
300 }
301
302 /// Number of coreference links.
303 ///
304 /// For a chain of n mentions: n*(n-1)/2 pairs, but only n-1 links needed
305 /// to connect all mentions (spanning tree).
306 #[must_use]
307 pub fn link_count(&self) -> usize {
308 if self.mentions.len() <= 1 {
309 0
310 } else {
311 self.mentions.len() - 1
312 }
313 }
314
315 /// Get all pairwise mention combinations (for B³, CEAF).
316 #[must_use]
317 pub fn all_pairs(&self) -> Vec<(&Mention, &Mention)> {
318 self.links() // Same as links for non-directed pairs
319 }
320
321 /// Check if chain contains a mention with given span.
322 #[must_use]
323 pub fn contains_span(&self, start: usize, end: usize) -> bool {
324 self.mentions
325 .iter()
326 .any(|m| m.start == start && m.end == end)
327 }
328
329 /// Get first mention (usually the most salient/representative).
330 #[must_use]
331 pub fn first(&self) -> Option<&Mention> {
332 self.mentions.first()
333 }
334
335 /// Get set of mention span IDs for set operations.
336 #[must_use]
337 pub fn mention_spans(&self) -> HashSet<(usize, usize)> {
338 self.mentions.iter().map(|m| m.span_id()).collect()
339 }
340
341 /// Get the canonical (representative) mention for this chain.
342 ///
343 /// Prefers proper nouns over other mention types, then longest mention.
344 /// Falls back to first mention if no proper noun exists.
345 #[must_use]
346 pub fn canonical_mention(&self) -> Option<&Mention> {
347 // Prefer proper noun mentions
348 let proper = self
349 .mentions
350 .iter()
351 .filter(|m| m.mention_type == Some(MentionType::Proper))
352 .max_by_key(|m| m.text.len());
353
354 if proper.is_some() {
355 return proper;
356 }
357
358 // Fall back to longest mention (likely most informative)
359 self.mentions.iter().max_by_key(|m| m.text.len())
360 }
361
362 /// Get the canonical ID for this chain (cluster_id if set).
363 #[must_use]
364 pub fn canonical_id(&self) -> Option<super::types::CanonicalId> {
365 self.cluster_id
366 }
367}
368
369impl std::fmt::Display for CorefChain {
370 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371 let mentions: Vec<String> = self
372 .mentions
373 .iter()
374 .map(|m| format!("\"{}\"", m.text))
375 .collect();
376 write!(f, "[{}]", mentions.join(", "))
377 }
378}
379
380// =============================================================================
381// CorefDocument
382// =============================================================================
383
384/// A document with coreference annotations.
385///
386/// Contains the source text and all coreference chains.
387#[derive(Debug, Clone, Serialize, Deserialize)]
388pub struct CorefDocument {
389 /// Document text.
390 pub text: String,
391 /// Document identifier.
392 pub doc_id: Option<String>,
393 /// Coreference chains (clusters).
394 pub chains: Vec<CorefChain>,
395 /// Whether singletons are included.
396 pub includes_singletons: bool,
397}
398
399impl CorefDocument {
400 /// Create a new document with chains.
401 #[must_use]
402 pub fn new(text: impl Into<String>, chains: Vec<CorefChain>) -> Self {
403 Self {
404 text: text.into(),
405 doc_id: None,
406 chains,
407 includes_singletons: false,
408 }
409 }
410
411 /// Create document with ID.
412 #[must_use]
413 pub fn with_id(
414 text: impl Into<String>,
415 doc_id: impl Into<String>,
416 chains: Vec<CorefChain>,
417 ) -> Self {
418 Self {
419 text: text.into(),
420 doc_id: Some(doc_id.into()),
421 chains,
422 includes_singletons: false,
423 }
424 }
425
426 /// Total number of mentions across all chains.
427 #[must_use]
428 pub fn mention_count(&self) -> usize {
429 self.chains.iter().map(|c| c.len()).sum()
430 }
431
432 /// Number of chains (clusters).
433 #[must_use]
434 pub fn chain_count(&self) -> usize {
435 self.chains.len()
436 }
437
438 /// Number of non-singleton chains.
439 #[must_use]
440 pub fn non_singleton_count(&self) -> usize {
441 self.chains.iter().filter(|c| !c.is_singleton()).count()
442 }
443
444 /// Get all mentions in document order.
445 #[must_use]
446 pub fn all_mentions(&self) -> Vec<&Mention> {
447 let mut mentions: Vec<&Mention> = self.chains.iter().flat_map(|c| &c.mentions).collect();
448 mentions.sort_by_key(|m| (m.start, m.end));
449 mentions
450 }
451
452 /// Find which chain contains a mention span.
453 #[must_use]
454 pub fn find_chain(&self, start: usize, end: usize) -> Option<&CorefChain> {
455 self.chains.iter().find(|c| c.contains_span(start, end))
456 }
457
458 /// Build mention-to-chain index for fast lookup.
459 #[must_use]
460 pub fn mention_to_chain_index(&self) -> HashMap<(usize, usize), usize> {
461 let mut index = HashMap::new();
462 for (chain_idx, chain) in self.chains.iter().enumerate() {
463 for mention in &chain.mentions {
464 index.insert(mention.span_id(), chain_idx);
465 }
466 }
467 index
468 }
469
470 /// Filter to only non-singleton chains.
471 #[must_use]
472 pub fn without_singletons(&self) -> Self {
473 Self {
474 text: self.text.clone(),
475 doc_id: self.doc_id.clone(),
476 chains: self
477 .chains
478 .iter()
479 .filter(|c| !c.is_singleton())
480 .cloned()
481 .collect(),
482 includes_singletons: false,
483 }
484 }
485}
486
487// =============================================================================
488// Conversion from Entity to Mention
489// =============================================================================
490
491impl From<&Entity> for Mention {
492 fn from(entity: &Entity) -> Self {
493 Self {
494 text: entity.text.clone(),
495 start: entity.start,
496 end: entity.end,
497 head_start: None,
498 head_end: None,
499 entity_type: Some(entity.entity_type.as_label().to_string()),
500 mention_type: None,
501 }
502 }
503}
504
505/// Convert entities with canonical_id to coreference chains.
506///
507/// Entities sharing the same `canonical_id` are grouped into a chain.
508#[must_use]
509pub fn entities_to_chains(entities: &[Entity]) -> Vec<CorefChain> {
510 let mut clusters: HashMap<u64, Vec<Mention>> = HashMap::new();
511 let mut singletons: Vec<Mention> = Vec::new();
512
513 for entity in entities {
514 let mention = Mention::from(entity);
515 if let Some(canonical_id) = entity.canonical_id {
516 clusters
517 .entry(canonical_id.get())
518 .or_default()
519 .push(mention);
520 } else {
521 singletons.push(mention);
522 }
523 }
524
525 let mut chains: Vec<CorefChain> = clusters
526 .into_iter()
527 .map(|(id, mentions)| CorefChain::with_id(mentions, id))
528 .collect();
529
530 // Add singletons as individual chains
531 for mention in singletons {
532 chains.push(CorefChain::singleton(mention));
533 }
534
535 chains
536}
537
538// =============================================================================
539// CoreferenceResolver Trait
540// =============================================================================
541
542/// Trait for coreference resolution algorithms.
543///
544/// Implementors take a set of entity mentions and cluster them into
545/// coreference chains (groups of mentions referring to the same entity).
546///
547/// # Design Philosophy
548///
549/// This trait lives in `anno::core` because:
550/// 1. It depends only on core types (`Entity`, `CorefChain`)
551/// 2. Multiple crates need to implement it (backends, eval)
552/// 3. Keeping it here prevents circular dependencies
553///
554/// # Relationship to the Grounded Pipeline
555///
556/// `CoreferenceResolver` operates on the **evaluation/convenience layer** (`Entity`),
557/// not the canonical **grounded pipeline** (`Signal` → `Track` → `Identity`).
558///
559/// | Layer | Type | `CoreferenceResolver` role |
560/// |-------|------|----------------------------|
561/// | Detection (L1) | `Entity` | Input: mentions to cluster |
562/// | Coref (L2) | `Entity.canonical_id` | Output: cluster assignment |
563/// | Linking (L3) | `Identity` | (not covered by this trait) |
564///
565/// For integration with `GroundedDocument`, use backends that produce
566/// `Signal` + `Track` directly (e.g., `anno::backends::MentionRankingCoref`).
567///
568/// # Example Implementation
569///
570/// ```rust,ignore
571/// use anno_core::{CoreferenceResolver, Entity, CorefChain};
572///
573/// struct ExactMatchResolver;
574///
575/// impl CoreferenceResolver for ExactMatchResolver {
576/// fn resolve(&self, entities: &[Entity]) -> Vec<Entity> {
577/// // Cluster entities with identical text
578/// // ... implementation ...
579/// }
580///
581/// fn name(&self) -> &'static str {
582/// "exact-match"
583/// }
584/// }
585/// ```
586pub trait CoreferenceResolver: Send + Sync {
587 /// Resolve coreference, assigning canonical IDs to entities.
588 ///
589 /// Each entity in the output will have a `canonical_id` field set.
590 /// Entities with the same `canonical_id` are coreferent (refer to the
591 /// same real-world entity).
592 ///
593 /// # Invariants
594 ///
595 /// - Every output entity has `canonical_id.is_some()`
596 /// - Coreferent entities share the same `canonical_id`
597 /// - Singleton mentions get unique `canonical_id` values
598 fn resolve(&self, entities: &[Entity]) -> Vec<Entity>;
599
600 /// Resolve directly to chains.
601 ///
602 /// A chain groups all mentions of the same entity together.
603 /// This is often the desired output format for evaluation and
604 /// downstream tasks.
605 fn resolve_to_chains(&self, entities: &[Entity]) -> Vec<CorefChain> {
606 let resolved = self.resolve(entities);
607 entities_to_chains(&resolved)
608 }
609
610 /// Get resolver name.
611 ///
612 /// Used for logging, metrics, and result attribution.
613 fn name(&self) -> &'static str;
614}
615
616// =============================================================================
617// Tests
618// =============================================================================
619
620#[cfg(test)]
621mod tests {
622 use super::*;
623
624 #[test]
625 fn test_mention_creation() {
626 let m = Mention::new("John", 0, 4);
627 assert_eq!(m.text, "John");
628 assert_eq!(m.start, 0);
629 assert_eq!(m.end, 4);
630 assert_eq!(m.len(), 4);
631 }
632
633 #[test]
634 fn test_mention_overlap() {
635 let m1 = Mention::new("John Smith", 0, 10);
636 let m2 = Mention::new("Smith", 5, 10);
637 let m3 = Mention::new("works", 11, 16);
638
639 assert!(m1.overlaps(&m2));
640 assert!(!m1.overlaps(&m3));
641 assert!(!m2.overlaps(&m3));
642 }
643
644 #[test]
645 fn test_chain_creation() {
646 let mentions = vec![
647 Mention::new("John", 0, 4),
648 Mention::new("he", 20, 22),
649 Mention::new("him", 40, 43),
650 ];
651 let chain = CorefChain::new(mentions);
652
653 assert_eq!(chain.len(), 3);
654 assert!(!chain.is_singleton());
655 assert_eq!(chain.link_count(), 2); // Minimum links to connect
656 }
657
658 #[test]
659 fn test_chain_links() {
660 let mentions = vec![
661 Mention::new("a", 0, 1),
662 Mention::new("b", 2, 3),
663 Mention::new("c", 4, 5),
664 ];
665 let chain = CorefChain::new(mentions);
666
667 // All pairs: (a,b), (a,c), (b,c) = 3 pairs
668 assert_eq!(chain.all_pairs().len(), 3);
669 }
670
671 #[test]
672 fn test_singleton_chain() {
673 let m = Mention::new("entity", 0, 6);
674 let chain = CorefChain::singleton(m);
675
676 assert!(chain.is_singleton());
677 assert_eq!(chain.link_count(), 0);
678 assert!(chain.all_pairs().is_empty());
679 }
680
681 #[test]
682 fn test_document() {
683 let text = "John went to the store. He bought milk.";
684 let chain = CorefChain::new(vec![Mention::new("John", 0, 4), Mention::new("He", 24, 26)]);
685 let doc = CorefDocument::new(text, vec![chain]);
686
687 assert_eq!(doc.mention_count(), 2);
688 assert_eq!(doc.chain_count(), 1);
689 assert_eq!(doc.non_singleton_count(), 1);
690 }
691
692 #[test]
693 fn test_mention_to_chain_index() {
694 let chain1 = CorefChain::new(vec![Mention::new("John", 0, 4), Mention::new("he", 20, 22)]);
695 let chain2 = CorefChain::new(vec![
696 Mention::new("Mary", 5, 9),
697 Mention::new("she", 30, 33),
698 ]);
699 let doc = CorefDocument::new("text", vec![chain1, chain2]);
700
701 let index = doc.mention_to_chain_index();
702 assert_eq!(index.get(&(0, 4)), Some(&0));
703 assert_eq!(index.get(&(20, 22)), Some(&0));
704 assert_eq!(index.get(&(5, 9)), Some(&1));
705 assert_eq!(index.get(&(30, 33)), Some(&1));
706 }
707}