anno_core/core/coref.rs
1//! Coreference resolution data structures.
2//!
3//! Provides types for representing coreference chains (clusters of mentions
4//! that refer to the same entity) and utilities for working with them.
5//!
6//! # Terminology
7//!
8//! - **Mention**: A span of text referring to an entity (e.g., "John", "he", "the CEO")
9//! - **Chain/Cluster**: A set of mentions that corefer (refer to the same entity)
10//! - **Singleton**: A chain with only one mention (entity mentioned only once)
11//! - **Antecedent**: An earlier mention that a pronoun/noun phrase refers to
12//!
13//! # Example
14//!
15//! ```rust
16//! use anno_core::core::coref::{Mention, CorefChain, CorefDocument};
17//!
18//! // "John went to the store. He bought milk."
19//! let john = Mention::new("John", 0, 4);
20//! let he = Mention::new("He", 25, 27);
21//!
22//! let chain = CorefChain::new(vec![john, he]);
23//! assert_eq!(chain.len(), 2);
24//! assert!(!chain.is_singleton());
25//! ```
26
27use super::Entity;
28use serde::{Deserialize, Serialize};
29use std::collections::{HashMap, HashSet};
30
31// Re-export MentionType for convenience
32pub use super::types::MentionType;
33
34// =============================================================================
35// Mention
36// =============================================================================
37
38/// A single mention (text span) that may corefer with other mentions.
39///
40/// Mentions are comparable by span position, not by text content.
41/// Two mentions with identical text at different positions are distinct.
42///
43/// # Character vs Byte Offsets
44///
45/// `start` and `end` are **character** offsets, not byte offsets.
46/// For "北京 Beijing", the character offsets are:
47/// - "北" = 0..1 (but 3 bytes in UTF-8)
48/// - "京" = 1..2 (but 3 bytes)
49/// - " " = 2..3
50/// - "Beijing" = 3..10
51///
52/// Use `text.chars().skip(start).take(end - start).collect()` to extract.
53///
54/// # Head Span
55///
56/// The `head_start`/`head_end` fields mark the syntactic head for head-match
57/// evaluation (used in CEAF-e, LEA metrics). In "the former president of France",
58/// the head is "president" - the noun that determines agreement.
59#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub struct Mention {
61 /// The mention text (surface form).
62 pub text: String,
63 /// Start character offset (inclusive, 0-indexed).
64 pub start: usize,
65 /// End character offset (exclusive).
66 pub end: usize,
67 /// Head word start (for head-match metrics like CEAF).
68 pub head_start: Option<usize>,
69 /// Head word end.
70 pub head_end: Option<usize>,
71 /// Entity type if known (e.g., "PER", "ORG").
72 pub entity_type: Option<String>,
73 /// Mention category: Pronominal, Proper, Nominal, Zero.
74 pub mention_type: Option<MentionType>,
75}
76
77impl Mention {
78 /// `Mention::new("John", 0, 4)` creates a mention for "John" at characters 0..4.
79 ///
80 /// Offsets are character positions, not byte positions.
81 ///
82 /// ```
83 /// use anno_core::Mention;
84 ///
85 /// let m = Mention::new("John", 0, 4);
86 /// assert_eq!(m.text, "John");
87 /// assert_eq!(m.len(), 4);
88 /// assert_eq!(m.span_id(), (0, 4));
89 /// ```
90 #[must_use]
91 pub fn new(text: impl Into<String>, start: usize, end: usize) -> Self {
92 Self {
93 text: text.into(),
94 start,
95 end,
96 head_start: None,
97 head_end: None,
98 entity_type: None,
99 mention_type: None,
100 }
101 }
102
103 /// Mention with head span for head-match evaluation.
104 ///
105 /// The head is the syntactic nucleus: in "the former president", head is "president".
106 ///
107 /// ```
108 /// # use anno_core::core::coref::Mention;
109 /// let m = Mention::with_head("the former president", 0, 20, 11, 20);
110 /// assert_eq!(m.head_start, Some(11)); // "president" starts at 11
111 /// ```
112 #[must_use]
113 pub fn with_head(
114 text: impl Into<String>,
115 start: usize,
116 end: usize,
117 head_start: usize,
118 head_end: usize,
119 ) -> Self {
120 Self {
121 text: text.into(),
122 start,
123 end,
124 head_start: Some(head_start),
125 head_end: Some(head_end),
126 entity_type: None,
127 mention_type: None,
128 }
129 }
130
131 /// Mention with type annotation for type-aware evaluation.
132 ///
133 /// ```
134 /// # use anno_core::core::coref::Mention;
135 /// # use anno_core::core::types::MentionType;
136 /// let pronoun = Mention::with_type("he", 25, 27, MentionType::Pronominal);
137 /// let proper = Mention::with_type("John Smith", 0, 10, MentionType::Proper);
138 /// ```
139 #[must_use]
140 pub fn with_type(
141 text: impl Into<String>,
142 start: usize,
143 end: usize,
144 mention_type: MentionType,
145 ) -> Self {
146 Self {
147 text: text.into(),
148 start,
149 end,
150 head_start: None,
151 head_end: None,
152 entity_type: None,
153 mention_type: Some(mention_type),
154 }
155 }
156
157 /// True if spans share any characters: `[0,5)` overlaps `[3,8)`.
158 #[must_use]
159 pub fn overlaps(&self, other: &Mention) -> bool {
160 self.start < other.end && other.start < self.end
161 }
162
163 /// True if spans are identical: same start AND end.
164 #[must_use]
165 pub fn span_matches(&self, other: &Mention) -> bool {
166 self.start == other.start && self.end == other.end
167 }
168
169 /// Span length in characters. Returns 0 if `end <= start`.
170 #[must_use]
171 pub fn len(&self) -> usize {
172 self.end.saturating_sub(self.start)
173 }
174
175 /// True if span has zero length.
176 #[must_use]
177 pub fn is_empty(&self) -> bool {
178 self.len() == 0
179 }
180
181 /// `(start, end)` tuple for use in hash sets and comparisons.
182 #[must_use]
183 pub fn span_id(&self) -> (usize, usize) {
184 (self.start, self.end)
185 }
186
187 /// Head-span tuple `(head_start, head_end)` for head-match scoring.
188 ///
189 /// Returns the head span if both `head_start` and `head_end` are set,
190 /// otherwise falls back to the full span `(start, end)`.
191 ///
192 /// Head-match evaluation is standard in CRAC shared tasks: two mentions
193 /// match if their syntactic heads overlap, even when full spans differ
194 /// (e.g., "the president" vs "the former president of France" both have
195 /// head "president").
196 ///
197 /// ```
198 /// # use anno_core::core::coref::Mention;
199 /// let m = Mention::with_head("the former president", 0, 20, 11, 20);
200 /// assert_eq!(m.span_id_head(), (11, 20));
201 ///
202 /// let m2 = Mention::new("John", 0, 4);
203 /// assert_eq!(m2.span_id_head(), (0, 4)); // falls back to full span
204 /// ```
205 #[must_use]
206 pub fn span_id_head(&self) -> (usize, usize) {
207 match (self.head_start, self.head_end) {
208 (Some(hs), Some(he)) => (hs, he),
209 _ => (self.start, self.end),
210 }
211 }
212}
213
214impl std::fmt::Display for Mention {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 write!(f, "\"{}\" [{}-{})", self.text, self.start, self.end)
217 }
218}
219
220// =============================================================================
221// CorefChain (Cluster)
222// =============================================================================
223
224/// A coreference chain: mentions that all refer to the same entity.
225///
226/// ```
227/// # use anno_core::core::coref::{CorefChain, Mention};
228/// // "John went to the store. He bought milk."
229/// // ^^^^ ^^
230/// let john = Mention::new("John", 0, 4);
231/// let he = Mention::new("He", 25, 27);
232///
233/// let chain = CorefChain::new(vec![john, he]);
234/// assert_eq!(chain.len(), 2);
235/// assert!(!chain.is_singleton());
236/// ```
237///
238/// # Note
239///
240/// This type is for **evaluation and intermediate processing**. For production pipelines,
241/// use [`Track`](super::grounded::Track) which integrates with the Signal/Track/Identity hierarchy.
242#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
243pub struct CorefChain {
244 /// Mentions in document order (sorted by start position).
245 pub mentions: Vec<Mention>,
246 /// Cluster ID from the source data, if any.
247 pub cluster_id: Option<super::types::CanonicalId>,
248 /// Entity type shared by all mentions (e.g., "PERSON").
249 pub entity_type: Option<String>,
250}
251
252impl CorefChain {
253 /// Build a chain from mentions. Sorts by position automatically.
254 ///
255 /// ```
256 /// # use anno_core::core::coref::{CorefChain, Mention};
257 /// let chain = CorefChain::new(vec![
258 /// Mention::new("she", 50, 53),
259 /// Mention::new("Dr. Smith", 0, 9), // out of order
260 /// ]);
261 /// assert_eq!(chain.mentions[0].text, "Dr. Smith"); // sorted
262 /// ```
263 #[must_use]
264 pub fn new(mut mentions: Vec<Mention>) -> Self {
265 mentions.sort_by_key(|m| (m.start, m.end));
266 Self {
267 mentions,
268 cluster_id: None,
269 entity_type: None,
270 }
271 }
272
273 /// Build a chain with an explicit cluster ID.
274 #[must_use]
275 pub fn with_id(
276 mut mentions: Vec<Mention>,
277 cluster_id: impl Into<super::types::CanonicalId>,
278 ) -> Self {
279 mentions.sort_by_key(|m| (m.start, m.end));
280 Self {
281 mentions,
282 cluster_id: Some(cluster_id.into()),
283 entity_type: None,
284 }
285 }
286
287 /// A chain with exactly one mention (entity mentioned only once).
288 #[must_use]
289 pub fn singleton(mention: Mention) -> Self {
290 Self {
291 mentions: vec![mention],
292 cluster_id: None,
293 entity_type: None,
294 }
295 }
296
297 /// Number of mentions. A chain with 3 mentions has 2 implicit "links".
298 #[must_use]
299 pub fn len(&self) -> usize {
300 self.mentions.len()
301 }
302
303 /// True if chain has no mentions. Shouldn't happen in valid data.
304 #[must_use]
305 pub fn is_empty(&self) -> bool {
306 self.mentions.is_empty()
307 }
308
309 /// True if chain has exactly one mention (singleton entity).
310 #[must_use]
311 pub fn is_singleton(&self) -> bool {
312 self.mentions.len() == 1
313 }
314
315 /// All pairwise links. For MUC: `n` mentions = `n*(n-1)/2` links.
316 ///
317 /// ```
318 /// # use anno_core::core::coref::{CorefChain, Mention};
319 /// let chain = CorefChain::new(vec![
320 /// Mention::new("A", 0, 1),
321 /// Mention::new("B", 2, 3),
322 /// Mention::new("C", 4, 5),
323 /// ]);
324 /// assert_eq!(chain.links().len(), 3); // A-B, A-C, B-C
325 /// ```
326 #[must_use]
327 pub fn links(&self) -> Vec<(&Mention, &Mention)> {
328 let mut links = Vec::new();
329 for i in 0..self.mentions.len() {
330 for j in (i + 1)..self.mentions.len() {
331 links.push((&self.mentions[i], &self.mentions[j]));
332 }
333 }
334 links
335 }
336
337 /// Number of coreference links.
338 ///
339 /// For a chain of n mentions: n*(n-1)/2 pairs, but only n-1 links needed
340 /// to connect all mentions (spanning tree).
341 #[must_use]
342 pub fn link_count(&self) -> usize {
343 if self.mentions.len() <= 1 {
344 0
345 } else {
346 self.mentions.len() - 1
347 }
348 }
349
350 /// Get all pairwise mention combinations (for B³, CEAF).
351 #[must_use]
352 pub fn all_pairs(&self) -> Vec<(&Mention, &Mention)> {
353 self.links() // Same as links for non-directed pairs
354 }
355
356 /// Check if chain contains a mention with given span.
357 #[must_use]
358 pub fn contains_span(&self, start: usize, end: usize) -> bool {
359 self.mentions
360 .iter()
361 .any(|m| m.start == start && m.end == end)
362 }
363
364 /// Get first mention (usually the most salient/representative).
365 #[must_use]
366 pub fn first(&self) -> Option<&Mention> {
367 self.mentions.first()
368 }
369
370 /// Get set of mention span IDs for set operations.
371 #[must_use]
372 pub fn mention_spans(&self) -> HashSet<(usize, usize)> {
373 self.mentions.iter().map(|m| m.span_id()).collect()
374 }
375
376 /// Get the canonical (representative) mention for this chain.
377 ///
378 /// Prefers proper nouns over other mention types, then longest mention.
379 /// Falls back to first mention if no proper noun exists.
380 #[must_use]
381 pub fn canonical_mention(&self) -> Option<&Mention> {
382 // Prefer proper noun mentions
383 let proper = self
384 .mentions
385 .iter()
386 .filter(|m| m.mention_type == Some(MentionType::Proper))
387 .max_by_key(|m| m.text.len());
388
389 if proper.is_some() {
390 return proper;
391 }
392
393 // Fall back to longest mention (likely most informative)
394 self.mentions.iter().max_by_key(|m| m.text.len())
395 }
396
397 /// Get the canonical ID for this chain (cluster_id if set).
398 #[must_use]
399 pub fn canonical_id(&self) -> Option<super::types::CanonicalId> {
400 self.cluster_id
401 }
402}
403
404impl std::fmt::Display for CorefChain {
405 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406 let mentions: Vec<String> = self
407 .mentions
408 .iter()
409 .map(|m| format!("\"{}\"", m.text))
410 .collect();
411 write!(f, "[{}]", mentions.join(", "))
412 }
413}
414
415// =============================================================================
416// CorefDocument
417// =============================================================================
418
419/// A document with coreference annotations.
420///
421/// Contains the source text and all coreference chains.
422#[derive(Debug, Clone, Serialize, Deserialize)]
423pub struct CorefDocument {
424 /// Document text.
425 pub text: String,
426 /// Document identifier.
427 pub doc_id: Option<String>,
428 /// Coreference chains (clusters).
429 pub chains: Vec<CorefChain>,
430 /// Whether singletons are included.
431 pub includes_singletons: bool,
432}
433
434impl CorefDocument {
435 /// Create a new document with chains.
436 ///
437 /// ```
438 /// use anno_core::core::coref::{CorefDocument, CorefChain, Mention};
439 ///
440 /// let chain = CorefChain::new(vec![
441 /// Mention::new("John", 0, 4),
442 /// Mention::new("He", 24, 26),
443 /// ]);
444 /// let doc = CorefDocument::new("John went to the store. He bought milk.", vec![chain]);
445 /// assert_eq!(doc.mention_count(), 2);
446 /// assert_eq!(doc.chain_count(), 1);
447 /// ```
448 #[must_use]
449 pub fn new(text: impl Into<String>, chains: Vec<CorefChain>) -> Self {
450 Self {
451 text: text.into(),
452 doc_id: None,
453 chains,
454 includes_singletons: false,
455 }
456 }
457
458 /// Create document with ID.
459 #[must_use]
460 pub fn with_id(
461 text: impl Into<String>,
462 doc_id: impl Into<String>,
463 chains: Vec<CorefChain>,
464 ) -> Self {
465 Self {
466 text: text.into(),
467 doc_id: Some(doc_id.into()),
468 chains,
469 includes_singletons: false,
470 }
471 }
472
473 /// Total number of mentions across all chains.
474 #[must_use]
475 pub fn mention_count(&self) -> usize {
476 self.chains.iter().map(|c| c.len()).sum()
477 }
478
479 /// Number of chains (clusters).
480 #[must_use]
481 pub fn chain_count(&self) -> usize {
482 self.chains.len()
483 }
484
485 /// Number of non-singleton chains.
486 #[must_use]
487 pub fn non_singleton_count(&self) -> usize {
488 self.chains.iter().filter(|c| !c.is_singleton()).count()
489 }
490
491 /// Get all mentions in document order.
492 #[must_use]
493 pub fn all_mentions(&self) -> Vec<&Mention> {
494 let mut mentions: Vec<&Mention> = self.chains.iter().flat_map(|c| &c.mentions).collect();
495 mentions.sort_by_key(|m| (m.start, m.end));
496 mentions
497 }
498
499 /// Find which chain contains a mention span.
500 #[must_use]
501 pub fn find_chain(&self, start: usize, end: usize) -> Option<&CorefChain> {
502 self.chains.iter().find(|c| c.contains_span(start, end))
503 }
504
505 /// Build mention-to-chain index for fast lookup.
506 #[must_use]
507 pub fn mention_to_chain_index(&self) -> HashMap<(usize, usize), usize> {
508 let mut index = HashMap::new();
509 for (chain_idx, chain) in self.chains.iter().enumerate() {
510 for mention in &chain.mentions {
511 index.insert(mention.span_id(), chain_idx);
512 }
513 }
514 index
515 }
516
517 /// Filter to only non-singleton chains.
518 #[must_use]
519 pub fn without_singletons(&self) -> Self {
520 Self {
521 text: self.text.clone(),
522 doc_id: self.doc_id.clone(),
523 chains: self
524 .chains
525 .iter()
526 .filter(|c| !c.is_singleton())
527 .cloned()
528 .collect(),
529 includes_singletons: false,
530 }
531 }
532}
533
534// =============================================================================
535// Conversion from Entity to Mention
536// =============================================================================
537
538impl From<&Entity> for Mention {
539 fn from(entity: &Entity) -> Self {
540 Self {
541 text: entity.text.clone(),
542 start: entity.start(),
543 end: entity.end(),
544 head_start: None,
545 head_end: None,
546 entity_type: Some(entity.entity_type.as_label().to_string()),
547 mention_type: entity.mention_type,
548 }
549 }
550}
551
552/// Convert entities with canonical_id to coreference chains.
553///
554/// Entities sharing the same `canonical_id` are grouped into a chain.
555#[must_use]
556pub fn entities_to_chains(entities: &[Entity]) -> Vec<CorefChain> {
557 let mut clusters: HashMap<u64, Vec<Mention>> = HashMap::new();
558 let mut singletons: Vec<Mention> = Vec::new();
559
560 for entity in entities {
561 let mention = Mention::from(entity);
562 if let Some(canonical_id) = entity.canonical_id {
563 clusters
564 .entry(canonical_id.get())
565 .or_default()
566 .push(mention);
567 } else {
568 singletons.push(mention);
569 }
570 }
571
572 let mut chains: Vec<CorefChain> = clusters
573 .into_iter()
574 .map(|(id, mentions)| CorefChain::with_id(mentions, id))
575 .collect();
576
577 // Add singletons as individual chains
578 for mention in singletons {
579 chains.push(CorefChain::singleton(mention));
580 }
581
582 chains
583}
584
585// =============================================================================
586// CoreferenceResolver Trait
587// =============================================================================
588
589/// Trait for coreference resolution algorithms.
590///
591/// Implementors take a set of entity mentions and cluster them into
592/// coreference chains (groups of mentions referring to the same entity).
593///
594/// # Design Philosophy
595///
596/// This trait lives in `anno::core` because:
597/// 1. It depends only on core types (`Entity`, `CorefChain`)
598/// 2. Multiple crates need to implement it (backends, eval)
599/// 3. Keeping it here prevents circular dependencies
600///
601/// # Relationship to the Grounded Pipeline
602///
603/// `CoreferenceResolver` operates on the **evaluation/convenience layer** (`Entity`),
604/// not the canonical **grounded pipeline** (`Signal` → `Track` → `Identity`).
605///
606/// | Layer | Type | `CoreferenceResolver` role |
607/// |-------|------|----------------------------|
608/// | Detection (L1) | `Entity` | Input: mentions to cluster |
609/// | Coref (L2) | `Entity.canonical_id` | Output: cluster assignment |
610/// | Linking (L3) | `Identity` | (not covered by this trait) |
611///
612/// For integration with `GroundedDocument`, use backends that produce
613/// `Signal` + `Track` directly (e.g., `anno::backends::MentionRankingCoref`).
614///
615/// # Example Implementation
616///
617/// ```rust
618/// use anno_core::{CoreferenceResolver, Entity, EntityType};
619///
620/// struct ExactMatchResolver;
621///
622/// impl CoreferenceResolver for ExactMatchResolver {
623/// fn resolve(&self, entities: &[Entity]) -> Vec<Entity> {
624/// // Trivially return entities unchanged for this example
625/// entities.to_vec()
626/// }
627///
628/// fn name(&self) -> &'static str {
629/// "exact-match"
630/// }
631/// }
632///
633/// let resolver = ExactMatchResolver;
634/// assert_eq!(resolver.name(), "exact-match");
635/// ```
636pub trait CoreferenceResolver: Send + Sync {
637 /// Resolve coreference, assigning canonical IDs to entities.
638 ///
639 /// Each entity in the output will have a `canonical_id` field set.
640 /// Entities with the same `canonical_id` are coreferent (refer to the
641 /// same real-world entity).
642 ///
643 /// # Invariants
644 ///
645 /// - Every output entity has `canonical_id.is_some()`
646 /// - Coreferent entities share the same `canonical_id`
647 /// - Singleton mentions get unique `canonical_id` values
648 fn resolve(&self, entities: &[Entity]) -> Vec<Entity>;
649
650 /// Resolve directly to chains.
651 ///
652 /// A chain groups all mentions of the same entity together.
653 /// This is often the desired output format for evaluation and
654 /// downstream tasks.
655 fn resolve_to_chains(&self, entities: &[Entity]) -> Vec<CorefChain> {
656 let resolved = self.resolve(entities);
657 entities_to_chains(&resolved)
658 }
659
660 /// Get resolver name.
661 ///
662 /// Used for logging, metrics, and result attribution.
663 fn name(&self) -> &'static str;
664}
665
666// =============================================================================
667// Tests
668// =============================================================================
669
670#[cfg(test)]
671mod tests {
672 use super::*;
673
674 #[test]
675 fn test_mention_creation() {
676 let m = Mention::new("John", 0, 4);
677 assert_eq!(m.text, "John");
678 assert_eq!(m.start, 0);
679 assert_eq!(m.end, 4);
680 assert_eq!(m.len(), 4);
681 }
682
683 #[test]
684 fn test_mention_overlap() {
685 let m1 = Mention::new("John Smith", 0, 10);
686 let m2 = Mention::new("Smith", 5, 10);
687 let m3 = Mention::new("works", 11, 16);
688
689 assert!(m1.overlaps(&m2));
690 assert!(!m1.overlaps(&m3));
691 assert!(!m2.overlaps(&m3));
692 }
693
694 #[test]
695 fn test_chain_creation() {
696 let mentions = vec![
697 Mention::new("John", 0, 4),
698 Mention::new("he", 20, 22),
699 Mention::new("him", 40, 43),
700 ];
701 let chain = CorefChain::new(mentions);
702
703 assert_eq!(chain.len(), 3);
704 assert!(!chain.is_singleton());
705 assert_eq!(chain.link_count(), 2); // Minimum links to connect
706 }
707
708 #[test]
709 fn test_chain_links() {
710 let mentions = vec![
711 Mention::new("a", 0, 1),
712 Mention::new("b", 2, 3),
713 Mention::new("c", 4, 5),
714 ];
715 let chain = CorefChain::new(mentions);
716
717 // All pairs: (a,b), (a,c), (b,c) = 3 pairs
718 assert_eq!(chain.all_pairs().len(), 3);
719 }
720
721 #[test]
722 fn test_singleton_chain() {
723 let m = Mention::new("entity", 0, 6);
724 let chain = CorefChain::singleton(m);
725
726 assert!(chain.is_singleton());
727 assert_eq!(chain.link_count(), 0);
728 assert!(chain.all_pairs().is_empty());
729 }
730
731 #[test]
732 fn test_document() {
733 let text = "John went to the store. He bought milk.";
734 let chain = CorefChain::new(vec![Mention::new("John", 0, 4), Mention::new("He", 24, 26)]);
735 let doc = CorefDocument::new(text, vec![chain]);
736
737 assert_eq!(doc.mention_count(), 2);
738 assert_eq!(doc.chain_count(), 1);
739 assert_eq!(doc.non_singleton_count(), 1);
740 }
741
742 #[test]
743 fn test_mention_to_chain_index() {
744 let chain1 = CorefChain::new(vec![Mention::new("John", 0, 4), Mention::new("he", 20, 22)]);
745 let chain2 = CorefChain::new(vec![
746 Mention::new("Mary", 5, 9),
747 Mention::new("she", 30, 33),
748 ]);
749 let doc = CorefDocument::new("text", vec![chain1, chain2]);
750
751 let index = doc.mention_to_chain_index();
752 assert_eq!(index.get(&(0, 4)), Some(&0));
753 assert_eq!(index.get(&(20, 22)), Some(&0));
754 assert_eq!(index.get(&(5, 9)), Some(&1));
755 assert_eq!(index.get(&(30, 33)), Some(&1));
756 }
757
758 // =========================================================================
759 // Edge case tests
760 // =========================================================================
761
762 #[test]
763 fn test_unicode_mention_offsets() {
764 // "北京 Beijing" — character offsets, not byte offsets.
765 // "北" is 3 bytes in UTF-8 but 1 character.
766 let m = Mention::new("北京", 0, 2); // 2 characters, not 6 bytes
767 assert_eq!(m.len(), 2);
768 assert_eq!(m.span_id(), (0, 2));
769 assert!(!m.is_empty());
770 }
771
772 #[test]
773 fn test_zero_length_mention() {
774 // Zero anaphora / empty mention at position 5.
775 let m = Mention::new("", 5, 5);
776 assert!(m.is_empty());
777 assert_eq!(m.len(), 0);
778 assert_eq!(m.span_id(), (5, 5));
779 }
780
781 #[test]
782 fn test_empty_chain() {
783 let chain = CorefChain::new(vec![]);
784 assert!(chain.is_empty());
785 assert_eq!(chain.link_count(), 0);
786 assert!(chain.all_pairs().is_empty());
787 assert!(chain.first().is_none());
788 assert!(chain.canonical_mention().is_none());
789 }
790
791 #[test]
792 fn test_chain_sorting_out_of_order() {
793 // Mentions given out of document order should be sorted by (start, end).
794 let chain = CorefChain::new(vec![
795 Mention::new("c", 20, 21),
796 Mention::new("a", 0, 1),
797 Mention::new("b", 10, 11),
798 ]);
799 assert_eq!(chain.mentions[0].text, "a");
800 assert_eq!(chain.mentions[1].text, "b");
801 assert_eq!(chain.mentions[2].text, "c");
802 }
803
804 #[test]
805 fn test_chain_sorting_ties_broken_by_end() {
806 // Same start, different end: shorter span first.
807 let chain = CorefChain::new(vec![
808 Mention::new("John Smith", 0, 10),
809 Mention::new("John", 0, 4),
810 ]);
811 assert_eq!(chain.mentions[0].text, "John");
812 assert_eq!(chain.mentions[1].text, "John Smith");
813 }
814
815 #[test]
816 fn test_entities_to_chains_grouped() {
817 use super::super::entity::EntityType;
818 use super::super::types::CanonicalId;
819
820 let e1 = super::super::Entity::new("John", EntityType::Person, 0, 4, 0.9)
821 .with_canonical_id(1_u64);
822 let e2 = super::super::Entity::new("he", EntityType::Person, 20, 22, 0.8)
823 .with_canonical_id(1_u64);
824 let e3 = super::super::Entity::new("Mary", EntityType::Person, 5, 9, 0.95)
825 .with_canonical_id(2_u64);
826
827 let chains = entities_to_chains(&[e1, e2, e3]);
828
829 // Two canonical_ids -> two chains
830 assert_eq!(chains.len(), 2);
831
832 // Find the chain with cluster_id=1 (John + he)
833 let chain1 = chains
834 .iter()
835 .find(|c| c.cluster_id == Some(CanonicalId::new(1)))
836 .expect("chain with id=1");
837 assert_eq!(chain1.len(), 2);
838
839 // Find the chain with cluster_id=2 (Mary)
840 let chain2 = chains
841 .iter()
842 .find(|c| c.cluster_id == Some(CanonicalId::new(2)))
843 .expect("chain with id=2");
844 assert_eq!(chain2.len(), 1);
845 }
846
847 #[test]
848 fn test_entities_to_chains_singletons() {
849 use super::super::entity::EntityType;
850
851 // Entities without canonical_id become individual singleton chains.
852 let e1 = super::super::Entity::new("Paris", EntityType::Location, 0, 5, 0.9);
853 let e2 = super::super::Entity::new("London", EntityType::Location, 10, 16, 0.85);
854
855 let chains = entities_to_chains(&[e1, e2]);
856 assert_eq!(chains.len(), 2);
857 assert!(chains.iter().all(|c| c.is_singleton()));
858 }
859
860 #[test]
861 fn test_entities_to_chains_empty() {
862 let chains = entities_to_chains(&[]);
863 assert!(chains.is_empty());
864 }
865
866 #[test]
867 fn test_without_singletons_filters() {
868 let singleton = CorefChain::singleton(Mention::new("solo", 0, 4));
869 let multi = CorefChain::new(vec![
870 Mention::new("John", 10, 14),
871 Mention::new("he", 20, 22),
872 ]);
873 let doc = CorefDocument::new("text", vec![singleton, multi]);
874
875 let filtered = doc.without_singletons();
876 assert_eq!(filtered.chain_count(), 1);
877 assert_eq!(filtered.chains[0].len(), 2);
878 assert!(!filtered.includes_singletons);
879 }
880
881 #[test]
882 fn test_without_singletons_preserves_non_singletons() {
883 let c1 = CorefChain::new(vec![Mention::new("a", 0, 1), Mention::new("b", 2, 3)]);
884 let c2 = CorefChain::new(vec![
885 Mention::new("x", 10, 11),
886 Mention::new("y", 12, 13),
887 Mention::new("z", 14, 15),
888 ]);
889 let doc = CorefDocument::new("text", vec![c1.clone(), c2.clone()]);
890
891 let filtered = doc.without_singletons();
892 assert_eq!(filtered.chain_count(), 2);
893 }
894
895 #[test]
896 fn test_without_singletons_all_singletons() {
897 let s1 = CorefChain::singleton(Mention::new("a", 0, 1));
898 let s2 = CorefChain::singleton(Mention::new("b", 2, 3));
899 let doc = CorefDocument::new("text", vec![s1, s2]);
900
901 let filtered = doc.without_singletons();
902 assert!(filtered.chains.is_empty());
903 }
904
905 #[test]
906 fn test_overlaps_adjacent_non_overlapping() {
907 // [0,5) and [5,10) are adjacent but NOT overlapping (half-open intervals).
908 let m1 = Mention::new("hello", 0, 5);
909 let m2 = Mention::new("world", 5, 10);
910 assert!(!m1.overlaps(&m2));
911 assert!(!m2.overlaps(&m1));
912 }
913
914 #[test]
915 fn test_overlaps_nested() {
916 // [0,10) fully contains [2,5).
917 let outer = Mention::new("the big dog", 0, 10);
918 let inner = Mention::new("big", 2, 5);
919 assert!(outer.overlaps(&inner));
920 assert!(inner.overlaps(&outer));
921 }
922
923 #[test]
924 fn test_chain_with_id() {
925 let chain = CorefChain::with_id(
926 vec![Mention::new("John", 0, 4), Mention::new("he", 10, 12)],
927 42_u64,
928 );
929 assert_eq!(
930 chain.canonical_id(),
931 Some(super::super::types::CanonicalId::new(42))
932 );
933 assert_eq!(
934 chain.cluster_id,
935 Some(super::super::types::CanonicalId::new(42))
936 );
937 // Mentions should still be sorted.
938 assert_eq!(chain.mentions[0].text, "John");
939 }
940}
941
942#[cfg(test)]
943mod proptests {
944 #![allow(clippy::unwrap_used)]
945 use super::*;
946 use proptest::prelude::*;
947
948 /// Strategy to generate a Mention with bounded offsets.
949 fn arb_mention(max_offset: usize) -> impl Strategy<Value = Mention> {
950 (0usize..max_offset, 1usize..500)
951 .prop_map(|(start, len)| Mention::new(format!("m_{}", start), start, start + len))
952 }
953
954 proptest! {
955 /// Mentions sort by (start, end) consistently after CorefChain::new.
956 #[test]
957 fn mention_ordering_after_chain_construction(
958 mentions in proptest::collection::vec(arb_mention(10000), 1..20),
959 ) {
960 let chain = CorefChain::new(mentions);
961 for w in chain.mentions.windows(2) {
962 prop_assert!(
963 (w[0].start, w[0].end) <= (w[1].start, w[1].end),
964 "mentions must be sorted by (start, end): ({},{}) vs ({},{})",
965 w[0].start, w[0].end, w[1].start, w[1].end
966 );
967 }
968 }
969
970 /// CorefChain constructed with at least one mention is never empty.
971 #[test]
972 fn coref_chain_non_empty(
973 mentions in proptest::collection::vec(arb_mention(10000), 1..20),
974 ) {
975 let n = mentions.len();
976 let chain = CorefChain::new(mentions);
977 prop_assert!(!chain.is_empty());
978 prop_assert_eq!(chain.len(), n);
979 }
980
981 /// CorefChain::singleton always produces a chain with exactly one mention.
982 #[test]
983 fn coref_chain_singleton_has_one(start in 0usize..10000, len in 1usize..500) {
984 let m = Mention::new("x", start, start + len);
985 let chain = CorefChain::singleton(m);
986 prop_assert!(chain.is_singleton());
987 prop_assert_eq!(chain.len(), 1);
988 prop_assert_eq!(chain.link_count(), 0);
989 }
990
991 /// Mention::overlaps is symmetric.
992 #[test]
993 fn mention_overlap_symmetric(
994 s1 in 0usize..10000, len1 in 1usize..500,
995 s2 in 0usize..10000, len2 in 1usize..500,
996 ) {
997 let m1 = Mention::new("a", s1, s1 + len1);
998 let m2 = Mention::new("b", s2, s2 + len2);
999 prop_assert_eq!(m1.overlaps(&m2), m2.overlaps(&m1));
1000 }
1001
1002 /// Mention serde roundtrip preserves all fields.
1003 #[test]
1004 fn mention_serde_roundtrip(
1005 start in 0usize..10000, len in 1usize..500,
1006 ) {
1007 let m = Mention::new(format!("mention_{}", start), start, start + len);
1008 let json = serde_json::to_string(&m).unwrap();
1009 let m2: Mention = serde_json::from_str(&json).unwrap();
1010 prop_assert_eq!(&m, &m2);
1011 }
1012 }
1013}