Skip to main content

anno_core/core/
coref.rs

1//! Coreference resolution data structures.
2//!
3//! Provides types for representing coreference chains (clusters of mentions
4//! that refer to the same entity) and utilities for working with them.
5//!
6//! # Terminology
7//!
8//! - **Mention**: A span of text referring to an entity (e.g., "John", "he", "the CEO")
9//! - **Chain/Cluster**: A set of mentions that corefer (refer to the same entity)
10//! - **Singleton**: A chain with only one mention (entity mentioned only once)
11//! - **Antecedent**: An earlier mention that a pronoun/noun phrase refers to
12//!
13//! # Example
14//!
15//! ```rust
16//! use anno_core::core::coref::{Mention, CorefChain, CorefDocument};
17//!
18//! // "John went to the store. He bought milk."
19//! let john = Mention::new("John", 0, 4);
20//! let he = Mention::new("He", 25, 27);
21//!
22//! let chain = CorefChain::new(vec![john, he]);
23//! assert_eq!(chain.len(), 2);
24//! assert!(!chain.is_singleton());
25//! ```
26
27use super::Entity;
28use serde::{Deserialize, Serialize};
29use std::collections::{HashMap, HashSet};
30
31// Re-export MentionType for convenience
32pub use super::types::MentionType;
33
34// =============================================================================
35// Mention
36// =============================================================================
37
38/// A single mention (text span) that may corefer with other mentions.
39///
40/// Mentions are comparable by span position, not by text content.
41/// Two mentions with identical text at different positions are distinct.
42///
43/// # Character vs Byte Offsets
44///
45/// `start` and `end` are **character** offsets, not byte offsets.
46/// For "北京 Beijing", the character offsets are:
47/// - "北" = 0..1 (but 3 bytes in UTF-8)
48/// - "京" = 1..2 (but 3 bytes)
49/// - " " = 2..3
50/// - "Beijing" = 3..10
51///
52/// Use `text.chars().skip(start).take(end - start).collect()` to extract.
53///
54/// # Head Span
55///
56/// The `head_start`/`head_end` fields mark the syntactic head for head-match
57/// evaluation (used in CEAF-e, LEA metrics). In "the former president of France",
58/// the head is "president" - the noun that determines agreement.
59#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub struct Mention {
61    /// The mention text (surface form).
62    pub text: String,
63    /// Start character offset (inclusive, 0-indexed).
64    pub start: usize,
65    /// End character offset (exclusive).
66    pub end: usize,
67    /// Head word start (for head-match metrics like CEAF).
68    pub head_start: Option<usize>,
69    /// Head word end.
70    pub head_end: Option<usize>,
71    /// Entity type if known (e.g., "PER", "ORG").
72    pub entity_type: Option<String>,
73    /// Mention category: Pronominal, Proper, Nominal, Zero.
74    pub mention_type: Option<MentionType>,
75}
76
77impl Mention {
78    /// `Mention::new("John", 0, 4)` creates a mention for "John" at characters 0..4.
79    ///
80    /// Offsets are character positions, not byte positions.
81    #[must_use]
82    pub fn new(text: impl Into<String>, start: usize, end: usize) -> Self {
83        Self {
84            text: text.into(),
85            start,
86            end,
87            head_start: None,
88            head_end: None,
89            entity_type: None,
90            mention_type: None,
91        }
92    }
93
94    /// Mention with head span for head-match evaluation.
95    ///
96    /// The head is the syntactic nucleus: in "the former president", head is "president".
97    ///
98    /// ```
99    /// # use anno_core::core::coref::Mention;
100    /// let m = Mention::with_head("the former president", 0, 20, 11, 20);
101    /// assert_eq!(m.head_start, Some(11)); // "president" starts at 11
102    /// ```
103    #[must_use]
104    pub fn with_head(
105        text: impl Into<String>,
106        start: usize,
107        end: usize,
108        head_start: usize,
109        head_end: usize,
110    ) -> Self {
111        Self {
112            text: text.into(),
113            start,
114            end,
115            head_start: Some(head_start),
116            head_end: Some(head_end),
117            entity_type: None,
118            mention_type: None,
119        }
120    }
121
122    /// Mention with type annotation for type-aware evaluation.
123    ///
124    /// ```
125    /// # use anno_core::core::coref::Mention;
126    /// # use anno_core::core::types::MentionType;
127    /// let pronoun = Mention::with_type("he", 25, 27, MentionType::Pronominal);
128    /// let proper = Mention::with_type("John Smith", 0, 10, MentionType::Proper);
129    /// ```
130    #[must_use]
131    pub fn with_type(
132        text: impl Into<String>,
133        start: usize,
134        end: usize,
135        mention_type: MentionType,
136    ) -> Self {
137        Self {
138            text: text.into(),
139            start,
140            end,
141            head_start: None,
142            head_end: None,
143            entity_type: None,
144            mention_type: Some(mention_type),
145        }
146    }
147
148    /// True if spans share any characters: `[0,5)` overlaps `[3,8)`.
149    #[must_use]
150    pub fn overlaps(&self, other: &Mention) -> bool {
151        self.start < other.end && other.start < self.end
152    }
153
154    /// True if spans are identical: same start AND end.
155    #[must_use]
156    pub fn span_matches(&self, other: &Mention) -> bool {
157        self.start == other.start && self.end == other.end
158    }
159
160    /// Span length in characters. Returns 0 if `end <= start`.
161    #[must_use]
162    pub fn len(&self) -> usize {
163        self.end.saturating_sub(self.start)
164    }
165
166    /// True if span has zero length.
167    #[must_use]
168    pub fn is_empty(&self) -> bool {
169        self.len() == 0
170    }
171
172    /// `(start, end)` tuple for use in hash sets and comparisons.
173    #[must_use]
174    pub fn span_id(&self) -> (usize, usize) {
175        (self.start, self.end)
176    }
177}
178
179impl std::fmt::Display for Mention {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        write!(f, "\"{}\" [{}-{})", self.text, self.start, self.end)
182    }
183}
184
185// =============================================================================
186// CorefChain (Cluster)
187// =============================================================================
188
189/// A coreference chain: mentions that all refer to the same entity.
190///
191/// ```
192/// # use anno_core::core::coref::{CorefChain, Mention};
193/// // "John went to the store. He bought milk."
194/// //  ^^^^                    ^^
195/// let john = Mention::new("John", 0, 4);
196/// let he = Mention::new("He", 25, 27);
197///
198/// let chain = CorefChain::new(vec![john, he]);
199/// assert_eq!(chain.len(), 2);
200/// assert!(!chain.is_singleton());
201/// ```
202///
203/// # Note
204///
205/// This type is for **evaluation and intermediate processing**. For production pipelines,
206/// use [`Track`](super::grounded::Track) which integrates with the Signal/Track/Identity hierarchy.
207#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
208pub struct CorefChain {
209    /// Mentions in document order (sorted by start position).
210    pub mentions: Vec<Mention>,
211    /// Cluster ID from the source data, if any.
212    pub cluster_id: Option<super::types::CanonicalId>,
213    /// Entity type shared by all mentions (e.g., "PERSON").
214    pub entity_type: Option<String>,
215}
216
217impl CorefChain {
218    /// Build a chain from mentions. Sorts by position automatically.
219    ///
220    /// ```
221    /// # use anno_core::core::coref::{CorefChain, Mention};
222    /// let chain = CorefChain::new(vec![
223    ///     Mention::new("she", 50, 53),
224    ///     Mention::new("Dr. Smith", 0, 9),  // out of order
225    /// ]);
226    /// assert_eq!(chain.mentions[0].text, "Dr. Smith"); // sorted
227    /// ```
228    #[must_use]
229    pub fn new(mut mentions: Vec<Mention>) -> Self {
230        mentions.sort_by_key(|m| (m.start, m.end));
231        Self {
232            mentions,
233            cluster_id: None,
234            entity_type: None,
235        }
236    }
237
238    /// Build a chain with an explicit cluster ID.
239    #[must_use]
240    pub fn with_id(
241        mut mentions: Vec<Mention>,
242        cluster_id: impl Into<super::types::CanonicalId>,
243    ) -> Self {
244        mentions.sort_by_key(|m| (m.start, m.end));
245        Self {
246            mentions,
247            cluster_id: Some(cluster_id.into()),
248            entity_type: None,
249        }
250    }
251
252    /// A chain with exactly one mention (entity mentioned only once).
253    #[must_use]
254    pub fn singleton(mention: Mention) -> Self {
255        Self {
256            mentions: vec![mention],
257            cluster_id: None,
258            entity_type: None,
259        }
260    }
261
262    /// Number of mentions. A chain with 3 mentions has 2 implicit "links".
263    #[must_use]
264    pub fn len(&self) -> usize {
265        self.mentions.len()
266    }
267
268    /// True if chain has no mentions. Shouldn't happen in valid data.
269    #[must_use]
270    pub fn is_empty(&self) -> bool {
271        self.mentions.is_empty()
272    }
273
274    /// True if chain has exactly one mention (singleton entity).
275    #[must_use]
276    pub fn is_singleton(&self) -> bool {
277        self.mentions.len() == 1
278    }
279
280    /// All pairwise links. For MUC: `n` mentions = `n*(n-1)/2` links.
281    ///
282    /// ```
283    /// # use anno_core::core::coref::{CorefChain, Mention};
284    /// let chain = CorefChain::new(vec![
285    ///     Mention::new("A", 0, 1),
286    ///     Mention::new("B", 2, 3),
287    ///     Mention::new("C", 4, 5),
288    /// ]);
289    /// assert_eq!(chain.links().len(), 3); // A-B, A-C, B-C
290    /// ```
291    #[must_use]
292    pub fn links(&self) -> Vec<(&Mention, &Mention)> {
293        let mut links = Vec::new();
294        for i in 0..self.mentions.len() {
295            for j in (i + 1)..self.mentions.len() {
296                links.push((&self.mentions[i], &self.mentions[j]));
297            }
298        }
299        links
300    }
301
302    /// Number of coreference links.
303    ///
304    /// For a chain of n mentions: n*(n-1)/2 pairs, but only n-1 links needed
305    /// to connect all mentions (spanning tree).
306    #[must_use]
307    pub fn link_count(&self) -> usize {
308        if self.mentions.len() <= 1 {
309            0
310        } else {
311            self.mentions.len() - 1
312        }
313    }
314
315    /// Get all pairwise mention combinations (for B³, CEAF).
316    #[must_use]
317    pub fn all_pairs(&self) -> Vec<(&Mention, &Mention)> {
318        self.links() // Same as links for non-directed pairs
319    }
320
321    /// Check if chain contains a mention with given span.
322    #[must_use]
323    pub fn contains_span(&self, start: usize, end: usize) -> bool {
324        self.mentions
325            .iter()
326            .any(|m| m.start == start && m.end == end)
327    }
328
329    /// Get first mention (usually the most salient/representative).
330    #[must_use]
331    pub fn first(&self) -> Option<&Mention> {
332        self.mentions.first()
333    }
334
335    /// Get set of mention span IDs for set operations.
336    #[must_use]
337    pub fn mention_spans(&self) -> HashSet<(usize, usize)> {
338        self.mentions.iter().map(|m| m.span_id()).collect()
339    }
340
341    /// Get the canonical (representative) mention for this chain.
342    ///
343    /// Prefers proper nouns over other mention types, then longest mention.
344    /// Falls back to first mention if no proper noun exists.
345    #[must_use]
346    pub fn canonical_mention(&self) -> Option<&Mention> {
347        // Prefer proper noun mentions
348        let proper = self
349            .mentions
350            .iter()
351            .filter(|m| m.mention_type == Some(MentionType::Proper))
352            .max_by_key(|m| m.text.len());
353
354        if proper.is_some() {
355            return proper;
356        }
357
358        // Fall back to longest mention (likely most informative)
359        self.mentions.iter().max_by_key(|m| m.text.len())
360    }
361
362    /// Get the canonical ID for this chain (cluster_id if set).
363    #[must_use]
364    pub fn canonical_id(&self) -> Option<super::types::CanonicalId> {
365        self.cluster_id
366    }
367}
368
369impl std::fmt::Display for CorefChain {
370    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371        let mentions: Vec<String> = self
372            .mentions
373            .iter()
374            .map(|m| format!("\"{}\"", m.text))
375            .collect();
376        write!(f, "[{}]", mentions.join(", "))
377    }
378}
379
380// =============================================================================
381// CorefDocument
382// =============================================================================
383
384/// A document with coreference annotations.
385///
386/// Contains the source text and all coreference chains.
387#[derive(Debug, Clone, Serialize, Deserialize)]
388pub struct CorefDocument {
389    /// Document text.
390    pub text: String,
391    /// Document identifier.
392    pub doc_id: Option<String>,
393    /// Coreference chains (clusters).
394    pub chains: Vec<CorefChain>,
395    /// Whether singletons are included.
396    pub includes_singletons: bool,
397}
398
399impl CorefDocument {
400    /// Create a new document with chains.
401    #[must_use]
402    pub fn new(text: impl Into<String>, chains: Vec<CorefChain>) -> Self {
403        Self {
404            text: text.into(),
405            doc_id: None,
406            chains,
407            includes_singletons: false,
408        }
409    }
410
411    /// Create document with ID.
412    #[must_use]
413    pub fn with_id(
414        text: impl Into<String>,
415        doc_id: impl Into<String>,
416        chains: Vec<CorefChain>,
417    ) -> Self {
418        Self {
419            text: text.into(),
420            doc_id: Some(doc_id.into()),
421            chains,
422            includes_singletons: false,
423        }
424    }
425
426    /// Total number of mentions across all chains.
427    #[must_use]
428    pub fn mention_count(&self) -> usize {
429        self.chains.iter().map(|c| c.len()).sum()
430    }
431
432    /// Number of chains (clusters).
433    #[must_use]
434    pub fn chain_count(&self) -> usize {
435        self.chains.len()
436    }
437
438    /// Number of non-singleton chains.
439    #[must_use]
440    pub fn non_singleton_count(&self) -> usize {
441        self.chains.iter().filter(|c| !c.is_singleton()).count()
442    }
443
444    /// Get all mentions in document order.
445    #[must_use]
446    pub fn all_mentions(&self) -> Vec<&Mention> {
447        let mut mentions: Vec<&Mention> = self.chains.iter().flat_map(|c| &c.mentions).collect();
448        mentions.sort_by_key(|m| (m.start, m.end));
449        mentions
450    }
451
452    /// Find which chain contains a mention span.
453    #[must_use]
454    pub fn find_chain(&self, start: usize, end: usize) -> Option<&CorefChain> {
455        self.chains.iter().find(|c| c.contains_span(start, end))
456    }
457
458    /// Build mention-to-chain index for fast lookup.
459    #[must_use]
460    pub fn mention_to_chain_index(&self) -> HashMap<(usize, usize), usize> {
461        let mut index = HashMap::new();
462        for (chain_idx, chain) in self.chains.iter().enumerate() {
463            for mention in &chain.mentions {
464                index.insert(mention.span_id(), chain_idx);
465            }
466        }
467        index
468    }
469
470    /// Filter to only non-singleton chains.
471    #[must_use]
472    pub fn without_singletons(&self) -> Self {
473        Self {
474            text: self.text.clone(),
475            doc_id: self.doc_id.clone(),
476            chains: self
477                .chains
478                .iter()
479                .filter(|c| !c.is_singleton())
480                .cloned()
481                .collect(),
482            includes_singletons: false,
483        }
484    }
485}
486
487// =============================================================================
488// Conversion from Entity to Mention
489// =============================================================================
490
491impl From<&Entity> for Mention {
492    fn from(entity: &Entity) -> Self {
493        Self {
494            text: entity.text.clone(),
495            start: entity.start,
496            end: entity.end,
497            head_start: None,
498            head_end: None,
499            entity_type: Some(entity.entity_type.as_label().to_string()),
500            mention_type: None,
501        }
502    }
503}
504
505/// Convert entities with canonical_id to coreference chains.
506///
507/// Entities sharing the same `canonical_id` are grouped into a chain.
508#[must_use]
509pub fn entities_to_chains(entities: &[Entity]) -> Vec<CorefChain> {
510    let mut clusters: HashMap<u64, Vec<Mention>> = HashMap::new();
511    let mut singletons: Vec<Mention> = Vec::new();
512
513    for entity in entities {
514        let mention = Mention::from(entity);
515        if let Some(canonical_id) = entity.canonical_id {
516            clusters
517                .entry(canonical_id.get())
518                .or_default()
519                .push(mention);
520        } else {
521            singletons.push(mention);
522        }
523    }
524
525    let mut chains: Vec<CorefChain> = clusters
526        .into_iter()
527        .map(|(id, mentions)| CorefChain::with_id(mentions, id))
528        .collect();
529
530    // Add singletons as individual chains
531    for mention in singletons {
532        chains.push(CorefChain::singleton(mention));
533    }
534
535    chains
536}
537
538// =============================================================================
539// CoreferenceResolver Trait
540// =============================================================================
541
542/// Trait for coreference resolution algorithms.
543///
544/// Implementors take a set of entity mentions and cluster them into
545/// coreference chains (groups of mentions referring to the same entity).
546///
547/// # Design Philosophy
548///
549/// This trait lives in `anno::core` because:
550/// 1. It depends only on core types (`Entity`, `CorefChain`)
551/// 2. Multiple crates need to implement it (backends, eval)
552/// 3. Keeping it here prevents circular dependencies
553///
554/// # Relationship to the Grounded Pipeline
555///
556/// `CoreferenceResolver` operates on the **evaluation/convenience layer** (`Entity`),
557/// not the canonical **grounded pipeline** (`Signal` → `Track` → `Identity`).
558///
559/// | Layer | Type | `CoreferenceResolver` role |
560/// |-------|------|----------------------------|
561/// | Detection (L1) | `Entity` | Input: mentions to cluster |
562/// | Coref (L2) | `Entity.canonical_id` | Output: cluster assignment |
563/// | Linking (L3) | `Identity` | (not covered by this trait) |
564///
565/// For integration with `GroundedDocument`, use backends that produce
566/// `Signal` + `Track` directly (e.g., `anno::backends::MentionRankingCoref`).
567///
568/// # Example Implementation
569///
570/// ```rust,ignore
571/// use anno_core::{CoreferenceResolver, Entity, CorefChain};
572///
573/// struct ExactMatchResolver;
574///
575/// impl CoreferenceResolver for ExactMatchResolver {
576///     fn resolve(&self, entities: &[Entity]) -> Vec<Entity> {
577///         // Cluster entities with identical text
578///         // ... implementation ...
579///     }
580///
581///     fn name(&self) -> &'static str {
582///         "exact-match"
583///     }
584/// }
585/// ```
586pub trait CoreferenceResolver: Send + Sync {
587    /// Resolve coreference, assigning canonical IDs to entities.
588    ///
589    /// Each entity in the output will have a `canonical_id` field set.
590    /// Entities with the same `canonical_id` are coreferent (refer to the
591    /// same real-world entity).
592    ///
593    /// # Invariants
594    ///
595    /// - Every output entity has `canonical_id.is_some()`
596    /// - Coreferent entities share the same `canonical_id`
597    /// - Singleton mentions get unique `canonical_id` values
598    fn resolve(&self, entities: &[Entity]) -> Vec<Entity>;
599
600    /// Resolve directly to chains.
601    ///
602    /// A chain groups all mentions of the same entity together.
603    /// This is often the desired output format for evaluation and
604    /// downstream tasks.
605    fn resolve_to_chains(&self, entities: &[Entity]) -> Vec<CorefChain> {
606        let resolved = self.resolve(entities);
607        entities_to_chains(&resolved)
608    }
609
610    /// Get resolver name.
611    ///
612    /// Used for logging, metrics, and result attribution.
613    fn name(&self) -> &'static str;
614}
615
616// =============================================================================
617// Tests
618// =============================================================================
619
620#[cfg(test)]
621mod tests {
622    use super::*;
623
624    #[test]
625    fn test_mention_creation() {
626        let m = Mention::new("John", 0, 4);
627        assert_eq!(m.text, "John");
628        assert_eq!(m.start, 0);
629        assert_eq!(m.end, 4);
630        assert_eq!(m.len(), 4);
631    }
632
633    #[test]
634    fn test_mention_overlap() {
635        let m1 = Mention::new("John Smith", 0, 10);
636        let m2 = Mention::new("Smith", 5, 10);
637        let m3 = Mention::new("works", 11, 16);
638
639        assert!(m1.overlaps(&m2));
640        assert!(!m1.overlaps(&m3));
641        assert!(!m2.overlaps(&m3));
642    }
643
644    #[test]
645    fn test_chain_creation() {
646        let mentions = vec![
647            Mention::new("John", 0, 4),
648            Mention::new("he", 20, 22),
649            Mention::new("him", 40, 43),
650        ];
651        let chain = CorefChain::new(mentions);
652
653        assert_eq!(chain.len(), 3);
654        assert!(!chain.is_singleton());
655        assert_eq!(chain.link_count(), 2); // Minimum links to connect
656    }
657
658    #[test]
659    fn test_chain_links() {
660        let mentions = vec![
661            Mention::new("a", 0, 1),
662            Mention::new("b", 2, 3),
663            Mention::new("c", 4, 5),
664        ];
665        let chain = CorefChain::new(mentions);
666
667        // All pairs: (a,b), (a,c), (b,c) = 3 pairs
668        assert_eq!(chain.all_pairs().len(), 3);
669    }
670
671    #[test]
672    fn test_singleton_chain() {
673        let m = Mention::new("entity", 0, 6);
674        let chain = CorefChain::singleton(m);
675
676        assert!(chain.is_singleton());
677        assert_eq!(chain.link_count(), 0);
678        assert!(chain.all_pairs().is_empty());
679    }
680
681    #[test]
682    fn test_document() {
683        let text = "John went to the store. He bought milk.";
684        let chain = CorefChain::new(vec![Mention::new("John", 0, 4), Mention::new("He", 24, 26)]);
685        let doc = CorefDocument::new(text, vec![chain]);
686
687        assert_eq!(doc.mention_count(), 2);
688        assert_eq!(doc.chain_count(), 1);
689        assert_eq!(doc.non_singleton_count(), 1);
690    }
691
692    #[test]
693    fn test_mention_to_chain_index() {
694        let chain1 = CorefChain::new(vec![Mention::new("John", 0, 4), Mention::new("he", 20, 22)]);
695        let chain2 = CorefChain::new(vec![
696            Mention::new("Mary", 5, 9),
697            Mention::new("she", 30, 33),
698        ]);
699        let doc = CorefDocument::new("text", vec![chain1, chain2]);
700
701        let index = doc.mention_to_chain_index();
702        assert_eq!(index.get(&(0, 4)), Some(&0));
703        assert_eq!(index.get(&(20, 22)), Some(&0));
704        assert_eq!(index.get(&(5, 9)), Some(&1));
705        assert_eq!(index.get(&(30, 33)), Some(&1));
706    }
707}