anno/backends/graph_coref.rs
1//! Graph-Based Coreference Resolution with Iterative Refinement.
2//!
3//! This module implements a graph-based approach to coreference resolution,
4//! inspired by the Graph-to-Graph Transformer (G2GT) architecture from
5//! Miculicich & Henderson (2022). The key insight: model coreference as a
6//! graph where nodes are mentions and edges are coref links, then iteratively
7//! refine predictions until convergence.
8//!
9//! # Historical Context
10//!
11//! Coreference resolution evolved through distinct paradigms:
12//!
13//! ```text
14//! 1995-2010 Rule-based: Hobbs algorithm, centering theory
15//! 2010-2016 Mention-pair: Classify (m_i, m_j) independently
16//! 2017 Lee et al.: End-to-end span-based, O(N⁴) complexity
17//! 2018 Lee et al.: Higher-order with representation refinement
18//! 2022 G2GT: Graph refinement with global decisions, O(N² × T)
19//! ```
20//!
21//! **The core problem with pairwise models**: Decisions are independent.
22//! If P(A~B)=0.9 and P(B~C)=0.9, transitivity implies A~C, but pairwise
23//! models can output P(A~C)=0.1. The G2GT approach addresses this by
24//! conditioning each iteration on the full predicted graph from the previous
25//! iteration, enabling global consistency.
26//!
27//! # Architecture
28//!
29//! ```text
30//! Input: Detected mentions M = [m₁, m₂, ..., mₙ]
31//! ↓
32//! ┌─────────────────────────────────────────────────────────┐
33//! │ Iteration 0: Initialize empty graph G₀ │
34//! │ - Nodes: all mentions │
35//! │ - Edges: none │
36//! └─────────────────────────────────────────────────────────┘
37//! ↓
38//! ┌─────────────────────────────────────────────────────────┐
39//! │ Iteration t: Refine graph Gₜ₋₁ → Gₜ │
40//! │ For each mention pair (mᵢ, mⱼ) where j < i: │
41//! │ 1. Compute pairwise score s(mᵢ, mⱼ) │
42//! │ 2. Add graph context from Gₜ₋₁ (transitivity bonus) │
43//! │ 3. Update edge if score exceeds threshold │
44//! └─────────────────────────────────────────────────────────┘
45//! ↓ (repeat until Gₜ = Gₜ₋₁ or t = max_iterations)
46//! ┌─────────────────────────────────────────────────────────┐
47//! │ Extract clusters via connected components │
48//! └─────────────────────────────────────────────────────────┘
49//! ↓
50//! Output: Coreference chains (clusters)
51//! ```
52//!
53//! # Approximation vs. Full G2GT
54//!
55//! This implementation is a **heuristic approximation**, not a full neural model.
56//!
57//! | Aspect | G2GT (Miculicich 2022) | This Implementation |
58//! |--------|------------------------|---------------------|
59//! | **Graph nodes** | Tokens | Mentions (pre-detected) |
60//! | **Graph encoding** | Attention modification: `Lk = E(G)·Wk` | Explicit transitivity bonus |
61//! | **Pairwise scoring** | Learned neural scorer | String/head heuristics |
62//! | **Refinement** | Full neural re-prediction | Score adjustment |
63//! | **Training** | End-to-end backprop | None (heuristic) |
64//!
65//! ## What We Preserve
66//!
67//! The key insight from G2GT: **iterative refinement with graph conditioning**
68//! enables global consistency that independent pairwise models lack. Even with
69//! heuristic scoring, the refinement loop propagates transitivity constraints.
70//!
71//! ## What We Lose
72//!
73//! - **Learned representations**: G2GT embeds graph structure directly into
74//! transformer attention. We approximate this with explicit bonuses.
75//! - **End-to-end optimization**: G2GT trains the full system. We use fixed heuristics.
76//! - **Token-level granularity**: G2GT operates on tokens; we operate on mentions.
77//!
78//! For production use with high accuracy requirements, consider a full neural
79//! implementation or the T5-based coreference in `crate::backends::coref_t5`.
80//!
81//! # Usage with MentionType
82//!
83//! For best results, provide mentions with `mention_type` set:
84//!
85//! ```rust,ignore
86//! use anno::backends::graph_coref::GraphCoref;
87//! use anno::eval::coref::{Mention, MentionType};
88//!
89//! // Properly annotated mentions work better
90//! let mut john = Mention::new("John", 0, 4);
91//! john.mention_type = Some(MentionType::Proper);
92//!
93//! let mut he = Mention::new("he", 20, 22);
94//! he.mention_type = Some(MentionType::Pronominal);
95//!
96//! let coref = GraphCoref::new();
97//! let chains = coref.resolve(&[john, he]);
98//! ```
99//!
100//! # Graph Initialization: Syntactic vs Semantic
101//!
102//! SpanEIT (Hossain et al. 2025) constructs a combined graph:
103//! - **Syntactic edges** (`E_syn`): From dependency parse (adjectival modifiers, etc.)
104//! - **Semantic edges** (`E_sem`): From co-occurrence statistics
105//!
106//! Use [`CorefGraph::seed_cooccurrence_edges`] to initialize with proximity-based
107//! priors before running iterative refinement.
108//!
109//! # References
110//!
111//! - Miculicich & Henderson (2022): "Graph Refinement for Coreference Resolution"
112//! [arXiv:2203.16574](https://arxiv.org/abs/2203.16574)
113//! - Lee et al. (2017): "End-to-end Neural Coreference Resolution"
114//! - Lee et al. (2018): "Higher-Order Coreference Resolution"
115//! - Mohammadshahi & Henderson (2021): "Graph-to-Graph Transformer for Dependency Parsing"
116//! - Hossain et al. (2025): "SpanEIT: Dynamic Span Interaction and Graph-Aware Memory"
117//! [arXiv:2509.11604](https://arxiv.org/abs/2509.11604)
118//!
119//! # Future Direction: Sheaf Neural Networks
120//!
121//! This implementation uses explicit transitivity bonuses to approximate global consistency.
122//! A more principled approach: **Sheaf Neural Networks** replace scalar edge weights with
123//! learned linear maps (restriction maps) and minimize the sheaf Dirichlet energy:
124//!
125//! ```text
126//! E(x) = Σ_{(u,v) ∈ E} || F(u→v) · x_u - F(v→u) · x_v ||²
127//! ```
128//!
129//! This enforces transitivity at the gradient level, not post-hoc. See:
130//! - `archive/geometric-2024-12/sheaf.rs` for stub implementation and trait definitions
131//! - Bodnar et al. (2023): "Neural Sheaf Diffusion" - NeurIPS
132//! - twitter-research/neural-sheaf-diffusion (Apache 2.0): reference implementation
133
134use anno_core::{CorefChain, Mention, MentionType};
135use std::collections::{HashMap, HashSet, VecDeque};
136use std::hash::{Hash, Hasher};
137
138// =============================================================================
139// Types
140// =============================================================================
141
142/// Edge type in the coreference graph.
143///
144/// Following G2GT's three-way classification:
145/// - 0 (None): No relationship
146/// - 1 (Mention): Within-mention link (not used here since we operate on mentions, not tokens)
147/// - 2 (Coref): Coreference link between mentions
148#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
149#[repr(u8)]
150pub enum EdgeType {
151 /// No link between mentions.
152 None = 0,
153 /// Coreference link (both mentions refer to the same entity).
154 Coref = 2,
155}
156
157/// A coreference graph representing mention relationships.
158///
159/// This is the core data structure for iterative refinement. Nodes are mention
160/// indices, edges are coreference links. The graph is stored as an adjacency
161/// set for O(1) edge lookup during refinement.
162///
163/// # Invariants
164///
165/// - Graph is symmetric: if edge(i,j) exists, edge(j,i) exists
166/// - No self-loops: edge(i,i) is always None
167/// - Indices are valid mention indices: 0 <= i,j < num_mentions
168///
169/// # Example
170///
171/// ```rust
172/// use anno::backends::graph_coref::CorefGraph;
173///
174/// let mut graph = CorefGraph::new(3);
175/// graph.add_edge(0, 1);
176/// graph.add_edge(1, 2);
177///
178/// assert!(graph.has_edge(0, 1));
179/// assert!(graph.transitively_connected(0, 2)); // via 0-1-2
180///
181/// let clusters = graph.extract_clusters();
182/// assert_eq!(clusters.len(), 1); // All connected
183/// ```
184#[derive(Debug, Clone, PartialEq, Eq)]
185pub struct CorefGraph {
186 /// Number of mentions (nodes).
187 num_mentions: usize,
188 /// Adjacency set: (i, j) where i < j for canonical representation.
189 edges: HashSet<(usize, usize)>,
190}
191
192impl CorefGraph {
193 /// Create an empty coreference graph with the given number of mentions.
194 #[must_use]
195 pub fn new(num_mentions: usize) -> Self {
196 Self {
197 num_mentions,
198 edges: HashSet::new(),
199 }
200 }
201
202 /// Get the number of mentions (nodes) in the graph.
203 #[must_use]
204 pub fn num_mentions(&self) -> usize {
205 self.num_mentions
206 }
207
208 /// Add a coreference edge between two mentions.
209 ///
210 /// The edge is stored in canonical form (i < j) for consistency.
211 /// Self-loops and out-of-bounds indices are silently ignored.
212 pub fn add_edge(&mut self, i: usize, j: usize) {
213 if i == j || i >= self.num_mentions || j >= self.num_mentions {
214 return;
215 }
216 let (lo, hi) = if i < j { (i, j) } else { (j, i) };
217 self.edges.insert((lo, hi));
218 }
219
220 /// Remove a coreference edge between two mentions.
221 pub fn remove_edge(&mut self, i: usize, j: usize) {
222 let (lo, hi) = if i < j { (i, j) } else { (j, i) };
223 self.edges.remove(&(lo, hi));
224 }
225
226 /// Check if two mentions are directly linked.
227 #[must_use]
228 pub fn has_edge(&self, i: usize, j: usize) -> bool {
229 if i == j {
230 return false;
231 }
232 let (lo, hi) = if i < j { (i, j) } else { (j, i) };
233 self.edges.contains(&(lo, hi))
234 }
235
236 /// Get all neighbors (directly linked mentions) of a mention.
237 #[must_use]
238 pub fn neighbors(&self, i: usize) -> Vec<usize> {
239 let mut result = Vec::new();
240 for &(lo, hi) in &self.edges {
241 if lo == i {
242 result.push(hi);
243 } else if hi == i {
244 result.push(lo);
245 }
246 }
247 result
248 }
249
250 /// Count shared neighbors between two mentions.
251 ///
252 /// This is the basis for the transitivity bonus: if mentions i and j
253 /// share many neighbors in the current graph, they're likely coreferent.
254 ///
255 /// # G2GT Connection
256 ///
257 /// In the full G2GT model, shared structure is captured via graph-conditioned
258 /// attention. Here we approximate it by explicitly counting shared neighbors
259 /// and adding a proportional bonus to the pairwise score.
260 #[must_use]
261 pub fn shared_neighbors(&self, i: usize, j: usize) -> usize {
262 let neighbors_i: HashSet<usize> = self.neighbors(i).into_iter().collect();
263 let neighbors_j: HashSet<usize> = self.neighbors(j).into_iter().collect();
264 neighbors_i.intersection(&neighbors_j).count()
265 }
266
267 /// Check if two mentions are transitively connected.
268 ///
269 /// Uses BFS to find if there's a path from i to j through coreference links.
270 /// This is the closure property that ensures consistency: if A~B and B~C,
271 /// then A and C are transitively connected even without a direct edge.
272 #[must_use]
273 pub fn transitively_connected(&self, i: usize, j: usize) -> bool {
274 if i == j {
275 return true;
276 }
277 if self.has_edge(i, j) {
278 return true;
279 }
280
281 // BFS from i to find j
282 let mut visited = HashSet::new();
283 let mut queue = vec![i];
284 visited.insert(i);
285
286 while let Some(current) = queue.pop() {
287 for neighbor in self.neighbors(current) {
288 if neighbor == j {
289 return true;
290 }
291 if visited.insert(neighbor) {
292 queue.push(neighbor);
293 }
294 }
295 }
296
297 false
298 }
299
300 /// Extract connected components as clusters.
301 ///
302 /// Each connected component in the graph becomes a coreference chain.
303 /// Singleton mentions (no edges) are included as single-mention clusters.
304 #[must_use]
305 pub fn extract_clusters(&self) -> Vec<Vec<usize>> {
306 let mut visited = vec![false; self.num_mentions];
307 let mut clusters = Vec::new();
308
309 for start in 0..self.num_mentions {
310 if visited[start] {
311 continue;
312 }
313
314 // BFS to find all members of this component
315 let mut cluster = Vec::new();
316 let mut queue = vec![start];
317 visited[start] = true;
318
319 while let Some(current) = queue.pop() {
320 cluster.push(current);
321 for neighbor in self.neighbors(current) {
322 if !visited[neighbor] {
323 visited[neighbor] = true;
324 queue.push(neighbor);
325 }
326 }
327 }
328
329 cluster.sort_unstable();
330 clusters.push(cluster);
331 }
332
333 clusters
334 }
335
336 /// Get the number of edges in the graph.
337 #[must_use]
338 pub fn edge_count(&self) -> usize {
339 self.edges.len()
340 }
341
342 /// Check if graph is empty (no edges).
343 #[must_use]
344 pub fn is_empty(&self) -> bool {
345 self.edges.is_empty()
346 }
347
348 /// Seed the graph with co-occurrence priors based on mention proximity.
349 ///
350 /// This is inspired by SpanEIT (Hossain et al. 2025), which constructs a
351 /// semantic co-occurrence graph `G_sem` alongside the syntactic graph. The
352 /// insight: mentions that appear close together are more likely coreferent.
353 ///
354 /// # Arguments
355 ///
356 /// * `mention_positions` - Position of each mention (e.g., character offset)
357 /// * `window_size` - Maximum distance for co-occurrence (e.g., 100 chars)
358 /// * `scorer` - Optional scoring function; if None, uses constant weight
359 ///
360 /// # Example
361 ///
362 /// ```rust
363 /// use anno::backends::graph_coref::CorefGraph;
364 ///
365 /// let mut graph = CorefGraph::new(3);
366 /// let positions = vec![0, 50, 200]; // Character offsets
367 ///
368 /// // Seed edges for mentions within 100 chars of each other
369 /// // Type annotation needed when passing None for the scorer
370 /// graph.seed_cooccurrence_edges::<fn(usize, usize) -> bool>(&positions, 100, None);
371 ///
372 /// assert!(graph.has_edge(0, 1)); // 50 < 100
373 /// assert!(!graph.has_edge(0, 2)); // 200 > 100
374 /// ```
375 ///
376 /// # Research Background
377 ///
378 /// SpanEIT constructs `G = (V, E_syn ∪ E_sem)` where:
379 /// - `E_syn` = syntactic dependency edges
380 /// - `E_sem` = co-occurrence edges (this method)
381 ///
382 /// The combined graph is processed by GAT layers for context-aware embeddings.
383 /// For anno's heuristic approach, we add these edges as initial priors before
384 /// iterative refinement.
385 pub fn seed_cooccurrence_edges<F>(
386 &mut self,
387 mention_positions: &[usize],
388 window_size: usize,
389 scorer: Option<F>,
390 ) where
391 F: Fn(usize, usize) -> bool,
392 {
393 for i in 0..self.num_mentions {
394 for j in (i + 1)..self.num_mentions {
395 if i >= mention_positions.len() || j >= mention_positions.len() {
396 continue;
397 }
398
399 let pos_i = mention_positions[i];
400 let pos_j = mention_positions[j];
401 let distance = pos_i.abs_diff(pos_j);
402
403 if distance <= window_size {
404 let should_add = match scorer.as_ref() {
405 None => true,
406 Some(f) => f(i, j),
407 };
408 if should_add {
409 self.add_edge(i, j);
410 }
411 }
412 }
413 }
414 }
415}
416
417// =============================================================================
418// Configuration
419// =============================================================================
420
421/// Configuration for graph-based coreference resolution.
422///
423/// These parameters control the iterative refinement process. The defaults
424/// are based on findings from Miculicich & Henderson (2022):
425///
426/// - `max_iterations = 4`: Paper found T=4 optimal; more iterations don't help
427/// - `link_threshold = 0.5`: Standard classification threshold
428/// - `transitivity_bonus = 0.15`: Reward for transitive consistency
429///
430/// # Tuning Guide
431///
432/// | Parameter | Higher Value | Lower Value |
433/// |-----------|--------------|-------------|
434/// | `link_threshold` | Fewer, more confident links | More links, potential noise |
435/// | `transitivity_bonus` | Stronger clustering effect | More independent decisions |
436/// | `max_iterations` | More refinement passes | Faster, less propagation |
437/// | `head_match_weight` | Trust head matches more | Rely more on full string |
438#[derive(Debug, Clone)]
439pub struct GraphCorefConfig {
440 /// Maximum refinement iterations before stopping.
441 ///
442 /// The G2GT paper found T=4 to be optimal on CoNLL 2012. Fewer iterations
443 /// leave potential coreference links undiscovered; more iterations don't
444 /// improve results and waste computation.
445 pub max_iterations: usize,
446
447 /// Minimum score to create a coreference link.
448 ///
449 /// A pair (mᵢ, mⱼ) is linked if score(mᵢ, mⱼ) + context_bonus > threshold.
450 pub link_threshold: f64,
451
452 /// Bonus added for transitive consistency.
453 ///
454 /// If mentions A and B share neighbors in the current graph (i.e., both
455 /// are already linked to some common mention C), this bonus is added to
456 /// encourage the model to also link A and B directly.
457 ///
458 /// **Note**: This is our heuristic approximation of G2GT's graph-conditioned
459 /// attention. The full G2GT model encodes graph structure as:
460 /// ```text
461 /// Attention(Q,K,V,Lk,Lv) = softmax(Q·(K+Lk)/√d)·(V+Lv)
462 /// where Lk = E(G^{t-1})·Wk
463 /// ```
464 /// We approximate this by explicit score adjustment rather than attention modification.
465 pub transitivity_bonus: f64,
466
467 /// Bonus for each shared neighbor.
468 ///
469 /// Scaled by the number of shared neighbors: total_bonus = shared_count * per_neighbor_bonus
470 pub per_shared_neighbor_bonus: f64,
471
472 /// Weight for string similarity in pairwise scoring.
473 pub string_similarity_weight: f64,
474
475 /// Weight for head word matching.
476 ///
477 /// The G2GT paper emphasizes head-based matching. When mentions have
478 /// `head_start`/`head_end` set, head matching is used. Otherwise falls
479 /// back to last word heuristic.
480 pub head_match_weight: f64,
481
482 /// Weight for distance penalty in pairwise scoring.
483 pub distance_weight: f64,
484
485 /// Maximum character distance to consider (mentions further apart are not linked).
486 pub max_distance: Option<usize>,
487
488 /// Include singletons (mentions with no coreference) in output.
489 ///
490 /// Default: false (only return multi-mention chains).
491 /// Set to true for evaluation against datasets that include singletons.
492 pub include_singletons: bool,
493
494 /// Bonus when a pronoun links to a proper noun.
495 ///
496 /// Pronouns are weak signals alone but should link to antecedents.
497 pub pronoun_proper_bonus: f64,
498
499 /// Optional early-stop controls for iterative refinement.
500 ///
501 /// GraphCoref already stops when it reaches a fixed point (Gₜ == Gₜ₋₁).
502 /// This option additionally stops on:
503 /// - **Cycle detection** (e.g., A→B→A oscillation across iterations)
504 /// - **Stagnation** (edge count stops changing for N iterations)
505 ///
506 /// This is an analogue of “overthinking” / redundancy detection (CoRE-Eval),
507 /// implemented using observable signals (graph structure) rather than hidden states.
508 pub early_stop: Option<GraphCorefEarlyStopConfig>,
509}
510
511/// Configuration for early stopping in iterative graph refinement.
512#[derive(Debug, Clone)]
513pub struct GraphCorefEarlyStopConfig {
514 /// Stop if we detect a repeated graph state (cycle) within the configured history.
515 pub detect_cycles: bool,
516 /// How many past graph fingerprints to remember (0 = unbounded).
517 pub cycle_history: usize,
518 /// Stop if the edge count hasn't changed for this many consecutive iterations.
519 pub stagnation_patience: usize,
520}
521
522impl Default for GraphCorefEarlyStopConfig {
523 fn default() -> Self {
524 Self {
525 detect_cycles: true,
526 cycle_history: 8,
527 stagnation_patience: 2,
528 }
529 }
530}
531
532impl Default for GraphCorefConfig {
533 fn default() -> Self {
534 Self {
535 max_iterations: 4,
536 link_threshold: 0.5,
537 transitivity_bonus: 0.15,
538 per_shared_neighbor_bonus: 0.1,
539 string_similarity_weight: 1.0,
540 head_match_weight: 0.5,
541 distance_weight: 0.05,
542 max_distance: Some(1000),
543 include_singletons: false,
544 pronoun_proper_bonus: 0.3,
545 early_stop: None,
546 }
547 }
548}
549
550// =============================================================================
551// Main Implementation
552// =============================================================================
553
554/// Graph-based coreference resolver with iterative refinement.
555///
556/// This implements a heuristic version of the G2GT architecture, preserving
557/// the key insight that iterative graph refinement enables global consistency
558/// in coreference decisions.
559///
560/// # Algorithm
561///
562/// 1. **Initialize**: Empty graph with mentions as nodes
563/// 2. **Iterate**: For each mention pair, compute score with graph context
564/// 3. **Update**: Add/remove edges based on threshold
565/// 4. **Converge**: Stop when graph unchanged or max iterations reached
566/// 5. **Extract**: Connected components become coreference chains
567///
568/// # Complexity
569///
570/// - Time: O(N² × T) where N = mentions, T = iterations (typically 4)
571/// - Space: O(N²) for adjacency representation
572///
573/// Compare to Lee et al. (2017): O(N⁴) for full span enumeration.
574///
575/// # Feature Usage
576///
577/// The resolver uses available `Mention` fields when present:
578///
579/// | Field | Used For | Fallback |
580/// |-------|----------|----------|
581/// | `mention_type` | Pronoun detection, type compatibility | Heuristic detection |
582/// | `head_start`/`head_end` | Head word matching | Last word of mention |
583/// | `entity_type` | Type compatibility check | Ignored |
584#[derive(Debug, Clone)]
585pub struct GraphCoref {
586 config: GraphCorefConfig,
587}
588
589impl GraphCoref {
590 /// Create a new graph coref resolver with default configuration.
591 #[must_use]
592 pub fn new() -> Self {
593 Self::with_config(GraphCorefConfig::default())
594 }
595
596 /// Create with custom configuration.
597 #[must_use]
598 pub fn with_config(config: GraphCorefConfig) -> Self {
599 Self { config }
600 }
601
602 fn graph_fingerprint(graph: &CorefGraph) -> u64 {
603 // Order-independent fingerprint of the edge set.
604 // We sort per-edge hashes to avoid HashSet iteration nondeterminism.
605 let mut edge_hashes: Vec<u64> = graph
606 .edges
607 .iter()
608 .map(|e| {
609 let mut h = std::collections::hash_map::DefaultHasher::new();
610 e.hash(&mut h);
611 h.finish()
612 })
613 .collect();
614 edge_hashes.sort_unstable();
615
616 let mut h = std::collections::hash_map::DefaultHasher::new();
617 graph.num_mentions.hash(&mut h);
618 for eh in edge_hashes {
619 eh.hash(&mut h);
620 }
621 h.finish()
622 }
623
624 /// Resolve coreferences among mentions using iterative graph refinement.
625 ///
626 /// # Arguments
627 ///
628 /// * `mentions` - Pre-detected mentions (from NER or mention detector).
629 /// For best results, set `mention_type` on each mention.
630 ///
631 /// # Returns
632 ///
633 /// Coreference chains (clusters) where each chain contains mentions
634 /// referring to the same entity. By default, singletons are filtered;
635 /// set `config.include_singletons = true` to include them.
636 ///
637 /// # Panics
638 ///
639 /// Does not panic. Empty input returns empty output.
640 ///
641 /// # Example
642 ///
643 /// ```rust
644 /// use anno::backends::graph_coref::GraphCoref;
645 /// use anno::{Mention, MentionType};
646 ///
647 /// let coref = GraphCoref::new();
648 ///
649 /// let mut john = Mention::new("John", 0, 4);
650 /// john.mention_type = Some(MentionType::Proper);
651 ///
652 /// let mut he = Mention::new("he", 20, 22);
653 /// he.mention_type = Some(MentionType::Pronominal);
654 ///
655 /// let chains = coref.resolve(&[john, he]);
656 /// ```
657 #[must_use]
658 pub fn resolve(&self, mentions: &[Mention]) -> Vec<CorefChain> {
659 if mentions.is_empty() {
660 return vec![];
661 }
662
663 // Validate mentions (filter empty/invalid)
664 let valid_mentions: Vec<&Mention> = mentions
665 .iter()
666 .filter(|m| !m.text.trim().is_empty() && m.start < m.end)
667 .collect();
668
669 if valid_mentions.is_empty() {
670 return vec![];
671 }
672
673 // Initialize empty graph
674 let mut graph = CorefGraph::new(valid_mentions.len());
675
676 // Iterative refinement
677 let mut stagnation: usize = 0;
678 let mut last_edge_count = graph.edge_count();
679 let mut seen: HashMap<u64, usize> = HashMap::new();
680 let mut history: VecDeque<u64> = VecDeque::new();
681 if let Some(cfg) = &self.config.early_stop {
682 if cfg.detect_cycles {
683 let fp0 = Self::graph_fingerprint(&graph);
684 seen.insert(fp0, 0);
685 history.push_back(fp0);
686 }
687 }
688
689 for iteration in 0..self.config.max_iterations {
690 let prev_graph = graph.clone();
691 graph = self.refine_iteration(&valid_mentions, &graph);
692
693 // Check convergence
694 if graph == prev_graph {
695 break;
696 }
697
698 // Optional early stop: stagnation / cycles.
699 if let Some(cfg) = &self.config.early_stop {
700 // Stagnation: edge count stops changing.
701 let ec = graph.edge_count();
702 if ec == last_edge_count {
703 stagnation += 1;
704 } else {
705 stagnation = 0;
706 last_edge_count = ec;
707 }
708 if cfg.stagnation_patience > 0 && stagnation >= cfg.stagnation_patience {
709 break;
710 }
711
712 if cfg.detect_cycles {
713 let fp = Self::graph_fingerprint(&graph);
714 if seen.contains_key(&fp) {
715 break;
716 }
717 seen.insert(fp, iteration + 1);
718 history.push_back(fp);
719 if cfg.cycle_history > 0 {
720 while history.len() > cfg.cycle_history {
721 if let Some(old) = history.pop_front() {
722 seen.remove(&old);
723 }
724 }
725 }
726 }
727 }
728 }
729
730 // Extract clusters and convert to CorefChains
731 self.graph_to_chains(&graph, &valid_mentions)
732 }
733
734 /// Perform one iteration of graph refinement.
735 ///
736 /// For each mention pair, computes a score incorporating:
737 /// 1. Base pairwise similarity (string match, head match, type compatibility)
738 /// 2. Graph context (transitivity bonus from shared neighbors)
739 ///
740 /// Edges are added if score exceeds threshold.
741 fn refine_iteration(&self, mentions: &[&Mention], prev_graph: &CorefGraph) -> CorefGraph {
742 let mut new_graph = CorefGraph::new(mentions.len());
743
744 for i in 0..mentions.len() {
745 for j in 0..i {
746 // Distance filter
747 if let Some(max_dist) = self.config.max_distance {
748 let dist = mentions[i].start.saturating_sub(mentions[j].end);
749 if dist > max_dist {
750 continue;
751 }
752 }
753
754 // Compute score with graph context
755 let base_score = self.pairwise_score(mentions[i], mentions[j]);
756 let context_bonus = self.graph_context_bonus(i, j, prev_graph);
757 let total_score = base_score + context_bonus;
758
759 if total_score > self.config.link_threshold {
760 new_graph.add_edge(i, j);
761 }
762 }
763 }
764
765 new_graph
766 }
767
768 /// Compute base pairwise similarity score between two mentions.
769 ///
770 /// Uses multiple signals:
771 /// - Exact string match (highest weight)
772 /// - Substring containment
773 /// - Head word match (uses `head_start`/`head_end` if available)
774 /// - Pronoun-to-proper linking (uses `mention_type` if available)
775 /// - Distance penalty
776 fn pairwise_score(&self, m1: &Mention, m2: &Mention) -> f64 {
777 let mut score = 0.0;
778
779 let t1 = m1.text.to_lowercase();
780 let t2 = m2.text.to_lowercase();
781
782 // Exact match (strongest signal)
783 if t1 == t2 {
784 score += self.config.string_similarity_weight * 1.0;
785 }
786 // Substring containment
787 else if t1.contains(&t2) || t2.contains(&t1) {
788 score += self.config.string_similarity_weight * 0.6;
789 }
790 // Head word match
791 else {
792 let h1 = self.get_head_text(m1);
793 let h2 = self.get_head_text(m2);
794 if !h1.is_empty() && h1.to_lowercase() == h2.to_lowercase() {
795 score += self.config.head_match_weight;
796 }
797 }
798
799 // Mention type compatibility
800 score += self.type_compatibility_score(m1, m2);
801
802 // Distance penalty (log scale)
803 let distance = m1.start.abs_diff(m2.end).min(m2.start.abs_diff(m1.end));
804 if distance > 0 {
805 score -= self.config.distance_weight * (distance as f64).ln();
806 }
807
808 score
809 }
810
811 /// Get head text for a mention.
812 ///
813 /// Uses `head_start`/`head_end` if available, otherwise falls back to
814 /// last word heuristic (common in English NPs where head is rightmost).
815 fn get_head_text<'a>(&self, mention: &'a Mention) -> &'a str {
816 // Use explicit head span if available
817 if let (Some(head_start), Some(head_end)) = (mention.head_start, mention.head_end) {
818 // Head offsets are relative to document, need to extract from text
819 // This is complex; fall back to heuristic for now
820 // In a full implementation, we'd have the document text available
821 let _ = (head_start, head_end);
822 }
823
824 // Fallback: last word (head-final assumption for English NPs)
825 mention.text.split_whitespace().last().unwrap_or("")
826 }
827
828 /// Compute type compatibility score between two mentions.
829 ///
830 /// Uses `MentionType` field if available, otherwise uses heuristics.
831 fn type_compatibility_score(&self, m1: &Mention, m2: &Mention) -> f64 {
832 let type1 = m1
833 .mention_type
834 .unwrap_or_else(|| self.infer_mention_type(m1));
835 let type2 = m2
836 .mention_type
837 .unwrap_or_else(|| self.infer_mention_type(m2));
838
839 match (type1, type2) {
840 // Pronoun linking to proper noun: boost
841 (MentionType::Pronominal, MentionType::Proper)
842 | (MentionType::Proper, MentionType::Pronominal) => self.config.pronoun_proper_bonus,
843
844 // Pronoun linking to nominal: smaller boost
845 (MentionType::Pronominal, MentionType::Nominal)
846 | (MentionType::Nominal, MentionType::Pronominal) => {
847 self.config.pronoun_proper_bonus * 0.5
848 }
849
850 // Same type: neutral
851 _ if type1 == type2 => 0.0,
852
853 // Different non-pronoun types: slight penalty
854 _ => -0.1,
855 }
856 }
857
858 /// Infer mention type from text when not explicitly set.
859 ///
860 /// This is a fallback heuristic. For best results, set `mention_type`
861 /// on mentions before calling `resolve()`.
862 fn infer_mention_type(&self, mention: &Mention) -> MentionType {
863 let text_lower = mention.text.to_lowercase();
864
865 // Check for pronouns
866 const PRONOUNS: &[&str] = &[
867 "i",
868 "me",
869 "my",
870 "mine",
871 "myself",
872 "you",
873 "your",
874 "yours",
875 "yourself",
876 "yourselves",
877 "he",
878 "him",
879 "his",
880 "himself",
881 "she",
882 "her",
883 "hers",
884 "herself",
885 "it",
886 "its",
887 "itself",
888 "we",
889 "us",
890 "our",
891 "ours",
892 "ourselves",
893 "they",
894 "them",
895 "their",
896 "theirs",
897 "themselves",
898 "who",
899 "whom",
900 "whose",
901 "which",
902 "that",
903 "this",
904 "these",
905 "those",
906 ];
907
908 if PRONOUNS.contains(&text_lower.as_str()) {
909 return MentionType::Pronominal;
910 }
911
912 // Check for proper noun (starts with uppercase, not sentence-initial heuristic)
913 let first_char = mention.text.chars().next();
914 if first_char.is_some_and(|c| c.is_uppercase()) {
915 // Additional check: not a common word
916 let common_words = ["the", "a", "an", "this", "that", "these", "those"];
917 if !common_words.contains(&text_lower.as_str()) {
918 return MentionType::Proper;
919 }
920 }
921
922 // Default to nominal
923 MentionType::Nominal
924 }
925
926 /// Compute graph context bonus based on previous iteration's structure.
927 ///
928 /// This is our approximation of G2GT's graph-conditioned attention.
929 ///
930 /// # Transitivity Bonus
931 ///
932 /// If A~C and B~C in the previous graph, A and B should likely be linked.
933 /// We add a bonus proportional to the number of shared neighbors.
934 ///
935 /// # Already Connected Bonus
936 ///
937 /// If A and B are already transitively connected (through a chain of
938 /// coreference links), add a bonus to preserve and strengthen the connection.
939 fn graph_context_bonus(&self, i: usize, j: usize, prev_graph: &CorefGraph) -> f64 {
940 let mut bonus = 0.0;
941
942 // Bonus for shared neighbors (transitivity signal)
943 let shared = prev_graph.shared_neighbors(i, j);
944 bonus += (shared as f64) * self.config.per_shared_neighbor_bonus;
945
946 // Bonus if already transitively connected
947 if prev_graph.transitively_connected(i, j) {
948 bonus += self.config.transitivity_bonus;
949 }
950
951 bonus
952 }
953
954 /// Convert graph clusters to CorefChain format.
955 fn graph_to_chains(&self, graph: &CorefGraph, mentions: &[&Mention]) -> Vec<CorefChain> {
956 let clusters = graph.extract_clusters();
957
958 clusters
959 .into_iter()
960 .filter(|cluster| self.config.include_singletons || cluster.len() > 1)
961 .enumerate()
962 .map(|(id, indices)| {
963 let chain_mentions: Vec<Mention> = indices
964 .into_iter()
965 .map(|i| (*mentions[i]).clone())
966 .collect();
967
968 let mut chain = CorefChain::new(chain_mentions);
969 chain.cluster_id = Some((id as u64).into());
970
971 // Set entity type from first proper mention
972 chain.entity_type = chain
973 .mentions
974 .iter()
975 .find(|m| m.mention_type == Some(MentionType::Proper))
976 .and_then(|m| m.entity_type.clone());
977
978 chain
979 })
980 .collect()
981 }
982
983 /// Get configuration.
984 #[must_use]
985 pub fn config(&self) -> &GraphCorefConfig {
986 &self.config
987 }
988}
989
990impl Default for GraphCoref {
991 fn default() -> Self {
992 Self::new()
993 }
994}
995
996// =============================================================================
997// Metrics and Diagnostics
998// =============================================================================
999
1000/// Statistics from a graph coref run for debugging and analysis.
1001///
1002/// Use [`GraphCoref::resolve_with_stats`] to get these alongside results.
1003#[derive(Debug, Clone, Default)]
1004pub struct GraphCorefStats {
1005 /// Number of iterations until convergence (1 to max_iterations).
1006 pub iterations: usize,
1007 /// Number of edges in final graph.
1008 pub final_edges: usize,
1009 /// Number of clusters (including singletons).
1010 pub num_clusters: usize,
1011 /// Number of non-singleton clusters.
1012 pub num_chains: usize,
1013 /// Per-iteration edge counts, starting from 0.
1014 pub edge_history: Vec<usize>,
1015 /// Whether the algorithm converged before max_iterations.
1016 pub converged: bool,
1017 /// Whether we stopped early for a non-fixed-point reason (cycle/stagnation).
1018 pub early_stopped: bool,
1019 /// Cycle detected (graph fingerprint repeated).
1020 pub cycle_detected: bool,
1021 /// Stagnation detected (edge count stopped changing).
1022 pub stagnation_detected: bool,
1023}
1024
1025impl GraphCoref {
1026 /// Resolve coreferences and return detailed statistics.
1027 ///
1028 /// Useful for debugging, tuning parameters, and understanding convergence.
1029 ///
1030 /// # Example
1031 ///
1032 /// ```rust
1033 /// use anno::backends::graph_coref::GraphCoref;
1034 /// use anno::Mention;
1035 ///
1036 /// let coref = GraphCoref::new();
1037 /// let mentions = vec![
1038 /// Mention::new("John", 0, 4),
1039 /// Mention::new("John", 50, 54),
1040 /// ];
1041 ///
1042 /// let (chains, stats) = coref.resolve_with_stats(&mentions);
1043 /// println!("Converged in {} iterations", stats.iterations);
1044 /// println!("Edge history: {:?}", stats.edge_history);
1045 /// ```
1046 #[must_use]
1047 pub fn resolve_with_stats(&self, mentions: &[Mention]) -> (Vec<CorefChain>, GraphCorefStats) {
1048 let mut stats = GraphCorefStats::default();
1049
1050 if mentions.is_empty() {
1051 return (vec![], stats);
1052 }
1053
1054 let valid_mentions: Vec<&Mention> = mentions
1055 .iter()
1056 .filter(|m| !m.text.trim().is_empty() && m.start < m.end)
1057 .collect();
1058
1059 if valid_mentions.is_empty() {
1060 return (vec![], stats);
1061 }
1062
1063 let mut graph = CorefGraph::new(valid_mentions.len());
1064 stats.edge_history.push(0);
1065
1066 let mut stagnation: usize = 0;
1067 let mut last_edge_count = graph.edge_count();
1068 let mut seen: HashMap<u64, usize> = HashMap::new();
1069 let mut history: VecDeque<u64> = VecDeque::new();
1070 if let Some(cfg) = &self.config.early_stop {
1071 if cfg.detect_cycles {
1072 let fp0 = Self::graph_fingerprint(&graph);
1073 seen.insert(fp0, 0);
1074 history.push_back(fp0);
1075 }
1076 }
1077
1078 for iteration in 0..self.config.max_iterations {
1079 let prev_graph = graph.clone();
1080 graph = self.refine_iteration(&valid_mentions, &graph);
1081 stats.edge_history.push(graph.edge_count());
1082 stats.iterations = iteration + 1;
1083
1084 if graph == prev_graph {
1085 stats.converged = true;
1086 break;
1087 }
1088
1089 if let Some(cfg) = &self.config.early_stop {
1090 let ec = graph.edge_count();
1091 if ec == last_edge_count {
1092 stagnation += 1;
1093 } else {
1094 stagnation = 0;
1095 last_edge_count = ec;
1096 }
1097 if cfg.stagnation_patience > 0 && stagnation >= cfg.stagnation_patience {
1098 stats.early_stopped = true;
1099 stats.stagnation_detected = true;
1100 break;
1101 }
1102
1103 if cfg.detect_cycles {
1104 let fp = Self::graph_fingerprint(&graph);
1105 if seen.contains_key(&fp) {
1106 stats.early_stopped = true;
1107 stats.cycle_detected = true;
1108 break;
1109 }
1110 seen.insert(fp, iteration + 1);
1111 history.push_back(fp);
1112 if cfg.cycle_history > 0 {
1113 while history.len() > cfg.cycle_history {
1114 if let Some(old) = history.pop_front() {
1115 seen.remove(&old);
1116 }
1117 }
1118 }
1119 }
1120 }
1121 }
1122
1123 let clusters = graph.extract_clusters();
1124 stats.final_edges = graph.edge_count();
1125 stats.num_clusters = clusters.len();
1126 stats.num_chains = clusters.iter().filter(|c| c.len() > 1).count();
1127
1128 let chains = self.graph_to_chains(&graph, &valid_mentions);
1129 (chains, stats)
1130 }
1131}
1132
1133// =============================================================================
1134// Evaluation Helpers
1135// =============================================================================
1136
1137/// Convert GraphCoref output to format suitable for CoNLL evaluation.
1138///
1139/// This produces a `CorefDocument` that can be evaluated with standard
1140/// coreference metrics (MUC, B³, CEAF, LEA).
1141///
1142/// # Example
1143///
1144/// ```rust
1145/// use anno::backends::graph_coref::{GraphCoref, chains_to_document};
1146/// use anno::Mention;
1147///
1148/// let coref = GraphCoref::new();
1149/// let mentions = vec![
1150/// Mention::new("John", 0, 4),
1151/// Mention::new("he", 20, 22),
1152/// ];
1153///
1154/// let chains = coref.resolve(&mentions);
1155/// let doc = chains_to_document("John went to work. He was late.", chains);
1156/// ```
1157pub fn chains_to_document(
1158 text: impl Into<String>,
1159 chains: Vec<CorefChain>,
1160) -> anno_core::CorefDocument {
1161 anno_core::CorefDocument::new(text, chains)
1162}
1163
1164// =============================================================================
1165// Tests
1166// =============================================================================
1167
1168#[cfg(test)]
1169mod tests {
1170 use super::*;
1171
1172 fn make_mention(text: &str, start: usize) -> Mention {
1173 Mention::new(text, start, start + text.chars().count())
1174 }
1175
1176 fn make_typed_mention(text: &str, start: usize, mention_type: MentionType) -> Mention {
1177 Mention::with_type(text, start, start + text.chars().count(), mention_type)
1178 }
1179
1180 // -------------------------------------------------------------------------
1181 // Basic functionality
1182 // -------------------------------------------------------------------------
1183
1184 #[test]
1185 fn test_empty_input() {
1186 let coref = GraphCoref::new();
1187 let chains = coref.resolve(&[]);
1188 assert!(chains.is_empty());
1189 }
1190
1191 #[test]
1192 fn test_single_mention() {
1193 let coref = GraphCoref::new();
1194 let mentions = vec![make_mention("John", 0)];
1195 let chains = coref.resolve(&mentions);
1196 assert!(
1197 chains.is_empty(),
1198 "Single mention should be filtered as singleton"
1199 );
1200 }
1201
1202 #[test]
1203 fn test_single_mention_with_singletons() {
1204 let config = GraphCorefConfig {
1205 include_singletons: true,
1206 ..Default::default()
1207 };
1208 let coref = GraphCoref::with_config(config);
1209 let mentions = vec![make_mention("John", 0)];
1210 let chains = coref.resolve(&mentions);
1211 assert_eq!(chains.len(), 1, "Should include singleton when configured");
1212 }
1213
1214 #[test]
1215 fn test_exact_match_linking() {
1216 let coref = GraphCoref::new();
1217 let mentions = vec![make_mention("John", 0), make_mention("John", 50)];
1218
1219 let chains = coref.resolve(&mentions);
1220 assert_eq!(chains.len(), 1);
1221 assert_eq!(chains[0].mentions.len(), 2);
1222 }
1223
1224 #[test]
1225 fn test_substring_linking() {
1226 let config = GraphCorefConfig {
1227 link_threshold: 0.4,
1228 ..Default::default()
1229 };
1230 let coref = GraphCoref::with_config(config);
1231 let mentions = vec![make_mention("Marie Curie", 0), make_mention("Curie", 50)];
1232
1233 let chains = coref.resolve(&mentions);
1234 assert_eq!(chains.len(), 1);
1235 assert_eq!(chains[0].mentions.len(), 2);
1236 }
1237
1238 // -------------------------------------------------------------------------
1239 // MentionType usage
1240 // -------------------------------------------------------------------------
1241
1242 #[test]
1243 fn test_typed_pronoun_linking() {
1244 let config = GraphCorefConfig {
1245 link_threshold: 0.2,
1246 distance_weight: 0.0, // Disable distance penalty to isolate type signal
1247 ..Default::default()
1248 };
1249 let coref = GraphCoref::with_config(config);
1250
1251 let mentions = vec![
1252 make_typed_mention("Marie", 0, MentionType::Proper),
1253 make_typed_mention("she", 20, MentionType::Pronominal),
1254 ];
1255
1256 let chains = coref.resolve(&mentions);
1257 assert_eq!(chains.len(), 1, "Typed pronoun should link to proper noun");
1258 }
1259
1260 #[test]
1261 fn test_inferred_pronoun_detection() {
1262 let coref = GraphCoref::new();
1263
1264 // Create mention without type - should be inferred
1265 let he = make_mention("he", 0);
1266 assert_eq!(
1267 coref.infer_mention_type(&he),
1268 MentionType::Pronominal,
1269 "Should detect 'he' as pronoun"
1270 );
1271
1272 let john = make_mention("John", 0);
1273 assert_eq!(
1274 coref.infer_mention_type(&john),
1275 MentionType::Proper,
1276 "Should detect 'John' as proper"
1277 );
1278
1279 let dog = make_mention("the dog", 0);
1280 assert_eq!(
1281 coref.infer_mention_type(&dog),
1282 MentionType::Nominal,
1283 "Should detect 'the dog' as nominal"
1284 );
1285 }
1286
1287 // -------------------------------------------------------------------------
1288 // Transitivity and graph refinement
1289 // -------------------------------------------------------------------------
1290
1291 #[test]
1292 fn test_transitivity() {
1293 let config = GraphCorefConfig {
1294 max_iterations: 4,
1295 link_threshold: 0.3,
1296 transitivity_bonus: 0.3,
1297 per_shared_neighbor_bonus: 0.2,
1298 ..Default::default()
1299 };
1300 let coref = GraphCoref::with_config(config);
1301
1302 let mentions = vec![
1303 make_mention("John Smith", 0),
1304 make_mention("Smith", 30),
1305 make_mention("John Smith", 60),
1306 ];
1307
1308 let chains = coref.resolve(&mentions);
1309 assert_eq!(chains.len(), 1);
1310 assert_eq!(chains[0].mentions.len(), 3);
1311 }
1312
1313 #[test]
1314 fn test_convergence() {
1315 let coref = GraphCoref::new();
1316 let mentions = vec![
1317 make_mention("Apple", 0),
1318 make_mention("Apple", 50),
1319 make_mention("Microsoft", 100),
1320 ];
1321
1322 let (chains, stats) = coref.resolve_with_stats(&mentions);
1323
1324 assert!(stats.iterations <= 4);
1325 assert!(stats.converged || stats.iterations == 4);
1326 assert_eq!(chains.len(), 1);
1327 assert_eq!(stats.num_chains, 1);
1328 }
1329
1330 // -------------------------------------------------------------------------
1331 // CorefGraph tests
1332 // -------------------------------------------------------------------------
1333
1334 #[test]
1335 fn test_coref_graph_basics() {
1336 let mut graph = CorefGraph::new(5);
1337
1338 graph.add_edge(0, 1);
1339 graph.add_edge(1, 2);
1340
1341 assert!(graph.has_edge(0, 1));
1342 assert!(graph.has_edge(1, 0)); // Symmetric
1343 assert!(graph.has_edge(1, 2));
1344 assert!(!graph.has_edge(0, 2));
1345
1346 assert!(graph.transitively_connected(0, 2));
1347
1348 let clusters = graph.extract_clusters();
1349 assert_eq!(clusters.len(), 3); // {0,1,2}, {3}, {4}
1350
1351 let main_cluster = clusters.iter().find(|c| c.len() == 3).unwrap();
1352 assert!(main_cluster.contains(&0));
1353 assert!(main_cluster.contains(&1));
1354 assert!(main_cluster.contains(&2));
1355 }
1356
1357 #[test]
1358 fn test_shared_neighbors() {
1359 let mut graph = CorefGraph::new(4);
1360 graph.add_edge(0, 2);
1361 graph.add_edge(1, 2);
1362
1363 assert_eq!(graph.shared_neighbors(0, 1), 1);
1364 assert_eq!(graph.shared_neighbors(0, 3), 0);
1365 }
1366
1367 #[test]
1368 fn test_graph_self_loop_ignored() {
1369 let mut graph = CorefGraph::new(3);
1370 graph.add_edge(0, 0); // Self-loop
1371 assert!(!graph.has_edge(0, 0));
1372 assert_eq!(graph.edge_count(), 0);
1373 }
1374
1375 #[test]
1376 fn test_graph_out_of_bounds_ignored() {
1377 let mut graph = CorefGraph::new(3);
1378 graph.add_edge(0, 10); // Out of bounds
1379 assert_eq!(graph.edge_count(), 0);
1380 }
1381
1382 // -------------------------------------------------------------------------
1383 // Edge cases
1384 // -------------------------------------------------------------------------
1385
1386 #[test]
1387 fn test_empty_mention_filtered() {
1388 let coref = GraphCoref::new();
1389 let mentions = vec![
1390 make_mention("John", 0),
1391 Mention::new("", 10, 10), // Empty
1392 Mention::new(" ", 20, 23), // Whitespace only
1393 make_mention("John", 50),
1394 ];
1395
1396 let chains = coref.resolve(&mentions);
1397 assert_eq!(chains.len(), 1);
1398 assert_eq!(chains[0].mentions.len(), 2);
1399 }
1400
1401 #[test]
1402 fn test_distance_filter() {
1403 let config = GraphCorefConfig {
1404 max_distance: Some(100),
1405 ..Default::default()
1406 };
1407 let coref = GraphCoref::with_config(config);
1408
1409 let mentions = vec![make_mention("John", 0), make_mention("John", 200)];
1410
1411 let chains = coref.resolve(&mentions);
1412 assert!(chains.is_empty());
1413 }
1414
1415 #[test]
1416 fn test_stats_edge_history() {
1417 let coref = GraphCoref::new();
1418 let mentions = vec![
1419 make_mention("A", 0),
1420 make_mention("A", 10),
1421 make_mention("A", 20),
1422 ];
1423
1424 let (_, stats) = coref.resolve_with_stats(&mentions);
1425
1426 assert!(!stats.edge_history.is_empty());
1427 assert_eq!(stats.edge_history[0], 0);
1428 }
1429
1430 // -------------------------------------------------------------------------
1431 // Unicode / multilingual
1432 // -------------------------------------------------------------------------
1433
1434 #[test]
1435 fn test_unicode_cjk() {
1436 let coref = GraphCoref::new();
1437 let mentions = vec![
1438 make_mention("北京", 0),
1439 make_mention("北京", 20),
1440 make_mention("東京", 40),
1441 ];
1442
1443 let chains = coref.resolve(&mentions);
1444 assert_eq!(chains.len(), 1);
1445 assert!(chains[0].mentions.iter().all(|m| m.text == "北京"));
1446 }
1447
1448 #[test]
1449 fn test_unicode_diacritics() {
1450 let coref = GraphCoref::new();
1451 let mentions = vec![make_mention("François", 0), make_mention("François", 50)];
1452
1453 let chains = coref.resolve(&mentions);
1454 assert_eq!(chains.len(), 1);
1455 }
1456
1457 #[test]
1458 fn test_unicode_arabic_rtl() {
1459 let coref = GraphCoref::new();
1460 // Arabic: "Muhammad" repeated
1461 let mentions = vec![make_mention("محمد", 0), make_mention("محمد", 20)];
1462
1463 let chains = coref.resolve(&mentions);
1464 assert_eq!(chains.len(), 1);
1465 }
1466
1467 // -------------------------------------------------------------------------
1468 // Evaluation helper
1469 // -------------------------------------------------------------------------
1470
1471 #[test]
1472 fn test_chains_to_document() {
1473 let chain = CorefChain::new(vec![make_mention("John", 0), make_mention("he", 20)]);
1474
1475 let doc = chains_to_document("John went home. He slept.", vec![chain]);
1476
1477 assert_eq!(doc.chain_count(), 1);
1478 assert_eq!(doc.mention_count(), 2);
1479 }
1480
1481 // -------------------------------------------------------------------------
1482 // Co-occurrence seeding (SpanEIT-inspired)
1483 // -------------------------------------------------------------------------
1484
1485 #[test]
1486 fn test_cooccurrence_seeding_basic() {
1487 let mut graph = CorefGraph::new(3);
1488 let positions = vec![0, 50, 200]; // Character offsets
1489
1490 // Window of 100: should connect 0-1 (distance 50) but not 0-2 (distance 200)
1491 graph.seed_cooccurrence_edges(&positions, 100, None::<fn(usize, usize) -> bool>);
1492
1493 assert!(graph.has_edge(0, 1), "Close mentions should be connected");
1494 assert!(
1495 !graph.has_edge(0, 2),
1496 "Distant mentions should not be connected"
1497 );
1498 assert!(
1499 !graph.has_edge(1, 2),
1500 "Distant mentions should not be connected"
1501 );
1502 }
1503
1504 #[test]
1505 fn test_cooccurrence_seeding_with_scorer() {
1506 let mut graph = CorefGraph::new(3);
1507 let positions = vec![0, 50, 80];
1508
1509 // Custom scorer: only connect if both indices are even
1510 let scorer = |i: usize, j: usize| i.is_multiple_of(2) && j.is_multiple_of(2);
1511 graph.seed_cooccurrence_edges(&positions, 100, Some(scorer));
1512
1513 assert!(graph.has_edge(0, 2), "0 and 2 are both even");
1514 assert!(!graph.has_edge(0, 1), "1 is odd");
1515 assert!(!graph.has_edge(1, 2), "1 is odd");
1516 }
1517
1518 #[test]
1519 fn test_cooccurrence_seeding_empty() {
1520 let mut graph = CorefGraph::new(3);
1521 let positions: Vec<usize> = vec![];
1522
1523 graph.seed_cooccurrence_edges(&positions, 100, None::<fn(usize, usize) -> bool>);
1524
1525 assert!(graph.is_empty(), "Empty positions should create no edges");
1526 }
1527}