Skip to main content

anno_core/core/
coref.rs

1//! Coreference resolution data structures.
2//!
3//! Provides types for representing coreference chains (clusters of mentions
4//! that refer to the same entity) and utilities for working with them.
5//!
6//! # Terminology
7//!
8//! - **Mention**: A span of text referring to an entity (e.g., "John", "he", "the CEO")
9//! - **Chain/Cluster**: A set of mentions that corefer (refer to the same entity)
10//! - **Singleton**: A chain with only one mention (entity mentioned only once)
11//! - **Antecedent**: An earlier mention that a pronoun/noun phrase refers to
12//!
13//! # Example
14//!
15//! ```rust
16//! use anno_core::core::coref::{Mention, CorefChain, CorefDocument};
17//!
18//! // "John went to the store. He bought milk."
19//! let john = Mention::new("John", 0, 4);
20//! let he = Mention::new("He", 25, 27);
21//!
22//! let chain = CorefChain::new(vec![john, he]);
23//! assert_eq!(chain.len(), 2);
24//! assert!(!chain.is_singleton());
25//! ```
26
27use super::Entity;
28use serde::{Deserialize, Serialize};
29use std::collections::{HashMap, HashSet};
30
31// Re-export MentionType for convenience
32pub use super::types::MentionType;
33
34// =============================================================================
35// Mention
36// =============================================================================
37
38/// A single mention (text span) that may corefer with other mentions.
39///
40/// Mentions are comparable by span position, not by text content.
41/// Two mentions with identical text at different positions are distinct.
42///
43/// # Character vs Byte Offsets
44///
45/// `start` and `end` are **character** offsets, not byte offsets.
46/// For "北京 Beijing", the character offsets are:
47/// - "北" = 0..1 (but 3 bytes in UTF-8)
48/// - "京" = 1..2 (but 3 bytes)
49/// - " " = 2..3
50/// - "Beijing" = 3..10
51///
52/// Use `text.chars().skip(start).take(end - start).collect()` to extract.
53///
54/// # Head Span
55///
56/// The `head_start`/`head_end` fields mark the syntactic head for head-match
57/// evaluation (used in CEAF-e, LEA metrics). In "the former president of France",
58/// the head is "president" - the noun that determines agreement.
59#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub struct Mention {
61    /// The mention text (surface form).
62    pub text: String,
63    /// Start character offset (inclusive, 0-indexed).
64    pub start: usize,
65    /// End character offset (exclusive).
66    pub end: usize,
67    /// Head word start (for head-match metrics like CEAF).
68    pub head_start: Option<usize>,
69    /// Head word end.
70    pub head_end: Option<usize>,
71    /// Entity type if known (e.g., "PER", "ORG").
72    pub entity_type: Option<String>,
73    /// Mention category: Pronominal, Proper, Nominal, Zero.
74    pub mention_type: Option<MentionType>,
75}
76
77impl Mention {
78    /// `Mention::new("John", 0, 4)` creates a mention for "John" at characters 0..4.
79    ///
80    /// Offsets are character positions, not byte positions.
81    ///
82    /// ```
83    /// use anno_core::Mention;
84    ///
85    /// let m = Mention::new("John", 0, 4);
86    /// assert_eq!(m.text, "John");
87    /// assert_eq!(m.len(), 4);
88    /// assert_eq!(m.span_id(), (0, 4));
89    /// ```
90    #[must_use]
91    pub fn new(text: impl Into<String>, start: usize, end: usize) -> Self {
92        Self {
93            text: text.into(),
94            start,
95            end,
96            head_start: None,
97            head_end: None,
98            entity_type: None,
99            mention_type: None,
100        }
101    }
102
103    /// Mention with head span for head-match evaluation.
104    ///
105    /// The head is the syntactic nucleus: in "the former president", head is "president".
106    ///
107    /// ```
108    /// # use anno_core::core::coref::Mention;
109    /// let m = Mention::with_head("the former president", 0, 20, 11, 20);
110    /// assert_eq!(m.head_start, Some(11)); // "president" starts at 11
111    /// ```
112    #[must_use]
113    pub fn with_head(
114        text: impl Into<String>,
115        start: usize,
116        end: usize,
117        head_start: usize,
118        head_end: usize,
119    ) -> Self {
120        Self {
121            text: text.into(),
122            start,
123            end,
124            head_start: Some(head_start),
125            head_end: Some(head_end),
126            entity_type: None,
127            mention_type: None,
128        }
129    }
130
131    /// Mention with type annotation for type-aware evaluation.
132    ///
133    /// ```
134    /// # use anno_core::core::coref::Mention;
135    /// # use anno_core::core::types::MentionType;
136    /// let pronoun = Mention::with_type("he", 25, 27, MentionType::Pronominal);
137    /// let proper = Mention::with_type("John Smith", 0, 10, MentionType::Proper);
138    /// ```
139    #[must_use]
140    pub fn with_type(
141        text: impl Into<String>,
142        start: usize,
143        end: usize,
144        mention_type: MentionType,
145    ) -> Self {
146        Self {
147            text: text.into(),
148            start,
149            end,
150            head_start: None,
151            head_end: None,
152            entity_type: None,
153            mention_type: Some(mention_type),
154        }
155    }
156
157    /// True if spans share any characters: `[0,5)` overlaps `[3,8)`.
158    #[must_use]
159    pub fn overlaps(&self, other: &Mention) -> bool {
160        self.start < other.end && other.start < self.end
161    }
162
163    /// True if spans are identical: same start AND end.
164    #[must_use]
165    pub fn span_matches(&self, other: &Mention) -> bool {
166        self.start == other.start && self.end == other.end
167    }
168
169    /// Span length in characters. Returns 0 if `end <= start`.
170    #[must_use]
171    pub fn len(&self) -> usize {
172        self.end.saturating_sub(self.start)
173    }
174
175    /// True if span has zero length.
176    #[must_use]
177    pub fn is_empty(&self) -> bool {
178        self.len() == 0
179    }
180
181    /// `(start, end)` tuple for use in hash sets and comparisons.
182    #[must_use]
183    pub fn span_id(&self) -> (usize, usize) {
184        (self.start, self.end)
185    }
186
187    /// Head-span tuple `(head_start, head_end)` for head-match scoring.
188    ///
189    /// Returns the head span if both `head_start` and `head_end` are set,
190    /// otherwise falls back to the full span `(start, end)`.
191    ///
192    /// Head-match evaluation is standard in CRAC shared tasks: two mentions
193    /// match if their syntactic heads overlap, even when full spans differ
194    /// (e.g., "the president" vs "the former president of France" both have
195    /// head "president").
196    ///
197    /// ```
198    /// # use anno_core::core::coref::Mention;
199    /// let m = Mention::with_head("the former president", 0, 20, 11, 20);
200    /// assert_eq!(m.span_id_head(), (11, 20));
201    ///
202    /// let m2 = Mention::new("John", 0, 4);
203    /// assert_eq!(m2.span_id_head(), (0, 4)); // falls back to full span
204    /// ```
205    #[must_use]
206    pub fn span_id_head(&self) -> (usize, usize) {
207        match (self.head_start, self.head_end) {
208            (Some(hs), Some(he)) => (hs, he),
209            _ => (self.start, self.end),
210        }
211    }
212}
213
214impl std::fmt::Display for Mention {
215    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216        write!(f, "\"{}\" [{}-{})", self.text, self.start, self.end)
217    }
218}
219
220// =============================================================================
221// CorefChain (Cluster)
222// =============================================================================
223
224/// A coreference chain: mentions that all refer to the same entity.
225///
226/// ```
227/// # use anno_core::core::coref::{CorefChain, Mention};
228/// // "John went to the store. He bought milk."
229/// //  ^^^^                    ^^
230/// let john = Mention::new("John", 0, 4);
231/// let he = Mention::new("He", 25, 27);
232///
233/// let chain = CorefChain::new(vec![john, he]);
234/// assert_eq!(chain.len(), 2);
235/// assert!(!chain.is_singleton());
236/// ```
237///
238/// # Note
239///
240/// This type is for **evaluation and intermediate processing**. For production pipelines,
241/// use [`Track`](super::grounded::Track) which integrates with the Signal/Track/Identity hierarchy.
242#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
243pub struct CorefChain {
244    /// Mentions in document order (sorted by start position).
245    pub mentions: Vec<Mention>,
246    /// Cluster ID from the source data, if any.
247    pub cluster_id: Option<super::types::CanonicalId>,
248    /// Entity type shared by all mentions (e.g., "PERSON").
249    pub entity_type: Option<String>,
250}
251
252impl CorefChain {
253    /// Build a chain from mentions. Sorts by position automatically.
254    ///
255    /// ```
256    /// # use anno_core::core::coref::{CorefChain, Mention};
257    /// let chain = CorefChain::new(vec![
258    ///     Mention::new("she", 50, 53),
259    ///     Mention::new("Dr. Smith", 0, 9),  // out of order
260    /// ]);
261    /// assert_eq!(chain.mentions[0].text, "Dr. Smith"); // sorted
262    /// ```
263    #[must_use]
264    pub fn new(mut mentions: Vec<Mention>) -> Self {
265        mentions.sort_by_key(|m| (m.start, m.end));
266        Self {
267            mentions,
268            cluster_id: None,
269            entity_type: None,
270        }
271    }
272
273    /// Build a chain with an explicit cluster ID.
274    #[must_use]
275    pub fn with_id(
276        mut mentions: Vec<Mention>,
277        cluster_id: impl Into<super::types::CanonicalId>,
278    ) -> Self {
279        mentions.sort_by_key(|m| (m.start, m.end));
280        Self {
281            mentions,
282            cluster_id: Some(cluster_id.into()),
283            entity_type: None,
284        }
285    }
286
287    /// A chain with exactly one mention (entity mentioned only once).
288    #[must_use]
289    pub fn singleton(mention: Mention) -> Self {
290        Self {
291            mentions: vec![mention],
292            cluster_id: None,
293            entity_type: None,
294        }
295    }
296
297    /// Number of mentions. A chain with 3 mentions has 2 implicit "links".
298    #[must_use]
299    pub fn len(&self) -> usize {
300        self.mentions.len()
301    }
302
303    /// True if chain has no mentions. Shouldn't happen in valid data.
304    #[must_use]
305    pub fn is_empty(&self) -> bool {
306        self.mentions.is_empty()
307    }
308
309    /// True if chain has exactly one mention (singleton entity).
310    #[must_use]
311    pub fn is_singleton(&self) -> bool {
312        self.mentions.len() == 1
313    }
314
315    /// All pairwise links. For MUC: `n` mentions = `n*(n-1)/2` links.
316    ///
317    /// ```
318    /// # use anno_core::core::coref::{CorefChain, Mention};
319    /// let chain = CorefChain::new(vec![
320    ///     Mention::new("A", 0, 1),
321    ///     Mention::new("B", 2, 3),
322    ///     Mention::new("C", 4, 5),
323    /// ]);
324    /// assert_eq!(chain.links().len(), 3); // A-B, A-C, B-C
325    /// ```
326    #[must_use]
327    pub fn links(&self) -> Vec<(&Mention, &Mention)> {
328        let mut links = Vec::new();
329        for i in 0..self.mentions.len() {
330            for j in (i + 1)..self.mentions.len() {
331                links.push((&self.mentions[i], &self.mentions[j]));
332            }
333        }
334        links
335    }
336
337    /// Number of coreference links.
338    ///
339    /// For a chain of n mentions: n*(n-1)/2 pairs, but only n-1 links needed
340    /// to connect all mentions (spanning tree).
341    #[must_use]
342    pub fn link_count(&self) -> usize {
343        if self.mentions.len() <= 1 {
344            0
345        } else {
346            self.mentions.len() - 1
347        }
348    }
349
350    /// Get all pairwise mention combinations (for B³, CEAF).
351    #[must_use]
352    pub fn all_pairs(&self) -> Vec<(&Mention, &Mention)> {
353        self.links() // Same as links for non-directed pairs
354    }
355
356    /// Check if chain contains a mention with given span.
357    #[must_use]
358    pub fn contains_span(&self, start: usize, end: usize) -> bool {
359        self.mentions
360            .iter()
361            .any(|m| m.start == start && m.end == end)
362    }
363
364    /// Get first mention (usually the most salient/representative).
365    #[must_use]
366    pub fn first(&self) -> Option<&Mention> {
367        self.mentions.first()
368    }
369
370    /// Get set of mention span IDs for set operations.
371    #[must_use]
372    pub fn mention_spans(&self) -> HashSet<(usize, usize)> {
373        self.mentions.iter().map(|m| m.span_id()).collect()
374    }
375
376    /// Get the canonical (representative) mention for this chain.
377    ///
378    /// Prefers proper nouns over other mention types, then longest mention.
379    /// Falls back to first mention if no proper noun exists.
380    #[must_use]
381    pub fn canonical_mention(&self) -> Option<&Mention> {
382        // Prefer proper noun mentions
383        let proper = self
384            .mentions
385            .iter()
386            .filter(|m| m.mention_type == Some(MentionType::Proper))
387            .max_by_key(|m| m.text.len());
388
389        if proper.is_some() {
390            return proper;
391        }
392
393        // Fall back to longest mention (likely most informative)
394        self.mentions.iter().max_by_key(|m| m.text.len())
395    }
396
397    /// Get the canonical ID for this chain (cluster_id if set).
398    #[must_use]
399    pub fn canonical_id(&self) -> Option<super::types::CanonicalId> {
400        self.cluster_id
401    }
402}
403
404impl std::fmt::Display for CorefChain {
405    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406        let mentions: Vec<String> = self
407            .mentions
408            .iter()
409            .map(|m| format!("\"{}\"", m.text))
410            .collect();
411        write!(f, "[{}]", mentions.join(", "))
412    }
413}
414
415// =============================================================================
416// CorefDocument
417// =============================================================================
418
419/// A document with coreference annotations.
420///
421/// Contains the source text and all coreference chains.
422#[derive(Debug, Clone, Serialize, Deserialize)]
423pub struct CorefDocument {
424    /// Document text.
425    pub text: String,
426    /// Document identifier.
427    pub doc_id: Option<String>,
428    /// Coreference chains (clusters).
429    pub chains: Vec<CorefChain>,
430    /// Whether singletons are included.
431    pub includes_singletons: bool,
432}
433
434impl CorefDocument {
435    /// Create a new document with chains.
436    ///
437    /// ```
438    /// use anno_core::core::coref::{CorefDocument, CorefChain, Mention};
439    ///
440    /// let chain = CorefChain::new(vec![
441    ///     Mention::new("John", 0, 4),
442    ///     Mention::new("He", 24, 26),
443    /// ]);
444    /// let doc = CorefDocument::new("John went to the store. He bought milk.", vec![chain]);
445    /// assert_eq!(doc.mention_count(), 2);
446    /// assert_eq!(doc.chain_count(), 1);
447    /// ```
448    #[must_use]
449    pub fn new(text: impl Into<String>, chains: Vec<CorefChain>) -> Self {
450        Self {
451            text: text.into(),
452            doc_id: None,
453            chains,
454            includes_singletons: false,
455        }
456    }
457
458    /// Create document with ID.
459    #[must_use]
460    pub fn with_id(
461        text: impl Into<String>,
462        doc_id: impl Into<String>,
463        chains: Vec<CorefChain>,
464    ) -> Self {
465        Self {
466            text: text.into(),
467            doc_id: Some(doc_id.into()),
468            chains,
469            includes_singletons: false,
470        }
471    }
472
473    /// Total number of mentions across all chains.
474    #[must_use]
475    pub fn mention_count(&self) -> usize {
476        self.chains.iter().map(|c| c.len()).sum()
477    }
478
479    /// Number of chains (clusters).
480    #[must_use]
481    pub fn chain_count(&self) -> usize {
482        self.chains.len()
483    }
484
485    /// Number of non-singleton chains.
486    #[must_use]
487    pub fn non_singleton_count(&self) -> usize {
488        self.chains.iter().filter(|c| !c.is_singleton()).count()
489    }
490
491    /// Get all mentions in document order.
492    #[must_use]
493    pub fn all_mentions(&self) -> Vec<&Mention> {
494        let mut mentions: Vec<&Mention> = self.chains.iter().flat_map(|c| &c.mentions).collect();
495        mentions.sort_by_key(|m| (m.start, m.end));
496        mentions
497    }
498
499    /// Find which chain contains a mention span.
500    #[must_use]
501    pub fn find_chain(&self, start: usize, end: usize) -> Option<&CorefChain> {
502        self.chains.iter().find(|c| c.contains_span(start, end))
503    }
504
505    /// Build mention-to-chain index for fast lookup.
506    #[must_use]
507    pub fn mention_to_chain_index(&self) -> HashMap<(usize, usize), usize> {
508        let mut index = HashMap::new();
509        for (chain_idx, chain) in self.chains.iter().enumerate() {
510            for mention in &chain.mentions {
511                index.insert(mention.span_id(), chain_idx);
512            }
513        }
514        index
515    }
516
517    /// Filter to only non-singleton chains.
518    #[must_use]
519    pub fn without_singletons(&self) -> Self {
520        Self {
521            text: self.text.clone(),
522            doc_id: self.doc_id.clone(),
523            chains: self
524                .chains
525                .iter()
526                .filter(|c| !c.is_singleton())
527                .cloned()
528                .collect(),
529            includes_singletons: false,
530        }
531    }
532}
533
534// =============================================================================
535// Conversion from Entity to Mention
536// =============================================================================
537
538impl From<&Entity> for Mention {
539    fn from(entity: &Entity) -> Self {
540        Self {
541            text: entity.text.clone(),
542            start: entity.start,
543            end: entity.end,
544            head_start: None,
545            head_end: None,
546            entity_type: Some(entity.entity_type.as_label().to_string()),
547            mention_type: entity.mention_type,
548        }
549    }
550}
551
552/// Convert entities with canonical_id to coreference chains.
553///
554/// Entities sharing the same `canonical_id` are grouped into a chain.
555#[must_use]
556pub fn entities_to_chains(entities: &[Entity]) -> Vec<CorefChain> {
557    let mut clusters: HashMap<u64, Vec<Mention>> = HashMap::new();
558    let mut singletons: Vec<Mention> = Vec::new();
559
560    for entity in entities {
561        let mention = Mention::from(entity);
562        if let Some(canonical_id) = entity.canonical_id {
563            clusters
564                .entry(canonical_id.get())
565                .or_default()
566                .push(mention);
567        } else {
568            singletons.push(mention);
569        }
570    }
571
572    let mut chains: Vec<CorefChain> = clusters
573        .into_iter()
574        .map(|(id, mentions)| CorefChain::with_id(mentions, id))
575        .collect();
576
577    // Add singletons as individual chains
578    for mention in singletons {
579        chains.push(CorefChain::singleton(mention));
580    }
581
582    chains
583}
584
585// =============================================================================
586// CoreferenceResolver Trait
587// =============================================================================
588
589/// Trait for coreference resolution algorithms.
590///
591/// Implementors take a set of entity mentions and cluster them into
592/// coreference chains (groups of mentions referring to the same entity).
593///
594/// # Design Philosophy
595///
596/// This trait lives in `anno::core` because:
597/// 1. It depends only on core types (`Entity`, `CorefChain`)
598/// 2. Multiple crates need to implement it (backends, eval)
599/// 3. Keeping it here prevents circular dependencies
600///
601/// # Relationship to the Grounded Pipeline
602///
603/// `CoreferenceResolver` operates on the **evaluation/convenience layer** (`Entity`),
604/// not the canonical **grounded pipeline** (`Signal` → `Track` → `Identity`).
605///
606/// | Layer | Type | `CoreferenceResolver` role |
607/// |-------|------|----------------------------|
608/// | Detection (L1) | `Entity` | Input: mentions to cluster |
609/// | Coref (L2) | `Entity.canonical_id` | Output: cluster assignment |
610/// | Linking (L3) | `Identity` | (not covered by this trait) |
611///
612/// For integration with `GroundedDocument`, use backends that produce
613/// `Signal` + `Track` directly (e.g., `anno::backends::MentionRankingCoref`).
614///
615/// # Example Implementation
616///
617/// ```rust
618/// use anno_core::{CoreferenceResolver, Entity, EntityType};
619///
620/// struct ExactMatchResolver;
621///
622/// impl CoreferenceResolver for ExactMatchResolver {
623///     fn resolve(&self, entities: &[Entity]) -> Vec<Entity> {
624///         // Trivially return entities unchanged for this example
625///         entities.to_vec()
626///     }
627///
628///     fn name(&self) -> &'static str {
629///         "exact-match"
630///     }
631/// }
632///
633/// let resolver = ExactMatchResolver;
634/// assert_eq!(resolver.name(), "exact-match");
635/// ```
636pub trait CoreferenceResolver: Send + Sync {
637    /// Resolve coreference, assigning canonical IDs to entities.
638    ///
639    /// Each entity in the output will have a `canonical_id` field set.
640    /// Entities with the same `canonical_id` are coreferent (refer to the
641    /// same real-world entity).
642    ///
643    /// # Invariants
644    ///
645    /// - Every output entity has `canonical_id.is_some()`
646    /// - Coreferent entities share the same `canonical_id`
647    /// - Singleton mentions get unique `canonical_id` values
648    fn resolve(&self, entities: &[Entity]) -> Vec<Entity>;
649
650    /// Resolve directly to chains.
651    ///
652    /// A chain groups all mentions of the same entity together.
653    /// This is often the desired output format for evaluation and
654    /// downstream tasks.
655    fn resolve_to_chains(&self, entities: &[Entity]) -> Vec<CorefChain> {
656        let resolved = self.resolve(entities);
657        entities_to_chains(&resolved)
658    }
659
660    /// Get resolver name.
661    ///
662    /// Used for logging, metrics, and result attribution.
663    fn name(&self) -> &'static str;
664}
665
666// =============================================================================
667// Tests
668// =============================================================================
669
670#[cfg(test)]
671mod tests {
672    use super::*;
673
674    #[test]
675    fn test_mention_creation() {
676        let m = Mention::new("John", 0, 4);
677        assert_eq!(m.text, "John");
678        assert_eq!(m.start, 0);
679        assert_eq!(m.end, 4);
680        assert_eq!(m.len(), 4);
681    }
682
683    #[test]
684    fn test_mention_overlap() {
685        let m1 = Mention::new("John Smith", 0, 10);
686        let m2 = Mention::new("Smith", 5, 10);
687        let m3 = Mention::new("works", 11, 16);
688
689        assert!(m1.overlaps(&m2));
690        assert!(!m1.overlaps(&m3));
691        assert!(!m2.overlaps(&m3));
692    }
693
694    #[test]
695    fn test_chain_creation() {
696        let mentions = vec![
697            Mention::new("John", 0, 4),
698            Mention::new("he", 20, 22),
699            Mention::new("him", 40, 43),
700        ];
701        let chain = CorefChain::new(mentions);
702
703        assert_eq!(chain.len(), 3);
704        assert!(!chain.is_singleton());
705        assert_eq!(chain.link_count(), 2); // Minimum links to connect
706    }
707
708    #[test]
709    fn test_chain_links() {
710        let mentions = vec![
711            Mention::new("a", 0, 1),
712            Mention::new("b", 2, 3),
713            Mention::new("c", 4, 5),
714        ];
715        let chain = CorefChain::new(mentions);
716
717        // All pairs: (a,b), (a,c), (b,c) = 3 pairs
718        assert_eq!(chain.all_pairs().len(), 3);
719    }
720
721    #[test]
722    fn test_singleton_chain() {
723        let m = Mention::new("entity", 0, 6);
724        let chain = CorefChain::singleton(m);
725
726        assert!(chain.is_singleton());
727        assert_eq!(chain.link_count(), 0);
728        assert!(chain.all_pairs().is_empty());
729    }
730
731    #[test]
732    fn test_document() {
733        let text = "John went to the store. He bought milk.";
734        let chain = CorefChain::new(vec![Mention::new("John", 0, 4), Mention::new("He", 24, 26)]);
735        let doc = CorefDocument::new(text, vec![chain]);
736
737        assert_eq!(doc.mention_count(), 2);
738        assert_eq!(doc.chain_count(), 1);
739        assert_eq!(doc.non_singleton_count(), 1);
740    }
741
742    #[test]
743    fn test_mention_to_chain_index() {
744        let chain1 = CorefChain::new(vec![Mention::new("John", 0, 4), Mention::new("he", 20, 22)]);
745        let chain2 = CorefChain::new(vec![
746            Mention::new("Mary", 5, 9),
747            Mention::new("she", 30, 33),
748        ]);
749        let doc = CorefDocument::new("text", vec![chain1, chain2]);
750
751        let index = doc.mention_to_chain_index();
752        assert_eq!(index.get(&(0, 4)), Some(&0));
753        assert_eq!(index.get(&(20, 22)), Some(&0));
754        assert_eq!(index.get(&(5, 9)), Some(&1));
755        assert_eq!(index.get(&(30, 33)), Some(&1));
756    }
757
758    // =========================================================================
759    // Edge case tests
760    // =========================================================================
761
762    #[test]
763    fn test_unicode_mention_offsets() {
764        // "北京 Beijing" — character offsets, not byte offsets.
765        // "北" is 3 bytes in UTF-8 but 1 character.
766        let m = Mention::new("北京", 0, 2); // 2 characters, not 6 bytes
767        assert_eq!(m.len(), 2);
768        assert_eq!(m.span_id(), (0, 2));
769        assert!(!m.is_empty());
770    }
771
772    #[test]
773    fn test_zero_length_mention() {
774        // Zero anaphora / empty mention at position 5.
775        let m = Mention::new("", 5, 5);
776        assert!(m.is_empty());
777        assert_eq!(m.len(), 0);
778        assert_eq!(m.span_id(), (5, 5));
779    }
780
781    #[test]
782    fn test_empty_chain() {
783        let chain = CorefChain::new(vec![]);
784        assert!(chain.is_empty());
785        assert_eq!(chain.link_count(), 0);
786        assert!(chain.all_pairs().is_empty());
787        assert!(chain.first().is_none());
788        assert!(chain.canonical_mention().is_none());
789    }
790
791    #[test]
792    fn test_chain_sorting_out_of_order() {
793        // Mentions given out of document order should be sorted by (start, end).
794        let chain = CorefChain::new(vec![
795            Mention::new("c", 20, 21),
796            Mention::new("a", 0, 1),
797            Mention::new("b", 10, 11),
798        ]);
799        assert_eq!(chain.mentions[0].text, "a");
800        assert_eq!(chain.mentions[1].text, "b");
801        assert_eq!(chain.mentions[2].text, "c");
802    }
803
804    #[test]
805    fn test_chain_sorting_ties_broken_by_end() {
806        // Same start, different end: shorter span first.
807        let chain = CorefChain::new(vec![
808            Mention::new("John Smith", 0, 10),
809            Mention::new("John", 0, 4),
810        ]);
811        assert_eq!(chain.mentions[0].text, "John");
812        assert_eq!(chain.mentions[1].text, "John Smith");
813    }
814
815    #[test]
816    fn test_entities_to_chains_grouped() {
817        use super::super::entity::EntityType;
818        use super::super::types::CanonicalId;
819
820        let e1 = super::super::Entity::new("John", EntityType::Person, 0, 4, 0.9)
821            .with_canonical_id(1_u64);
822        let e2 = super::super::Entity::new("he", EntityType::Person, 20, 22, 0.8)
823            .with_canonical_id(1_u64);
824        let e3 = super::super::Entity::new("Mary", EntityType::Person, 5, 9, 0.95)
825            .with_canonical_id(2_u64);
826
827        let chains = entities_to_chains(&[e1, e2, e3]);
828
829        // Two canonical_ids -> two chains
830        assert_eq!(chains.len(), 2);
831
832        // Find the chain with cluster_id=1 (John + he)
833        let chain1 = chains
834            .iter()
835            .find(|c| c.cluster_id == Some(CanonicalId::new(1)))
836            .expect("chain with id=1");
837        assert_eq!(chain1.len(), 2);
838
839        // Find the chain with cluster_id=2 (Mary)
840        let chain2 = chains
841            .iter()
842            .find(|c| c.cluster_id == Some(CanonicalId::new(2)))
843            .expect("chain with id=2");
844        assert_eq!(chain2.len(), 1);
845    }
846
847    #[test]
848    fn test_entities_to_chains_singletons() {
849        use super::super::entity::EntityType;
850
851        // Entities without canonical_id become individual singleton chains.
852        let e1 = super::super::Entity::new("Paris", EntityType::Location, 0, 5, 0.9);
853        let e2 = super::super::Entity::new("London", EntityType::Location, 10, 16, 0.85);
854
855        let chains = entities_to_chains(&[e1, e2]);
856        assert_eq!(chains.len(), 2);
857        assert!(chains.iter().all(|c| c.is_singleton()));
858    }
859
860    #[test]
861    fn test_entities_to_chains_empty() {
862        let chains = entities_to_chains(&[]);
863        assert!(chains.is_empty());
864    }
865
866    #[test]
867    fn test_without_singletons_filters() {
868        let singleton = CorefChain::singleton(Mention::new("solo", 0, 4));
869        let multi = CorefChain::new(vec![
870            Mention::new("John", 10, 14),
871            Mention::new("he", 20, 22),
872        ]);
873        let doc = CorefDocument::new("text", vec![singleton, multi]);
874
875        let filtered = doc.without_singletons();
876        assert_eq!(filtered.chain_count(), 1);
877        assert_eq!(filtered.chains[0].len(), 2);
878        assert!(!filtered.includes_singletons);
879    }
880
881    #[test]
882    fn test_without_singletons_preserves_non_singletons() {
883        let c1 = CorefChain::new(vec![Mention::new("a", 0, 1), Mention::new("b", 2, 3)]);
884        let c2 = CorefChain::new(vec![
885            Mention::new("x", 10, 11),
886            Mention::new("y", 12, 13),
887            Mention::new("z", 14, 15),
888        ]);
889        let doc = CorefDocument::new("text", vec![c1.clone(), c2.clone()]);
890
891        let filtered = doc.without_singletons();
892        assert_eq!(filtered.chain_count(), 2);
893    }
894
895    #[test]
896    fn test_without_singletons_all_singletons() {
897        let s1 = CorefChain::singleton(Mention::new("a", 0, 1));
898        let s2 = CorefChain::singleton(Mention::new("b", 2, 3));
899        let doc = CorefDocument::new("text", vec![s1, s2]);
900
901        let filtered = doc.without_singletons();
902        assert!(filtered.chains.is_empty());
903    }
904
905    #[test]
906    fn test_overlaps_adjacent_non_overlapping() {
907        // [0,5) and [5,10) are adjacent but NOT overlapping (half-open intervals).
908        let m1 = Mention::new("hello", 0, 5);
909        let m2 = Mention::new("world", 5, 10);
910        assert!(!m1.overlaps(&m2));
911        assert!(!m2.overlaps(&m1));
912    }
913
914    #[test]
915    fn test_overlaps_nested() {
916        // [0,10) fully contains [2,5).
917        let outer = Mention::new("the big dog", 0, 10);
918        let inner = Mention::new("big", 2, 5);
919        assert!(outer.overlaps(&inner));
920        assert!(inner.overlaps(&outer));
921    }
922
923    #[test]
924    fn test_chain_with_id() {
925        let chain = CorefChain::with_id(
926            vec![Mention::new("John", 0, 4), Mention::new("he", 10, 12)],
927            42_u64,
928        );
929        assert_eq!(
930            chain.canonical_id(),
931            Some(super::super::types::CanonicalId::new(42))
932        );
933        assert_eq!(
934            chain.cluster_id,
935            Some(super::super::types::CanonicalId::new(42))
936        );
937        // Mentions should still be sorted.
938        assert_eq!(chain.mentions[0].text, "John");
939    }
940}
941
942#[cfg(test)]
943mod proptests {
944    #![allow(clippy::unwrap_used)]
945    use super::*;
946    use proptest::prelude::*;
947
948    /// Strategy to generate a Mention with bounded offsets.
949    fn arb_mention(max_offset: usize) -> impl Strategy<Value = Mention> {
950        (0usize..max_offset, 1usize..500)
951            .prop_map(|(start, len)| Mention::new(format!("m_{}", start), start, start + len))
952    }
953
954    proptest! {
955        /// Mentions sort by (start, end) consistently after CorefChain::new.
956        #[test]
957        fn mention_ordering_after_chain_construction(
958            mentions in proptest::collection::vec(arb_mention(10000), 1..20),
959        ) {
960            let chain = CorefChain::new(mentions);
961            for w in chain.mentions.windows(2) {
962                prop_assert!(
963                    (w[0].start, w[0].end) <= (w[1].start, w[1].end),
964                    "mentions must be sorted by (start, end): ({},{}) vs ({},{})",
965                    w[0].start, w[0].end, w[1].start, w[1].end
966                );
967            }
968        }
969
970        /// CorefChain constructed with at least one mention is never empty.
971        #[test]
972        fn coref_chain_non_empty(
973            mentions in proptest::collection::vec(arb_mention(10000), 1..20),
974        ) {
975            let n = mentions.len();
976            let chain = CorefChain::new(mentions);
977            prop_assert!(!chain.is_empty());
978            prop_assert_eq!(chain.len(), n);
979        }
980
981        /// CorefChain::singleton always produces a chain with exactly one mention.
982        #[test]
983        fn coref_chain_singleton_has_one(start in 0usize..10000, len in 1usize..500) {
984            let m = Mention::new("x", start, start + len);
985            let chain = CorefChain::singleton(m);
986            prop_assert!(chain.is_singleton());
987            prop_assert_eq!(chain.len(), 1);
988            prop_assert_eq!(chain.link_count(), 0);
989        }
990
991        /// Mention::overlaps is symmetric.
992        #[test]
993        fn mention_overlap_symmetric(
994            s1 in 0usize..10000, len1 in 1usize..500,
995            s2 in 0usize..10000, len2 in 1usize..500,
996        ) {
997            let m1 = Mention::new("a", s1, s1 + len1);
998            let m2 = Mention::new("b", s2, s2 + len2);
999            prop_assert_eq!(m1.overlaps(&m2), m2.overlaps(&m1));
1000        }
1001
1002        /// Mention serde roundtrip preserves all fields.
1003        #[test]
1004        fn mention_serde_roundtrip(
1005            start in 0usize..10000, len in 1usize..500,
1006        ) {
1007            let m = Mention::new(format!("mention_{}", start), start, start + len);
1008            let json = serde_json::to_string(&m).unwrap();
1009            let m2: Mention = serde_json::from_str(&json).unwrap();
1010            prop_assert_eq!(&m, &m2);
1011        }
1012    }
1013}