Skip to main content

anno_core/core/
coref.rs

1//! Coreference resolution data structures.
2//!
3//! Provides types for representing coreference chains (clusters of mentions
4//! that refer to the same entity) and utilities for working with them.
5//!
6//! # Terminology
7//!
8//! - **Mention**: A span of text referring to an entity (e.g., "John", "he", "the CEO")
9//! - **Chain/Cluster**: A set of mentions that corefer (refer to the same entity)
10//! - **Singleton**: A chain with only one mention (entity mentioned only once)
11//! - **Antecedent**: An earlier mention that a pronoun/noun phrase refers to
12//!
13//! # Example
14//!
15//! ```rust
16//! use anno_core::core::coref::{Mention, CorefChain, CorefDocument};
17//!
18//! // "John went to the store. He bought milk."
19//! let john = Mention::new("John", 0, 4);
20//! let he = Mention::new("He", 25, 27);
21//!
22//! let chain = CorefChain::new(vec![john, he]);
23//! assert_eq!(chain.len(), 2);
24//! assert!(!chain.is_singleton());
25//! ```
26
27use super::Entity;
28use serde::{Deserialize, Serialize};
29use std::collections::{HashMap, HashSet};
30
31// Re-export MentionType for convenience
32pub use super::types::MentionType;
33
34// =============================================================================
35// Mention
36// =============================================================================
37
38/// A single mention (text span) that may corefer with other mentions.
39///
40/// Mentions are comparable by span position, not by text content.
41/// Two mentions with identical text at different positions are distinct.
42///
43/// # Character vs Byte Offsets
44///
45/// `start` and `end` are **character** offsets, not byte offsets.
46/// For "北京 Beijing", the character offsets are:
47/// - "北" = 0..1 (but 3 bytes in UTF-8)
48/// - "京" = 1..2 (but 3 bytes)
49/// - " " = 2..3
50/// - "Beijing" = 3..10
51///
52/// Use `text.chars().skip(start).take(end - start).collect()` to extract.
53///
54/// # Head Span
55///
56/// The `head_start`/`head_end` fields mark the syntactic head for head-match
57/// evaluation (used in CEAF-e, LEA metrics). In "the former president of France",
58/// the head is "president" - the noun that determines agreement.
59#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub struct Mention {
61    /// The mention text (surface form).
62    pub text: String,
63    /// Start character offset (inclusive, 0-indexed).
64    pub start: usize,
65    /// End character offset (exclusive).
66    pub end: usize,
67    /// Head word start (for head-match metrics like CEAF).
68    pub head_start: Option<usize>,
69    /// Head word end.
70    pub head_end: Option<usize>,
71    /// Entity type if known (e.g., "PER", "ORG").
72    pub entity_type: Option<String>,
73    /// Mention category: Pronominal, Proper, Nominal, Zero.
74    pub mention_type: Option<MentionType>,
75}
76
77impl Mention {
78    /// `Mention::new("John", 0, 4)` creates a mention for "John" at characters 0..4.
79    ///
80    /// Offsets are character positions, not byte positions.
81    ///
82    /// ```
83    /// use anno_core::Mention;
84    ///
85    /// let m = Mention::new("John", 0, 4);
86    /// assert_eq!(m.text, "John");
87    /// assert_eq!(m.len(), 4);
88    /// assert_eq!(m.span_id(), (0, 4));
89    /// ```
90    #[must_use]
91    pub fn new(text: impl Into<String>, start: usize, end: usize) -> Self {
92        Self {
93            text: text.into(),
94            start,
95            end,
96            head_start: None,
97            head_end: None,
98            entity_type: None,
99            mention_type: None,
100        }
101    }
102
103    /// Mention with head span for head-match evaluation.
104    ///
105    /// The head is the syntactic nucleus: in "the former president", head is "president".
106    ///
107    /// ```
108    /// # use anno_core::core::coref::Mention;
109    /// let m = Mention::with_head("the former president", 0, 20, 11, 20);
110    /// assert_eq!(m.head_start, Some(11)); // "president" starts at 11
111    /// ```
112    #[must_use]
113    pub fn with_head(
114        text: impl Into<String>,
115        start: usize,
116        end: usize,
117        head_start: usize,
118        head_end: usize,
119    ) -> Self {
120        Self {
121            text: text.into(),
122            start,
123            end,
124            head_start: Some(head_start),
125            head_end: Some(head_end),
126            entity_type: None,
127            mention_type: None,
128        }
129    }
130
131    /// Mention with type annotation for type-aware evaluation.
132    ///
133    /// ```
134    /// # use anno_core::core::coref::Mention;
135    /// # use anno_core::core::types::MentionType;
136    /// let pronoun = Mention::with_type("he", 25, 27, MentionType::Pronominal);
137    /// let proper = Mention::with_type("John Smith", 0, 10, MentionType::Proper);
138    /// ```
139    #[must_use]
140    pub fn with_type(
141        text: impl Into<String>,
142        start: usize,
143        end: usize,
144        mention_type: MentionType,
145    ) -> Self {
146        Self {
147            text: text.into(),
148            start,
149            end,
150            head_start: None,
151            head_end: None,
152            entity_type: None,
153            mention_type: Some(mention_type),
154        }
155    }
156
157    /// True if spans share any characters: `[0,5)` overlaps `[3,8)`.
158    #[must_use]
159    pub fn overlaps(&self, other: &Mention) -> bool {
160        self.start < other.end && other.start < self.end
161    }
162
163    /// True if spans are identical: same start AND end.
164    #[must_use]
165    pub fn span_matches(&self, other: &Mention) -> bool {
166        self.start == other.start && self.end == other.end
167    }
168
169    /// Span length in characters. Returns 0 if `end <= start`.
170    #[must_use]
171    pub fn len(&self) -> usize {
172        self.end.saturating_sub(self.start)
173    }
174
175    /// True if span has zero length.
176    #[must_use]
177    pub fn is_empty(&self) -> bool {
178        self.len() == 0
179    }
180
181    /// `(start, end)` tuple for use in hash sets and comparisons.
182    #[must_use]
183    pub fn span_id(&self) -> (usize, usize) {
184        (self.start, self.end)
185    }
186}
187
188impl std::fmt::Display for Mention {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        write!(f, "\"{}\" [{}-{})", self.text, self.start, self.end)
191    }
192}
193
194// =============================================================================
195// CorefChain (Cluster)
196// =============================================================================
197
198/// A coreference chain: mentions that all refer to the same entity.
199///
200/// ```
201/// # use anno_core::core::coref::{CorefChain, Mention};
202/// // "John went to the store. He bought milk."
203/// //  ^^^^                    ^^
204/// let john = Mention::new("John", 0, 4);
205/// let he = Mention::new("He", 25, 27);
206///
207/// let chain = CorefChain::new(vec![john, he]);
208/// assert_eq!(chain.len(), 2);
209/// assert!(!chain.is_singleton());
210/// ```
211///
212/// # Note
213///
214/// This type is for **evaluation and intermediate processing**. For production pipelines,
215/// use [`Track`](super::grounded::Track) which integrates with the Signal/Track/Identity hierarchy.
216#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
217pub struct CorefChain {
218    /// Mentions in document order (sorted by start position).
219    pub mentions: Vec<Mention>,
220    /// Cluster ID from the source data, if any.
221    pub cluster_id: Option<super::types::CanonicalId>,
222    /// Entity type shared by all mentions (e.g., "PERSON").
223    pub entity_type: Option<String>,
224}
225
226impl CorefChain {
227    /// Build a chain from mentions. Sorts by position automatically.
228    ///
229    /// ```
230    /// # use anno_core::core::coref::{CorefChain, Mention};
231    /// let chain = CorefChain::new(vec![
232    ///     Mention::new("she", 50, 53),
233    ///     Mention::new("Dr. Smith", 0, 9),  // out of order
234    /// ]);
235    /// assert_eq!(chain.mentions[0].text, "Dr. Smith"); // sorted
236    /// ```
237    #[must_use]
238    pub fn new(mut mentions: Vec<Mention>) -> Self {
239        mentions.sort_by_key(|m| (m.start, m.end));
240        Self {
241            mentions,
242            cluster_id: None,
243            entity_type: None,
244        }
245    }
246
247    /// Build a chain with an explicit cluster ID.
248    #[must_use]
249    pub fn with_id(
250        mut mentions: Vec<Mention>,
251        cluster_id: impl Into<super::types::CanonicalId>,
252    ) -> Self {
253        mentions.sort_by_key(|m| (m.start, m.end));
254        Self {
255            mentions,
256            cluster_id: Some(cluster_id.into()),
257            entity_type: None,
258        }
259    }
260
261    /// A chain with exactly one mention (entity mentioned only once).
262    #[must_use]
263    pub fn singleton(mention: Mention) -> Self {
264        Self {
265            mentions: vec![mention],
266            cluster_id: None,
267            entity_type: None,
268        }
269    }
270
271    /// Number of mentions. A chain with 3 mentions has 2 implicit "links".
272    #[must_use]
273    pub fn len(&self) -> usize {
274        self.mentions.len()
275    }
276
277    /// True if chain has no mentions. Shouldn't happen in valid data.
278    #[must_use]
279    pub fn is_empty(&self) -> bool {
280        self.mentions.is_empty()
281    }
282
283    /// True if chain has exactly one mention (singleton entity).
284    #[must_use]
285    pub fn is_singleton(&self) -> bool {
286        self.mentions.len() == 1
287    }
288
289    /// All pairwise links. For MUC: `n` mentions = `n*(n-1)/2` links.
290    ///
291    /// ```
292    /// # use anno_core::core::coref::{CorefChain, Mention};
293    /// let chain = CorefChain::new(vec![
294    ///     Mention::new("A", 0, 1),
295    ///     Mention::new("B", 2, 3),
296    ///     Mention::new("C", 4, 5),
297    /// ]);
298    /// assert_eq!(chain.links().len(), 3); // A-B, A-C, B-C
299    /// ```
300    #[must_use]
301    pub fn links(&self) -> Vec<(&Mention, &Mention)> {
302        let mut links = Vec::new();
303        for i in 0..self.mentions.len() {
304            for j in (i + 1)..self.mentions.len() {
305                links.push((&self.mentions[i], &self.mentions[j]));
306            }
307        }
308        links
309    }
310
311    /// Number of coreference links.
312    ///
313    /// For a chain of n mentions: n*(n-1)/2 pairs, but only n-1 links needed
314    /// to connect all mentions (spanning tree).
315    #[must_use]
316    pub fn link_count(&self) -> usize {
317        if self.mentions.len() <= 1 {
318            0
319        } else {
320            self.mentions.len() - 1
321        }
322    }
323
324    /// Get all pairwise mention combinations (for B³, CEAF).
325    #[must_use]
326    pub fn all_pairs(&self) -> Vec<(&Mention, &Mention)> {
327        self.links() // Same as links for non-directed pairs
328    }
329
330    /// Check if chain contains a mention with given span.
331    #[must_use]
332    pub fn contains_span(&self, start: usize, end: usize) -> bool {
333        self.mentions
334            .iter()
335            .any(|m| m.start == start && m.end == end)
336    }
337
338    /// Get first mention (usually the most salient/representative).
339    #[must_use]
340    pub fn first(&self) -> Option<&Mention> {
341        self.mentions.first()
342    }
343
344    /// Get set of mention span IDs for set operations.
345    #[must_use]
346    pub fn mention_spans(&self) -> HashSet<(usize, usize)> {
347        self.mentions.iter().map(|m| m.span_id()).collect()
348    }
349
350    /// Get the canonical (representative) mention for this chain.
351    ///
352    /// Prefers proper nouns over other mention types, then longest mention.
353    /// Falls back to first mention if no proper noun exists.
354    #[must_use]
355    pub fn canonical_mention(&self) -> Option<&Mention> {
356        // Prefer proper noun mentions
357        let proper = self
358            .mentions
359            .iter()
360            .filter(|m| m.mention_type == Some(MentionType::Proper))
361            .max_by_key(|m| m.text.len());
362
363        if proper.is_some() {
364            return proper;
365        }
366
367        // Fall back to longest mention (likely most informative)
368        self.mentions.iter().max_by_key(|m| m.text.len())
369    }
370
371    /// Get the canonical ID for this chain (cluster_id if set).
372    #[must_use]
373    pub fn canonical_id(&self) -> Option<super::types::CanonicalId> {
374        self.cluster_id
375    }
376}
377
378impl std::fmt::Display for CorefChain {
379    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380        let mentions: Vec<String> = self
381            .mentions
382            .iter()
383            .map(|m| format!("\"{}\"", m.text))
384            .collect();
385        write!(f, "[{}]", mentions.join(", "))
386    }
387}
388
389// =============================================================================
390// CorefDocument
391// =============================================================================
392
393/// A document with coreference annotations.
394///
395/// Contains the source text and all coreference chains.
396#[derive(Debug, Clone, Serialize, Deserialize)]
397pub struct CorefDocument {
398    /// Document text.
399    pub text: String,
400    /// Document identifier.
401    pub doc_id: Option<String>,
402    /// Coreference chains (clusters).
403    pub chains: Vec<CorefChain>,
404    /// Whether singletons are included.
405    pub includes_singletons: bool,
406}
407
408impl CorefDocument {
409    /// Create a new document with chains.
410    ///
411    /// ```
412    /// use anno_core::core::coref::{CorefDocument, CorefChain, Mention};
413    ///
414    /// let chain = CorefChain::new(vec![
415    ///     Mention::new("John", 0, 4),
416    ///     Mention::new("He", 24, 26),
417    /// ]);
418    /// let doc = CorefDocument::new("John went to the store. He bought milk.", vec![chain]);
419    /// assert_eq!(doc.mention_count(), 2);
420    /// assert_eq!(doc.chain_count(), 1);
421    /// ```
422    #[must_use]
423    pub fn new(text: impl Into<String>, chains: Vec<CorefChain>) -> Self {
424        Self {
425            text: text.into(),
426            doc_id: None,
427            chains,
428            includes_singletons: false,
429        }
430    }
431
432    /// Create document with ID.
433    #[must_use]
434    pub fn with_id(
435        text: impl Into<String>,
436        doc_id: impl Into<String>,
437        chains: Vec<CorefChain>,
438    ) -> Self {
439        Self {
440            text: text.into(),
441            doc_id: Some(doc_id.into()),
442            chains,
443            includes_singletons: false,
444        }
445    }
446
447    /// Total number of mentions across all chains.
448    #[must_use]
449    pub fn mention_count(&self) -> usize {
450        self.chains.iter().map(|c| c.len()).sum()
451    }
452
453    /// Number of chains (clusters).
454    #[must_use]
455    pub fn chain_count(&self) -> usize {
456        self.chains.len()
457    }
458
459    /// Number of non-singleton chains.
460    #[must_use]
461    pub fn non_singleton_count(&self) -> usize {
462        self.chains.iter().filter(|c| !c.is_singleton()).count()
463    }
464
465    /// Get all mentions in document order.
466    #[must_use]
467    pub fn all_mentions(&self) -> Vec<&Mention> {
468        let mut mentions: Vec<&Mention> = self.chains.iter().flat_map(|c| &c.mentions).collect();
469        mentions.sort_by_key(|m| (m.start, m.end));
470        mentions
471    }
472
473    /// Find which chain contains a mention span.
474    #[must_use]
475    pub fn find_chain(&self, start: usize, end: usize) -> Option<&CorefChain> {
476        self.chains.iter().find(|c| c.contains_span(start, end))
477    }
478
479    /// Build mention-to-chain index for fast lookup.
480    #[must_use]
481    pub fn mention_to_chain_index(&self) -> HashMap<(usize, usize), usize> {
482        let mut index = HashMap::new();
483        for (chain_idx, chain) in self.chains.iter().enumerate() {
484            for mention in &chain.mentions {
485                index.insert(mention.span_id(), chain_idx);
486            }
487        }
488        index
489    }
490
491    /// Filter to only non-singleton chains.
492    #[must_use]
493    pub fn without_singletons(&self) -> Self {
494        Self {
495            text: self.text.clone(),
496            doc_id: self.doc_id.clone(),
497            chains: self
498                .chains
499                .iter()
500                .filter(|c| !c.is_singleton())
501                .cloned()
502                .collect(),
503            includes_singletons: false,
504        }
505    }
506}
507
508// =============================================================================
509// Conversion from Entity to Mention
510// =============================================================================
511
512impl From<&Entity> for Mention {
513    fn from(entity: &Entity) -> Self {
514        Self {
515            text: entity.text.clone(),
516            start: entity.start,
517            end: entity.end,
518            head_start: None,
519            head_end: None,
520            entity_type: Some(entity.entity_type.as_label().to_string()),
521            mention_type: None,
522        }
523    }
524}
525
526/// Convert entities with canonical_id to coreference chains.
527///
528/// Entities sharing the same `canonical_id` are grouped into a chain.
529#[must_use]
530pub fn entities_to_chains(entities: &[Entity]) -> Vec<CorefChain> {
531    let mut clusters: HashMap<u64, Vec<Mention>> = HashMap::new();
532    let mut singletons: Vec<Mention> = Vec::new();
533
534    for entity in entities {
535        let mention = Mention::from(entity);
536        if let Some(canonical_id) = entity.canonical_id {
537            clusters
538                .entry(canonical_id.get())
539                .or_default()
540                .push(mention);
541        } else {
542            singletons.push(mention);
543        }
544    }
545
546    let mut chains: Vec<CorefChain> = clusters
547        .into_iter()
548        .map(|(id, mentions)| CorefChain::with_id(mentions, id))
549        .collect();
550
551    // Add singletons as individual chains
552    for mention in singletons {
553        chains.push(CorefChain::singleton(mention));
554    }
555
556    chains
557}
558
559// =============================================================================
560// CoreferenceResolver Trait
561// =============================================================================
562
563/// Trait for coreference resolution algorithms.
564///
565/// Implementors take a set of entity mentions and cluster them into
566/// coreference chains (groups of mentions referring to the same entity).
567///
568/// # Design Philosophy
569///
570/// This trait lives in `anno::core` because:
571/// 1. It depends only on core types (`Entity`, `CorefChain`)
572/// 2. Multiple crates need to implement it (backends, eval)
573/// 3. Keeping it here prevents circular dependencies
574///
575/// # Relationship to the Grounded Pipeline
576///
577/// `CoreferenceResolver` operates on the **evaluation/convenience layer** (`Entity`),
578/// not the canonical **grounded pipeline** (`Signal` → `Track` → `Identity`).
579///
580/// | Layer | Type | `CoreferenceResolver` role |
581/// |-------|------|----------------------------|
582/// | Detection (L1) | `Entity` | Input: mentions to cluster |
583/// | Coref (L2) | `Entity.canonical_id` | Output: cluster assignment |
584/// | Linking (L3) | `Identity` | (not covered by this trait) |
585///
586/// For integration with `GroundedDocument`, use backends that produce
587/// `Signal` + `Track` directly (e.g., `anno::backends::MentionRankingCoref`).
588///
589/// # Example Implementation
590///
591/// ```rust
592/// use anno_core::{CoreferenceResolver, Entity, EntityType};
593///
594/// struct ExactMatchResolver;
595///
596/// impl CoreferenceResolver for ExactMatchResolver {
597///     fn resolve(&self, entities: &[Entity]) -> Vec<Entity> {
598///         // Trivially return entities unchanged for this example
599///         entities.to_vec()
600///     }
601///
602///     fn name(&self) -> &'static str {
603///         "exact-match"
604///     }
605/// }
606///
607/// let resolver = ExactMatchResolver;
608/// assert_eq!(resolver.name(), "exact-match");
609/// ```
610pub trait CoreferenceResolver: Send + Sync {
611    /// Resolve coreference, assigning canonical IDs to entities.
612    ///
613    /// Each entity in the output will have a `canonical_id` field set.
614    /// Entities with the same `canonical_id` are coreferent (refer to the
615    /// same real-world entity).
616    ///
617    /// # Invariants
618    ///
619    /// - Every output entity has `canonical_id.is_some()`
620    /// - Coreferent entities share the same `canonical_id`
621    /// - Singleton mentions get unique `canonical_id` values
622    fn resolve(&self, entities: &[Entity]) -> Vec<Entity>;
623
624    /// Resolve directly to chains.
625    ///
626    /// A chain groups all mentions of the same entity together.
627    /// This is often the desired output format for evaluation and
628    /// downstream tasks.
629    fn resolve_to_chains(&self, entities: &[Entity]) -> Vec<CorefChain> {
630        let resolved = self.resolve(entities);
631        entities_to_chains(&resolved)
632    }
633
634    /// Get resolver name.
635    ///
636    /// Used for logging, metrics, and result attribution.
637    fn name(&self) -> &'static str;
638}
639
640// =============================================================================
641// Tests
642// =============================================================================
643
644#[cfg(test)]
645mod tests {
646    use super::*;
647
648    #[test]
649    fn test_mention_creation() {
650        let m = Mention::new("John", 0, 4);
651        assert_eq!(m.text, "John");
652        assert_eq!(m.start, 0);
653        assert_eq!(m.end, 4);
654        assert_eq!(m.len(), 4);
655    }
656
657    #[test]
658    fn test_mention_overlap() {
659        let m1 = Mention::new("John Smith", 0, 10);
660        let m2 = Mention::new("Smith", 5, 10);
661        let m3 = Mention::new("works", 11, 16);
662
663        assert!(m1.overlaps(&m2));
664        assert!(!m1.overlaps(&m3));
665        assert!(!m2.overlaps(&m3));
666    }
667
668    #[test]
669    fn test_chain_creation() {
670        let mentions = vec![
671            Mention::new("John", 0, 4),
672            Mention::new("he", 20, 22),
673            Mention::new("him", 40, 43),
674        ];
675        let chain = CorefChain::new(mentions);
676
677        assert_eq!(chain.len(), 3);
678        assert!(!chain.is_singleton());
679        assert_eq!(chain.link_count(), 2); // Minimum links to connect
680    }
681
682    #[test]
683    fn test_chain_links() {
684        let mentions = vec![
685            Mention::new("a", 0, 1),
686            Mention::new("b", 2, 3),
687            Mention::new("c", 4, 5),
688        ];
689        let chain = CorefChain::new(mentions);
690
691        // All pairs: (a,b), (a,c), (b,c) = 3 pairs
692        assert_eq!(chain.all_pairs().len(), 3);
693    }
694
695    #[test]
696    fn test_singleton_chain() {
697        let m = Mention::new("entity", 0, 6);
698        let chain = CorefChain::singleton(m);
699
700        assert!(chain.is_singleton());
701        assert_eq!(chain.link_count(), 0);
702        assert!(chain.all_pairs().is_empty());
703    }
704
705    #[test]
706    fn test_document() {
707        let text = "John went to the store. He bought milk.";
708        let chain = CorefChain::new(vec![Mention::new("John", 0, 4), Mention::new("He", 24, 26)]);
709        let doc = CorefDocument::new(text, vec![chain]);
710
711        assert_eq!(doc.mention_count(), 2);
712        assert_eq!(doc.chain_count(), 1);
713        assert_eq!(doc.non_singleton_count(), 1);
714    }
715
716    #[test]
717    fn test_mention_to_chain_index() {
718        let chain1 = CorefChain::new(vec![Mention::new("John", 0, 4), Mention::new("he", 20, 22)]);
719        let chain2 = CorefChain::new(vec![
720            Mention::new("Mary", 5, 9),
721            Mention::new("she", 30, 33),
722        ]);
723        let doc = CorefDocument::new("text", vec![chain1, chain2]);
724
725        let index = doc.mention_to_chain_index();
726        assert_eq!(index.get(&(0, 4)), Some(&0));
727        assert_eq!(index.get(&(20, 22)), Some(&0));
728        assert_eq!(index.get(&(5, 9)), Some(&1));
729        assert_eq!(index.get(&(30, 33)), Some(&1));
730    }
731
732    // =========================================================================
733    // Edge case tests
734    // =========================================================================
735
736    #[test]
737    fn test_unicode_mention_offsets() {
738        // "北京 Beijing" — character offsets, not byte offsets.
739        // "北" is 3 bytes in UTF-8 but 1 character.
740        let m = Mention::new("北京", 0, 2); // 2 characters, not 6 bytes
741        assert_eq!(m.len(), 2);
742        assert_eq!(m.span_id(), (0, 2));
743        assert!(!m.is_empty());
744    }
745
746    #[test]
747    fn test_zero_length_mention() {
748        // Zero anaphora / empty mention at position 5.
749        let m = Mention::new("", 5, 5);
750        assert!(m.is_empty());
751        assert_eq!(m.len(), 0);
752        assert_eq!(m.span_id(), (5, 5));
753    }
754
755    #[test]
756    fn test_empty_chain() {
757        let chain = CorefChain::new(vec![]);
758        assert!(chain.is_empty());
759        assert_eq!(chain.link_count(), 0);
760        assert!(chain.all_pairs().is_empty());
761        assert!(chain.first().is_none());
762        assert!(chain.canonical_mention().is_none());
763    }
764
765    #[test]
766    fn test_chain_sorting_out_of_order() {
767        // Mentions given out of document order should be sorted by (start, end).
768        let chain = CorefChain::new(vec![
769            Mention::new("c", 20, 21),
770            Mention::new("a", 0, 1),
771            Mention::new("b", 10, 11),
772        ]);
773        assert_eq!(chain.mentions[0].text, "a");
774        assert_eq!(chain.mentions[1].text, "b");
775        assert_eq!(chain.mentions[2].text, "c");
776    }
777
778    #[test]
779    fn test_chain_sorting_ties_broken_by_end() {
780        // Same start, different end: shorter span first.
781        let chain = CorefChain::new(vec![
782            Mention::new("John Smith", 0, 10),
783            Mention::new("John", 0, 4),
784        ]);
785        assert_eq!(chain.mentions[0].text, "John");
786        assert_eq!(chain.mentions[1].text, "John Smith");
787    }
788
789    #[test]
790    fn test_entities_to_chains_grouped() {
791        use super::super::entity::EntityType;
792        use super::super::types::CanonicalId;
793
794        let e1 = super::super::Entity::new("John", EntityType::Person, 0, 4, 0.9)
795            .with_canonical_id(1_u64);
796        let e2 = super::super::Entity::new("he", EntityType::Person, 20, 22, 0.8)
797            .with_canonical_id(1_u64);
798        let e3 = super::super::Entity::new("Mary", EntityType::Person, 5, 9, 0.95)
799            .with_canonical_id(2_u64);
800
801        let chains = entities_to_chains(&[e1, e2, e3]);
802
803        // Two canonical_ids -> two chains
804        assert_eq!(chains.len(), 2);
805
806        // Find the chain with cluster_id=1 (John + he)
807        let chain1 = chains
808            .iter()
809            .find(|c| c.cluster_id == Some(CanonicalId::new(1)))
810            .expect("chain with id=1");
811        assert_eq!(chain1.len(), 2);
812
813        // Find the chain with cluster_id=2 (Mary)
814        let chain2 = chains
815            .iter()
816            .find(|c| c.cluster_id == Some(CanonicalId::new(2)))
817            .expect("chain with id=2");
818        assert_eq!(chain2.len(), 1);
819    }
820
821    #[test]
822    fn test_entities_to_chains_singletons() {
823        use super::super::entity::EntityType;
824
825        // Entities without canonical_id become individual singleton chains.
826        let e1 = super::super::Entity::new("Paris", EntityType::Location, 0, 5, 0.9);
827        let e2 = super::super::Entity::new("London", EntityType::Location, 10, 16, 0.85);
828
829        let chains = entities_to_chains(&[e1, e2]);
830        assert_eq!(chains.len(), 2);
831        assert!(chains.iter().all(|c| c.is_singleton()));
832    }
833
834    #[test]
835    fn test_entities_to_chains_empty() {
836        let chains = entities_to_chains(&[]);
837        assert!(chains.is_empty());
838    }
839
840    #[test]
841    fn test_without_singletons_filters() {
842        let singleton = CorefChain::singleton(Mention::new("solo", 0, 4));
843        let multi = CorefChain::new(vec![
844            Mention::new("John", 10, 14),
845            Mention::new("he", 20, 22),
846        ]);
847        let doc = CorefDocument::new("text", vec![singleton, multi]);
848
849        let filtered = doc.without_singletons();
850        assert_eq!(filtered.chain_count(), 1);
851        assert_eq!(filtered.chains[0].len(), 2);
852        assert!(!filtered.includes_singletons);
853    }
854
855    #[test]
856    fn test_without_singletons_preserves_non_singletons() {
857        let c1 = CorefChain::new(vec![Mention::new("a", 0, 1), Mention::new("b", 2, 3)]);
858        let c2 = CorefChain::new(vec![
859            Mention::new("x", 10, 11),
860            Mention::new("y", 12, 13),
861            Mention::new("z", 14, 15),
862        ]);
863        let doc = CorefDocument::new("text", vec![c1.clone(), c2.clone()]);
864
865        let filtered = doc.without_singletons();
866        assert_eq!(filtered.chain_count(), 2);
867    }
868
869    #[test]
870    fn test_without_singletons_all_singletons() {
871        let s1 = CorefChain::singleton(Mention::new("a", 0, 1));
872        let s2 = CorefChain::singleton(Mention::new("b", 2, 3));
873        let doc = CorefDocument::new("text", vec![s1, s2]);
874
875        let filtered = doc.without_singletons();
876        assert!(filtered.chains.is_empty());
877    }
878
879    #[test]
880    fn test_overlaps_adjacent_non_overlapping() {
881        // [0,5) and [5,10) are adjacent but NOT overlapping (half-open intervals).
882        let m1 = Mention::new("hello", 0, 5);
883        let m2 = Mention::new("world", 5, 10);
884        assert!(!m1.overlaps(&m2));
885        assert!(!m2.overlaps(&m1));
886    }
887
888    #[test]
889    fn test_overlaps_nested() {
890        // [0,10) fully contains [2,5).
891        let outer = Mention::new("the big dog", 0, 10);
892        let inner = Mention::new("big", 2, 5);
893        assert!(outer.overlaps(&inner));
894        assert!(inner.overlaps(&outer));
895    }
896
897    #[test]
898    fn test_chain_with_id() {
899        let chain = CorefChain::with_id(
900            vec![Mention::new("John", 0, 4), Mention::new("he", 10, 12)],
901            42_u64,
902        );
903        assert_eq!(
904            chain.canonical_id(),
905            Some(super::super::types::CanonicalId::new(42))
906        );
907        assert_eq!(
908            chain.cluster_id,
909            Some(super::super::types::CanonicalId::new(42))
910        );
911        // Mentions should still be sorted.
912        assert_eq!(chain.mentions[0].text, "John");
913    }
914}
915
916#[cfg(test)]
917mod proptests {
918    #![allow(clippy::unwrap_used)]
919    use super::*;
920    use proptest::prelude::*;
921
922    /// Strategy to generate a Mention with bounded offsets.
923    fn arb_mention(max_offset: usize) -> impl Strategy<Value = Mention> {
924        (0usize..max_offset, 1usize..500)
925            .prop_map(|(start, len)| Mention::new(format!("m_{}", start), start, start + len))
926    }
927
928    proptest! {
929        /// Mentions sort by (start, end) consistently after CorefChain::new.
930        #[test]
931        fn mention_ordering_after_chain_construction(
932            mentions in proptest::collection::vec(arb_mention(10000), 1..20),
933        ) {
934            let chain = CorefChain::new(mentions);
935            for w in chain.mentions.windows(2) {
936                prop_assert!(
937                    (w[0].start, w[0].end) <= (w[1].start, w[1].end),
938                    "mentions must be sorted by (start, end): ({},{}) vs ({},{})",
939                    w[0].start, w[0].end, w[1].start, w[1].end
940                );
941            }
942        }
943
944        /// CorefChain constructed with at least one mention is never empty.
945        #[test]
946        fn coref_chain_non_empty(
947            mentions in proptest::collection::vec(arb_mention(10000), 1..20),
948        ) {
949            let n = mentions.len();
950            let chain = CorefChain::new(mentions);
951            prop_assert!(!chain.is_empty());
952            prop_assert_eq!(chain.len(), n);
953        }
954
955        /// CorefChain::singleton always produces a chain with exactly one mention.
956        #[test]
957        fn coref_chain_singleton_has_one(start in 0usize..10000, len in 1usize..500) {
958            let m = Mention::new("x", start, start + len);
959            let chain = CorefChain::singleton(m);
960            prop_assert!(chain.is_singleton());
961            prop_assert_eq!(chain.len(), 1);
962            prop_assert_eq!(chain.link_count(), 0);
963        }
964
965        /// Mention::overlaps is symmetric.
966        #[test]
967        fn mention_overlap_symmetric(
968            s1 in 0usize..10000, len1 in 1usize..500,
969            s2 in 0usize..10000, len2 in 1usize..500,
970        ) {
971            let m1 = Mention::new("a", s1, s1 + len1);
972            let m2 = Mention::new("b", s2, s2 + len2);
973            prop_assert_eq!(m1.overlaps(&m2), m2.overlaps(&m1));
974        }
975
976        /// Mention serde roundtrip preserves all fields.
977        #[test]
978        fn mention_serde_roundtrip(
979            start in 0usize..10000, len in 1usize..500,
980        ) {
981            let m = Mention::new(format!("mention_{}", start), start, start + len);
982            let json = serde_json::to_string(&m).unwrap();
983            let m2: Mention = serde_json::from_str(&json).unwrap();
984            prop_assert_eq!(&m, &m2);
985        }
986    }
987}