Skip to main content

anno/backends/
graph_coref.rs

1//! Graph-Based Coreference Resolution with Iterative Refinement.
2//!
3//! This module implements a graph-based approach to coreference resolution,
4//! inspired by the Graph-to-Graph Transformer (G2GT) architecture from
5//! Miculicich & Henderson (2022). The key insight: model coreference as a
6//! graph where nodes are mentions and edges are coref links, then iteratively
7//! refine predictions until convergence.
8//!
9//! # Historical Context
10//!
11//! Coreference resolution evolved through distinct paradigms:
12//!
13//! ```text
14//! 1995-2010  Rule-based: Hobbs algorithm, centering theory
15//! 2010-2016  Mention-pair: Classify (m_i, m_j) independently
16//! 2017       Lee et al.: End-to-end span-based, O(N⁴) complexity
17//! 2018       Lee et al.: Higher-order with representation refinement
18//! 2022       G2GT: Graph refinement with global decisions, O(N² × T)
19//! ```
20//!
21//! **The core problem with pairwise models**: Decisions are independent.
22//! If P(A~B)=0.9 and P(B~C)=0.9, transitivity implies A~C, but pairwise
23//! models can output P(A~C)=0.1. The G2GT approach addresses this by
24//! conditioning each iteration on the full predicted graph from the previous
25//! iteration, enabling global consistency.
26//!
27//! # Architecture
28//!
29//! ```text
30//! Input: Detected mentions M = [m₁, m₂, ..., mₙ]
31//!    ↓
32//! ┌─────────────────────────────────────────────────────────┐
33//! │ Iteration 0: Initialize empty graph G₀                  │
34//! │    - Nodes: all mentions                                │
35//! │    - Edges: none                                        │
36//! └─────────────────────────────────────────────────────────┘
37//!    ↓
38//! ┌─────────────────────────────────────────────────────────┐
39//! │ Iteration t: Refine graph Gₜ₋₁ → Gₜ                     │
40//! │    For each mention pair (mᵢ, mⱼ) where j < i:          │
41//! │    1. Compute pairwise score s(mᵢ, mⱼ)                  │
42//! │    2. Add graph context from Gₜ₋₁ (transitivity bonus)  │
43//! │    3. Update edge if score exceeds threshold            │
44//! └─────────────────────────────────────────────────────────┘
45//!    ↓ (repeat until Gₜ = Gₜ₋₁ or t = max_iterations)
46//! ┌─────────────────────────────────────────────────────────┐
47//! │ Extract clusters via connected components               │
48//! └─────────────────────────────────────────────────────────┘
49//!    ↓
50//! Output: Coreference chains (clusters)
51//! ```
52//!
53//! # Approximation vs. Full G2GT
54//!
55//! This implementation is a **heuristic approximation**, not a full neural model.
56//!
57//! | Aspect | G2GT (Miculicich 2022) | This Implementation |
58//! |--------|------------------------|---------------------|
59//! | **Graph nodes** | Tokens | Mentions (pre-detected) |
60//! | **Graph encoding** | Attention modification: `Lk = E(G)·Wk` | Explicit transitivity bonus |
61//! | **Pairwise scoring** | Learned neural scorer | String/head heuristics |
62//! | **Refinement** | Full neural re-prediction | Score adjustment |
63//! | **Training** | End-to-end backprop | None (heuristic) |
64//!
65//! ## What We Preserve
66//!
67//! The key insight from G2GT: **iterative refinement with graph conditioning**
68//! enables global consistency that independent pairwise models lack. Even with
69//! heuristic scoring, the refinement loop propagates transitivity constraints.
70//!
71//! ## What We Lose
72//!
73//! - **Learned representations**: G2GT embeds graph structure directly into
74//!   transformer attention. We approximate this with explicit bonuses.
75//! - **End-to-end optimization**: G2GT trains the full system. We use fixed heuristics.
76//! - **Token-level granularity**: G2GT operates on tokens; we operate on mentions.
77//!
78//! For production use with high accuracy requirements, consider a full neural
79//! implementation or the T5-based coreference in `crate::backends::coref_t5`.
80//!
81//! # Usage with MentionType
82//!
83//! For best results, provide mentions with `mention_type` set:
84//!
85//! ```rust,ignore
86//! use anno::backends::graph_coref::GraphCoref;
87//! use anno::eval::coref::{Mention, MentionType};
88//!
89//! // Properly annotated mentions work better
90//! let mut john = Mention::new("John", 0, 4);
91//! john.mention_type = Some(MentionType::Proper);
92//!
93//! let mut he = Mention::new("he", 20, 22);
94//! he.mention_type = Some(MentionType::Pronominal);
95//!
96//! let coref = GraphCoref::new();
97//! let chains = coref.resolve(&[john, he]);
98//! ```
99//!
100//! # Graph Initialization: Syntactic vs Semantic
101//!
102//! SpanEIT (Hossain et al. 2025) constructs a combined graph:
103//! - **Syntactic edges** (`E_syn`): From dependency parse (adjectival modifiers, etc.)
104//! - **Semantic edges** (`E_sem`): From co-occurrence statistics
105//!
106//! Use [`CorefGraph::seed_cooccurrence_edges`] to initialize with proximity-based
107//! priors before running iterative refinement.
108//!
109//! # References
110//!
111//! - Miculicich & Henderson (2022): "Graph Refinement for Coreference Resolution"
112//!   [arXiv:2203.16574](https://arxiv.org/abs/2203.16574)
113//! - Lee et al. (2017): "End-to-end Neural Coreference Resolution"
114//! - Lee et al. (2018): "Higher-Order Coreference Resolution"
115//! - Mohammadshahi & Henderson (2021): "Graph-to-Graph Transformer for Dependency Parsing"
116//! - Hossain et al. (2025): "SpanEIT: Dynamic Span Interaction and Graph-Aware Memory"
117//!   [arXiv:2509.11604](https://arxiv.org/abs/2509.11604)
118//!
119//! # Future Direction: Sheaf Neural Networks
120//!
121//! This implementation uses explicit transitivity bonuses to approximate global consistency.
122//! A more principled approach: **Sheaf Neural Networks** replace scalar edge weights with
123//! learned linear maps (restriction maps) and minimize the sheaf Dirichlet energy:
124//!
125//! ```text
126//! E(x) = Σ_{(u,v) ∈ E} || F(u→v) · x_u - F(v→u) · x_v ||²
127//! ```
128//!
129//! This enforces transitivity at the gradient level, not post-hoc. See:
130//! - `archive/geometric-2024-12/sheaf.rs` for stub implementation and trait definitions
131//! - Bodnar et al. (2023): "Neural Sheaf Diffusion" - NeurIPS
132//! - twitter-research/neural-sheaf-diffusion (Apache 2.0): reference implementation
133
134use anno_core::{CorefChain, Mention, MentionType};
135use std::collections::{HashMap, HashSet, VecDeque};
136use std::hash::{Hash, Hasher};
137
138// =============================================================================
139// Types
140// =============================================================================
141
142/// Edge type in the coreference graph.
143///
144/// Following G2GT's three-way classification:
145/// - 0 (None): No relationship
146/// - 1 (Mention): Within-mention link (not used here since we operate on mentions, not tokens)
147/// - 2 (Coref): Coreference link between mentions
148#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
149#[repr(u8)]
150pub enum EdgeType {
151    /// No link between mentions.
152    None = 0,
153    /// Coreference link (both mentions refer to the same entity).
154    Coref = 2,
155}
156
157/// A coreference graph representing mention relationships.
158///
159/// This is the core data structure for iterative refinement. Nodes are mention
160/// indices, edges are coreference links. The graph is stored as an adjacency
161/// set for O(1) edge lookup during refinement.
162///
163/// # Invariants
164///
165/// - Graph is symmetric: if edge(i,j) exists, edge(j,i) exists
166/// - No self-loops: edge(i,i) is always None
167/// - Indices are valid mention indices: 0 <= i,j < num_mentions
168///
169/// # Example
170///
171/// ```rust
172/// use anno::backends::graph_coref::CorefGraph;
173///
174/// let mut graph = CorefGraph::new(3);
175/// graph.add_edge(0, 1);
176/// graph.add_edge(1, 2);
177///
178/// assert!(graph.has_edge(0, 1));
179/// assert!(graph.transitively_connected(0, 2));  // via 0-1-2
180///
181/// let clusters = graph.extract_clusters();
182/// assert_eq!(clusters.len(), 1);  // All connected
183/// ```
184#[derive(Debug, Clone, PartialEq, Eq)]
185pub struct CorefGraph {
186    /// Number of mentions (nodes).
187    num_mentions: usize,
188    /// Adjacency set: (i, j) where i < j for canonical representation.
189    edges: HashSet<(usize, usize)>,
190}
191
192impl CorefGraph {
193    /// Create an empty coreference graph with the given number of mentions.
194    #[must_use]
195    pub fn new(num_mentions: usize) -> Self {
196        Self {
197            num_mentions,
198            edges: HashSet::new(),
199        }
200    }
201
202    /// Get the number of mentions (nodes) in the graph.
203    #[must_use]
204    pub fn num_mentions(&self) -> usize {
205        self.num_mentions
206    }
207
208    /// Add a coreference edge between two mentions.
209    ///
210    /// The edge is stored in canonical form (i < j) for consistency.
211    /// Self-loops and out-of-bounds indices are silently ignored.
212    pub fn add_edge(&mut self, i: usize, j: usize) {
213        if i == j || i >= self.num_mentions || j >= self.num_mentions {
214            return;
215        }
216        let (lo, hi) = if i < j { (i, j) } else { (j, i) };
217        self.edges.insert((lo, hi));
218    }
219
220    /// Remove a coreference edge between two mentions.
221    pub fn remove_edge(&mut self, i: usize, j: usize) {
222        let (lo, hi) = if i < j { (i, j) } else { (j, i) };
223        self.edges.remove(&(lo, hi));
224    }
225
226    /// Check if two mentions are directly linked.
227    #[must_use]
228    pub fn has_edge(&self, i: usize, j: usize) -> bool {
229        if i == j {
230            return false;
231        }
232        let (lo, hi) = if i < j { (i, j) } else { (j, i) };
233        self.edges.contains(&(lo, hi))
234    }
235
236    /// Get all neighbors (directly linked mentions) of a mention.
237    #[must_use]
238    pub fn neighbors(&self, i: usize) -> Vec<usize> {
239        let mut result = Vec::new();
240        for &(lo, hi) in &self.edges {
241            if lo == i {
242                result.push(hi);
243            } else if hi == i {
244                result.push(lo);
245            }
246        }
247        result
248    }
249
250    /// Count shared neighbors between two mentions.
251    ///
252    /// This is the basis for the transitivity bonus: if mentions i and j
253    /// share many neighbors in the current graph, they're likely coreferent.
254    ///
255    /// # G2GT Connection
256    ///
257    /// In the full G2GT model, shared structure is captured via graph-conditioned
258    /// attention. Here we approximate it by explicitly counting shared neighbors
259    /// and adding a proportional bonus to the pairwise score.
260    #[must_use]
261    pub fn shared_neighbors(&self, i: usize, j: usize) -> usize {
262        let neighbors_i: HashSet<usize> = self.neighbors(i).into_iter().collect();
263        let neighbors_j: HashSet<usize> = self.neighbors(j).into_iter().collect();
264        neighbors_i.intersection(&neighbors_j).count()
265    }
266
267    /// Check if two mentions are transitively connected.
268    ///
269    /// Uses BFS to find if there's a path from i to j through coreference links.
270    /// This is the closure property that ensures consistency: if A~B and B~C,
271    /// then A and C are transitively connected even without a direct edge.
272    #[must_use]
273    pub fn transitively_connected(&self, i: usize, j: usize) -> bool {
274        if i == j {
275            return true;
276        }
277        if self.has_edge(i, j) {
278            return true;
279        }
280
281        // BFS from i to find j
282        let mut visited = HashSet::new();
283        let mut queue = vec![i];
284        visited.insert(i);
285
286        while let Some(current) = queue.pop() {
287            for neighbor in self.neighbors(current) {
288                if neighbor == j {
289                    return true;
290                }
291                if visited.insert(neighbor) {
292                    queue.push(neighbor);
293                }
294            }
295        }
296
297        false
298    }
299
300    /// Extract connected components as clusters.
301    ///
302    /// Each connected component in the graph becomes a coreference chain.
303    /// Singleton mentions (no edges) are included as single-mention clusters.
304    #[must_use]
305    pub fn extract_clusters(&self) -> Vec<Vec<usize>> {
306        let mut visited = vec![false; self.num_mentions];
307        let mut clusters = Vec::new();
308
309        for start in 0..self.num_mentions {
310            if visited[start] {
311                continue;
312            }
313
314            // BFS to find all members of this component
315            let mut cluster = Vec::new();
316            let mut queue = vec![start];
317            visited[start] = true;
318
319            while let Some(current) = queue.pop() {
320                cluster.push(current);
321                for neighbor in self.neighbors(current) {
322                    if !visited[neighbor] {
323                        visited[neighbor] = true;
324                        queue.push(neighbor);
325                    }
326                }
327            }
328
329            cluster.sort_unstable();
330            clusters.push(cluster);
331        }
332
333        clusters
334    }
335
336    /// Get the number of edges in the graph.
337    #[must_use]
338    pub fn edge_count(&self) -> usize {
339        self.edges.len()
340    }
341
342    /// Check if graph is empty (no edges).
343    #[must_use]
344    pub fn is_empty(&self) -> bool {
345        self.edges.is_empty()
346    }
347
348    /// Seed the graph with co-occurrence priors based on mention proximity.
349    ///
350    /// This is inspired by SpanEIT (Hossain et al. 2025), which constructs a
351    /// semantic co-occurrence graph `G_sem` alongside the syntactic graph. The
352    /// insight: mentions that appear close together are more likely coreferent.
353    ///
354    /// # Arguments
355    ///
356    /// * `mention_positions` - Position of each mention (e.g., character offset)
357    /// * `window_size` - Maximum distance for co-occurrence (e.g., 100 chars)
358    /// * `scorer` - Optional scoring function; if None, uses constant weight
359    ///
360    /// # Example
361    ///
362    /// ```rust
363    /// use anno::backends::graph_coref::CorefGraph;
364    ///
365    /// let mut graph = CorefGraph::new(3);
366    /// let positions = vec![0, 50, 200];  // Character offsets
367    ///
368    /// // Seed edges for mentions within 100 chars of each other
369    /// // Type annotation needed when passing None for the scorer
370    /// graph.seed_cooccurrence_edges::<fn(usize, usize) -> bool>(&positions, 100, None);
371    ///
372    /// assert!(graph.has_edge(0, 1));   // 50 < 100
373    /// assert!(!graph.has_edge(0, 2));  // 200 > 100
374    /// ```
375    ///
376    /// # Research Background
377    ///
378    /// SpanEIT constructs `G = (V, E_syn ∪ E_sem)` where:
379    /// - `E_syn` = syntactic dependency edges
380    /// - `E_sem` = co-occurrence edges (this method)
381    ///
382    /// The combined graph is processed by GAT layers for context-aware embeddings.
383    /// For anno's heuristic approach, we add these edges as initial priors before
384    /// iterative refinement.
385    pub fn seed_cooccurrence_edges<F>(
386        &mut self,
387        mention_positions: &[usize],
388        window_size: usize,
389        scorer: Option<F>,
390    ) where
391        F: Fn(usize, usize) -> bool,
392    {
393        for i in 0..self.num_mentions {
394            for j in (i + 1)..self.num_mentions {
395                if i >= mention_positions.len() || j >= mention_positions.len() {
396                    continue;
397                }
398
399                let pos_i = mention_positions[i];
400                let pos_j = mention_positions[j];
401                let distance = pos_i.abs_diff(pos_j);
402
403                if distance <= window_size {
404                    let should_add = match scorer.as_ref() {
405                        None => true,
406                        Some(f) => f(i, j),
407                    };
408                    if should_add {
409                        self.add_edge(i, j);
410                    }
411                }
412            }
413        }
414    }
415}
416
417// =============================================================================
418// Configuration
419// =============================================================================
420
421/// Configuration for graph-based coreference resolution.
422///
423/// These parameters control the iterative refinement process. The defaults
424/// are based on findings from Miculicich & Henderson (2022):
425///
426/// - `max_iterations = 4`: Paper found T=4 optimal; more iterations don't help
427/// - `link_threshold = 0.5`: Standard classification threshold
428/// - `transitivity_bonus = 0.15`: Reward for transitive consistency
429///
430/// # Tuning Guide
431///
432/// | Parameter | Higher Value | Lower Value |
433/// |-----------|--------------|-------------|
434/// | `link_threshold` | Fewer, more confident links | More links, potential noise |
435/// | `transitivity_bonus` | Stronger clustering effect | More independent decisions |
436/// | `max_iterations` | More refinement passes | Faster, less propagation |
437/// | `head_match_weight` | Trust head matches more | Rely more on full string |
438#[derive(Debug, Clone)]
439pub struct GraphCorefConfig {
440    /// Maximum refinement iterations before stopping.
441    ///
442    /// The G2GT paper found T=4 to be optimal on CoNLL 2012. Fewer iterations
443    /// leave potential coreference links undiscovered; more iterations don't
444    /// improve results and waste computation.
445    pub max_iterations: usize,
446
447    /// Minimum score to create a coreference link.
448    ///
449    /// A pair (mᵢ, mⱼ) is linked if score(mᵢ, mⱼ) + context_bonus > threshold.
450    pub link_threshold: f64,
451
452    /// Bonus added for transitive consistency.
453    ///
454    /// If mentions A and B share neighbors in the current graph (i.e., both
455    /// are already linked to some common mention C), this bonus is added to
456    /// encourage the model to also link A and B directly.
457    ///
458    /// **Note**: This is our heuristic approximation of G2GT's graph-conditioned
459    /// attention. The full G2GT model encodes graph structure as:
460    /// ```text
461    /// Attention(Q,K,V,Lk,Lv) = softmax(Q·(K+Lk)/√d)·(V+Lv)
462    /// where Lk = E(G^{t-1})·Wk
463    /// ```
464    /// We approximate this by explicit score adjustment rather than attention modification.
465    pub transitivity_bonus: f64,
466
467    /// Bonus for each shared neighbor.
468    ///
469    /// Scaled by the number of shared neighbors: total_bonus = shared_count * per_neighbor_bonus
470    pub per_shared_neighbor_bonus: f64,
471
472    /// Weight for string similarity in pairwise scoring.
473    pub string_similarity_weight: f64,
474
475    /// Weight for head word matching.
476    ///
477    /// The G2GT paper emphasizes head-based matching. When mentions have
478    /// `head_start`/`head_end` set, head matching is used. Otherwise falls
479    /// back to last word heuristic.
480    pub head_match_weight: f64,
481
482    /// Weight for distance penalty in pairwise scoring.
483    pub distance_weight: f64,
484
485    /// Maximum character distance to consider (mentions further apart are not linked).
486    pub max_distance: Option<usize>,
487
488    /// Include singletons (mentions with no coreference) in output.
489    ///
490    /// Default: false (only return multi-mention chains).
491    /// Set to true for evaluation against datasets that include singletons.
492    pub include_singletons: bool,
493
494    /// Bonus when a pronoun links to a proper noun.
495    ///
496    /// Pronouns are weak signals alone but should link to antecedents.
497    pub pronoun_proper_bonus: f64,
498
499    /// Optional early-stop controls for iterative refinement.
500    ///
501    /// GraphCoref already stops when it reaches a fixed point (Gₜ == Gₜ₋₁).
502    /// This option additionally stops on:
503    /// - **Cycle detection** (e.g., A→B→A oscillation across iterations)
504    /// - **Stagnation** (edge count stops changing for N iterations)
505    ///
506    /// This is an analogue of “overthinking” / redundancy detection (CoRE-Eval),
507    /// implemented using observable signals (graph structure) rather than hidden states.
508    pub early_stop: Option<GraphCorefEarlyStopConfig>,
509}
510
511/// Configuration for early stopping in iterative graph refinement.
512#[derive(Debug, Clone)]
513pub struct GraphCorefEarlyStopConfig {
514    /// Stop if we detect a repeated graph state (cycle) within the configured history.
515    pub detect_cycles: bool,
516    /// How many past graph fingerprints to remember (0 = unbounded).
517    pub cycle_history: usize,
518    /// Stop if the edge count hasn't changed for this many consecutive iterations.
519    pub stagnation_patience: usize,
520}
521
522impl Default for GraphCorefEarlyStopConfig {
523    fn default() -> Self {
524        Self {
525            detect_cycles: true,
526            cycle_history: 8,
527            stagnation_patience: 2,
528        }
529    }
530}
531
532impl Default for GraphCorefConfig {
533    fn default() -> Self {
534        Self {
535            max_iterations: 4,
536            link_threshold: 0.5,
537            transitivity_bonus: 0.15,
538            per_shared_neighbor_bonus: 0.1,
539            string_similarity_weight: 1.0,
540            head_match_weight: 0.5,
541            distance_weight: 0.05,
542            max_distance: Some(1000),
543            include_singletons: false,
544            pronoun_proper_bonus: 0.3,
545            early_stop: None,
546        }
547    }
548}
549
550// =============================================================================
551// Main Implementation
552// =============================================================================
553
554/// Graph-based coreference resolver with iterative refinement.
555///
556/// This implements a heuristic version of the G2GT architecture, preserving
557/// the key insight that iterative graph refinement enables global consistency
558/// in coreference decisions.
559///
560/// # Algorithm
561///
562/// 1. **Initialize**: Empty graph with mentions as nodes
563/// 2. **Iterate**: For each mention pair, compute score with graph context
564/// 3. **Update**: Add/remove edges based on threshold
565/// 4. **Converge**: Stop when graph unchanged or max iterations reached
566/// 5. **Extract**: Connected components become coreference chains
567///
568/// # Complexity
569///
570/// - Time: O(N² × T) where N = mentions, T = iterations (typically 4)
571/// - Space: O(N²) for adjacency representation
572///
573/// Compare to Lee et al. (2017): O(N⁴) for full span enumeration.
574///
575/// # Feature Usage
576///
577/// The resolver uses available `Mention` fields when present:
578///
579/// | Field | Used For | Fallback |
580/// |-------|----------|----------|
581/// | `mention_type` | Pronoun detection, type compatibility | Heuristic detection |
582/// | `head_start`/`head_end` | Head word matching | Last word of mention |
583/// | `entity_type` | Type compatibility check | Ignored |
584#[derive(Debug, Clone)]
585pub struct GraphCoref {
586    config: GraphCorefConfig,
587}
588
589impl GraphCoref {
590    /// Create a new graph coref resolver with default configuration.
591    #[must_use]
592    pub fn new() -> Self {
593        Self::with_config(GraphCorefConfig::default())
594    }
595
596    /// Create with custom configuration.
597    #[must_use]
598    pub fn with_config(config: GraphCorefConfig) -> Self {
599        Self { config }
600    }
601
602    fn graph_fingerprint(graph: &CorefGraph) -> u64 {
603        // Order-independent fingerprint of the edge set.
604        // We sort per-edge hashes to avoid HashSet iteration nondeterminism.
605        let mut edge_hashes: Vec<u64> = graph
606            .edges
607            .iter()
608            .map(|e| {
609                let mut h = std::collections::hash_map::DefaultHasher::new();
610                e.hash(&mut h);
611                h.finish()
612            })
613            .collect();
614        edge_hashes.sort_unstable();
615
616        let mut h = std::collections::hash_map::DefaultHasher::new();
617        graph.num_mentions.hash(&mut h);
618        for eh in edge_hashes {
619            eh.hash(&mut h);
620        }
621        h.finish()
622    }
623
624    /// Resolve coreferences among mentions using iterative graph refinement.
625    ///
626    /// # Arguments
627    ///
628    /// * `mentions` - Pre-detected mentions (from NER or mention detector).
629    ///   For best results, set `mention_type` on each mention.
630    ///
631    /// # Returns
632    ///
633    /// Coreference chains (clusters) where each chain contains mentions
634    /// referring to the same entity. By default, singletons are filtered;
635    /// set `config.include_singletons = true` to include them.
636    ///
637    /// # Panics
638    ///
639    /// Does not panic. Empty input returns empty output.
640    ///
641    /// # Example
642    ///
643    /// ```rust
644    /// use anno::backends::graph_coref::GraphCoref;
645    /// use anno::{Mention, MentionType};
646    ///
647    /// let coref = GraphCoref::new();
648    ///
649    /// let mut john = Mention::new("John", 0, 4);
650    /// john.mention_type = Some(MentionType::Proper);
651    ///
652    /// let mut he = Mention::new("he", 20, 22);
653    /// he.mention_type = Some(MentionType::Pronominal);
654    ///
655    /// let chains = coref.resolve(&[john, he]);
656    /// ```
657    #[must_use]
658    pub fn resolve(&self, mentions: &[Mention]) -> Vec<CorefChain> {
659        if mentions.is_empty() {
660            return vec![];
661        }
662
663        // Validate mentions (filter empty/invalid)
664        let valid_mentions: Vec<&Mention> = mentions
665            .iter()
666            .filter(|m| !m.text.trim().is_empty() && m.start < m.end)
667            .collect();
668
669        if valid_mentions.is_empty() {
670            return vec![];
671        }
672
673        // Initialize empty graph
674        let mut graph = CorefGraph::new(valid_mentions.len());
675
676        // Iterative refinement
677        let mut stagnation: usize = 0;
678        let mut last_edge_count = graph.edge_count();
679        let mut seen: HashMap<u64, usize> = HashMap::new();
680        let mut history: VecDeque<u64> = VecDeque::new();
681        if let Some(cfg) = &self.config.early_stop {
682            if cfg.detect_cycles {
683                let fp0 = Self::graph_fingerprint(&graph);
684                seen.insert(fp0, 0);
685                history.push_back(fp0);
686            }
687        }
688
689        for iteration in 0..self.config.max_iterations {
690            let prev_graph = graph.clone();
691            graph = self.refine_iteration(&valid_mentions, &graph);
692
693            // Check convergence
694            if graph == prev_graph {
695                break;
696            }
697
698            // Optional early stop: stagnation / cycles.
699            if let Some(cfg) = &self.config.early_stop {
700                // Stagnation: edge count stops changing.
701                let ec = graph.edge_count();
702                if ec == last_edge_count {
703                    stagnation += 1;
704                } else {
705                    stagnation = 0;
706                    last_edge_count = ec;
707                }
708                if cfg.stagnation_patience > 0 && stagnation >= cfg.stagnation_patience {
709                    break;
710                }
711
712                if cfg.detect_cycles {
713                    let fp = Self::graph_fingerprint(&graph);
714                    if seen.contains_key(&fp) {
715                        break;
716                    }
717                    seen.insert(fp, iteration + 1);
718                    history.push_back(fp);
719                    if cfg.cycle_history > 0 {
720                        while history.len() > cfg.cycle_history {
721                            if let Some(old) = history.pop_front() {
722                                seen.remove(&old);
723                            }
724                        }
725                    }
726                }
727            }
728        }
729
730        // Extract clusters and convert to CorefChains
731        self.graph_to_chains(&graph, &valid_mentions)
732    }
733
734    /// Perform one iteration of graph refinement.
735    ///
736    /// For each mention pair, computes a score incorporating:
737    /// 1. Base pairwise similarity (string match, head match, type compatibility)
738    /// 2. Graph context (transitivity bonus from shared neighbors)
739    ///
740    /// Edges are added if score exceeds threshold.
741    fn refine_iteration(&self, mentions: &[&Mention], prev_graph: &CorefGraph) -> CorefGraph {
742        let mut new_graph = CorefGraph::new(mentions.len());
743
744        for i in 0..mentions.len() {
745            for j in 0..i {
746                // Distance filter
747                if let Some(max_dist) = self.config.max_distance {
748                    let dist = mentions[i].start.saturating_sub(mentions[j].end);
749                    if dist > max_dist {
750                        continue;
751                    }
752                }
753
754                // Compute score with graph context
755                let base_score = self.pairwise_score(mentions[i], mentions[j]);
756                let context_bonus = self.graph_context_bonus(i, j, prev_graph);
757                let total_score = base_score + context_bonus;
758
759                if total_score > self.config.link_threshold {
760                    new_graph.add_edge(i, j);
761                }
762            }
763        }
764
765        new_graph
766    }
767
768    /// Compute base pairwise similarity score between two mentions.
769    ///
770    /// Uses multiple signals:
771    /// - Exact string match (highest weight)
772    /// - Substring containment
773    /// - Head word match (uses `head_start`/`head_end` if available)
774    /// - Pronoun-to-proper linking (uses `mention_type` if available)
775    /// - Distance penalty
776    fn pairwise_score(&self, m1: &Mention, m2: &Mention) -> f64 {
777        let mut score = 0.0;
778
779        let t1 = m1.text.to_lowercase();
780        let t2 = m2.text.to_lowercase();
781
782        // Exact match (strongest signal)
783        if t1 == t2 {
784            score += self.config.string_similarity_weight * 1.0;
785        }
786        // Substring containment
787        else if t1.contains(&t2) || t2.contains(&t1) {
788            score += self.config.string_similarity_weight * 0.6;
789        }
790        // Head word match
791        else {
792            let h1 = self.get_head_text(m1);
793            let h2 = self.get_head_text(m2);
794            if !h1.is_empty() && h1.to_lowercase() == h2.to_lowercase() {
795                score += self.config.head_match_weight;
796            }
797        }
798
799        // Mention type compatibility
800        score += self.type_compatibility_score(m1, m2);
801
802        // Distance penalty (log scale)
803        let distance = m1.start.abs_diff(m2.end).min(m2.start.abs_diff(m1.end));
804        if distance > 0 {
805            score -= self.config.distance_weight * (distance as f64).ln();
806        }
807
808        score
809    }
810
811    /// Get head text for a mention.
812    ///
813    /// Uses `head_start`/`head_end` if available, otherwise falls back to
814    /// last word heuristic (common in English NPs where head is rightmost).
815    fn get_head_text<'a>(&self, mention: &'a Mention) -> &'a str {
816        // Use explicit head span if available
817        if let (Some(head_start), Some(head_end)) = (mention.head_start, mention.head_end) {
818            // Head offsets are relative to document, need to extract from text
819            // This is complex; fall back to heuristic for now
820            // In a full implementation, we'd have the document text available
821            let _ = (head_start, head_end);
822        }
823
824        // Fallback: last word (head-final assumption for English NPs)
825        mention.text.split_whitespace().last().unwrap_or("")
826    }
827
828    /// Compute type compatibility score between two mentions.
829    ///
830    /// Uses `MentionType` field if available, otherwise uses heuristics.
831    fn type_compatibility_score(&self, m1: &Mention, m2: &Mention) -> f64 {
832        let type1 = m1
833            .mention_type
834            .unwrap_or_else(|| self.infer_mention_type(m1));
835        let type2 = m2
836            .mention_type
837            .unwrap_or_else(|| self.infer_mention_type(m2));
838
839        match (type1, type2) {
840            // Pronoun linking to proper noun: boost
841            (MentionType::Pronominal, MentionType::Proper)
842            | (MentionType::Proper, MentionType::Pronominal) => self.config.pronoun_proper_bonus,
843
844            // Pronoun linking to nominal: smaller boost
845            (MentionType::Pronominal, MentionType::Nominal)
846            | (MentionType::Nominal, MentionType::Pronominal) => {
847                self.config.pronoun_proper_bonus * 0.5
848            }
849
850            // Same type: neutral
851            _ if type1 == type2 => 0.0,
852
853            // Different non-pronoun types: slight penalty
854            _ => -0.1,
855        }
856    }
857
858    /// Infer mention type from text when not explicitly set.
859    ///
860    /// This is a fallback heuristic. For best results, set `mention_type`
861    /// on mentions before calling `resolve()`.
862    fn infer_mention_type(&self, mention: &Mention) -> MentionType {
863        let text_lower = mention.text.to_lowercase();
864
865        // Check for pronouns
866        const PRONOUNS: &[&str] = &[
867            "i",
868            "me",
869            "my",
870            "mine",
871            "myself",
872            "you",
873            "your",
874            "yours",
875            "yourself",
876            "yourselves",
877            "he",
878            "him",
879            "his",
880            "himself",
881            "she",
882            "her",
883            "hers",
884            "herself",
885            "it",
886            "its",
887            "itself",
888            "we",
889            "us",
890            "our",
891            "ours",
892            "ourselves",
893            "they",
894            "them",
895            "their",
896            "theirs",
897            "themselves",
898            "who",
899            "whom",
900            "whose",
901            "which",
902            "that",
903            "this",
904            "these",
905            "those",
906        ];
907
908        if PRONOUNS.contains(&text_lower.as_str()) {
909            return MentionType::Pronominal;
910        }
911
912        // Check for proper noun (starts with uppercase, not sentence-initial heuristic)
913        let first_char = mention.text.chars().next();
914        if first_char.is_some_and(|c| c.is_uppercase()) {
915            // Additional check: not a common word
916            let common_words = ["the", "a", "an", "this", "that", "these", "those"];
917            if !common_words.contains(&text_lower.as_str()) {
918                return MentionType::Proper;
919            }
920        }
921
922        // Default to nominal
923        MentionType::Nominal
924    }
925
926    /// Compute graph context bonus based on previous iteration's structure.
927    ///
928    /// This is our approximation of G2GT's graph-conditioned attention.
929    ///
930    /// # Transitivity Bonus
931    ///
932    /// If A~C and B~C in the previous graph, A and B should likely be linked.
933    /// We add a bonus proportional to the number of shared neighbors.
934    ///
935    /// # Already Connected Bonus
936    ///
937    /// If A and B are already transitively connected (through a chain of
938    /// coreference links), add a bonus to preserve and strengthen the connection.
939    fn graph_context_bonus(&self, i: usize, j: usize, prev_graph: &CorefGraph) -> f64 {
940        let mut bonus = 0.0;
941
942        // Bonus for shared neighbors (transitivity signal)
943        let shared = prev_graph.shared_neighbors(i, j);
944        bonus += (shared as f64) * self.config.per_shared_neighbor_bonus;
945
946        // Bonus if already transitively connected
947        if prev_graph.transitively_connected(i, j) {
948            bonus += self.config.transitivity_bonus;
949        }
950
951        bonus
952    }
953
954    /// Convert graph clusters to CorefChain format.
955    fn graph_to_chains(&self, graph: &CorefGraph, mentions: &[&Mention]) -> Vec<CorefChain> {
956        let clusters = graph.extract_clusters();
957
958        clusters
959            .into_iter()
960            .filter(|cluster| self.config.include_singletons || cluster.len() > 1)
961            .enumerate()
962            .map(|(id, indices)| {
963                let chain_mentions: Vec<Mention> = indices
964                    .into_iter()
965                    .map(|i| (*mentions[i]).clone())
966                    .collect();
967
968                let mut chain = CorefChain::new(chain_mentions);
969                chain.cluster_id = Some((id as u64).into());
970
971                // Set entity type from first proper mention
972                chain.entity_type = chain
973                    .mentions
974                    .iter()
975                    .find(|m| m.mention_type == Some(MentionType::Proper))
976                    .and_then(|m| m.entity_type.clone());
977
978                chain
979            })
980            .collect()
981    }
982
983    /// Get configuration.
984    #[must_use]
985    pub fn config(&self) -> &GraphCorefConfig {
986        &self.config
987    }
988}
989
990impl Default for GraphCoref {
991    fn default() -> Self {
992        Self::new()
993    }
994}
995
996// =============================================================================
997// Metrics and Diagnostics
998// =============================================================================
999
1000/// Statistics from a graph coref run for debugging and analysis.
1001///
1002/// Use [`GraphCoref::resolve_with_stats`] to get these alongside results.
1003#[derive(Debug, Clone, Default)]
1004pub struct GraphCorefStats {
1005    /// Number of iterations until convergence (1 to max_iterations).
1006    pub iterations: usize,
1007    /// Number of edges in final graph.
1008    pub final_edges: usize,
1009    /// Number of clusters (including singletons).
1010    pub num_clusters: usize,
1011    /// Number of non-singleton clusters.
1012    pub num_chains: usize,
1013    /// Per-iteration edge counts, starting from 0.
1014    pub edge_history: Vec<usize>,
1015    /// Whether the algorithm converged before max_iterations.
1016    pub converged: bool,
1017    /// Whether we stopped early for a non-fixed-point reason (cycle/stagnation).
1018    pub early_stopped: bool,
1019    /// Cycle detected (graph fingerprint repeated).
1020    pub cycle_detected: bool,
1021    /// Stagnation detected (edge count stopped changing).
1022    pub stagnation_detected: bool,
1023}
1024
1025impl GraphCoref {
1026    /// Resolve coreferences and return detailed statistics.
1027    ///
1028    /// Useful for debugging, tuning parameters, and understanding convergence.
1029    ///
1030    /// # Example
1031    ///
1032    /// ```rust
1033    /// use anno::backends::graph_coref::GraphCoref;
1034    /// use anno::Mention;
1035    ///
1036    /// let coref = GraphCoref::new();
1037    /// let mentions = vec![
1038    ///     Mention::new("John", 0, 4),
1039    ///     Mention::new("John", 50, 54),
1040    /// ];
1041    ///
1042    /// let (chains, stats) = coref.resolve_with_stats(&mentions);
1043    /// println!("Converged in {} iterations", stats.iterations);
1044    /// println!("Edge history: {:?}", stats.edge_history);
1045    /// ```
1046    #[must_use]
1047    pub fn resolve_with_stats(&self, mentions: &[Mention]) -> (Vec<CorefChain>, GraphCorefStats) {
1048        let mut stats = GraphCorefStats::default();
1049
1050        if mentions.is_empty() {
1051            return (vec![], stats);
1052        }
1053
1054        let valid_mentions: Vec<&Mention> = mentions
1055            .iter()
1056            .filter(|m| !m.text.trim().is_empty() && m.start < m.end)
1057            .collect();
1058
1059        if valid_mentions.is_empty() {
1060            return (vec![], stats);
1061        }
1062
1063        let mut graph = CorefGraph::new(valid_mentions.len());
1064        stats.edge_history.push(0);
1065
1066        let mut stagnation: usize = 0;
1067        let mut last_edge_count = graph.edge_count();
1068        let mut seen: HashMap<u64, usize> = HashMap::new();
1069        let mut history: VecDeque<u64> = VecDeque::new();
1070        if let Some(cfg) = &self.config.early_stop {
1071            if cfg.detect_cycles {
1072                let fp0 = Self::graph_fingerprint(&graph);
1073                seen.insert(fp0, 0);
1074                history.push_back(fp0);
1075            }
1076        }
1077
1078        for iteration in 0..self.config.max_iterations {
1079            let prev_graph = graph.clone();
1080            graph = self.refine_iteration(&valid_mentions, &graph);
1081            stats.edge_history.push(graph.edge_count());
1082            stats.iterations = iteration + 1;
1083
1084            if graph == prev_graph {
1085                stats.converged = true;
1086                break;
1087            }
1088
1089            if let Some(cfg) = &self.config.early_stop {
1090                let ec = graph.edge_count();
1091                if ec == last_edge_count {
1092                    stagnation += 1;
1093                } else {
1094                    stagnation = 0;
1095                    last_edge_count = ec;
1096                }
1097                if cfg.stagnation_patience > 0 && stagnation >= cfg.stagnation_patience {
1098                    stats.early_stopped = true;
1099                    stats.stagnation_detected = true;
1100                    break;
1101                }
1102
1103                if cfg.detect_cycles {
1104                    let fp = Self::graph_fingerprint(&graph);
1105                    if seen.contains_key(&fp) {
1106                        stats.early_stopped = true;
1107                        stats.cycle_detected = true;
1108                        break;
1109                    }
1110                    seen.insert(fp, iteration + 1);
1111                    history.push_back(fp);
1112                    if cfg.cycle_history > 0 {
1113                        while history.len() > cfg.cycle_history {
1114                            if let Some(old) = history.pop_front() {
1115                                seen.remove(&old);
1116                            }
1117                        }
1118                    }
1119                }
1120            }
1121        }
1122
1123        let clusters = graph.extract_clusters();
1124        stats.final_edges = graph.edge_count();
1125        stats.num_clusters = clusters.len();
1126        stats.num_chains = clusters.iter().filter(|c| c.len() > 1).count();
1127
1128        let chains = self.graph_to_chains(&graph, &valid_mentions);
1129        (chains, stats)
1130    }
1131}
1132
1133// =============================================================================
1134// Evaluation Helpers
1135// =============================================================================
1136
1137/// Convert GraphCoref output to format suitable for CoNLL evaluation.
1138///
1139/// This produces a `CorefDocument` that can be evaluated with standard
1140/// coreference metrics (MUC, B³, CEAF, LEA).
1141///
1142/// # Example
1143///
1144/// ```rust
1145/// use anno::backends::graph_coref::{GraphCoref, chains_to_document};
1146/// use anno::Mention;
1147///
1148/// let coref = GraphCoref::new();
1149/// let mentions = vec![
1150///     Mention::new("John", 0, 4),
1151///     Mention::new("he", 20, 22),
1152/// ];
1153///
1154/// let chains = coref.resolve(&mentions);
1155/// let doc = chains_to_document("John went to work. He was late.", chains);
1156/// ```
1157pub fn chains_to_document(
1158    text: impl Into<String>,
1159    chains: Vec<CorefChain>,
1160) -> anno_core::CorefDocument {
1161    anno_core::CorefDocument::new(text, chains)
1162}
1163
1164// =============================================================================
1165// Tests
1166// =============================================================================
1167
1168#[cfg(test)]
1169mod tests {
1170    use super::*;
1171
1172    fn make_mention(text: &str, start: usize) -> Mention {
1173        Mention::new(text, start, start + text.chars().count())
1174    }
1175
1176    fn make_typed_mention(text: &str, start: usize, mention_type: MentionType) -> Mention {
1177        Mention::with_type(text, start, start + text.chars().count(), mention_type)
1178    }
1179
1180    // -------------------------------------------------------------------------
1181    // Basic functionality
1182    // -------------------------------------------------------------------------
1183
1184    #[test]
1185    fn test_empty_input() {
1186        let coref = GraphCoref::new();
1187        let chains = coref.resolve(&[]);
1188        assert!(chains.is_empty());
1189    }
1190
1191    #[test]
1192    fn test_single_mention() {
1193        let coref = GraphCoref::new();
1194        let mentions = vec![make_mention("John", 0)];
1195        let chains = coref.resolve(&mentions);
1196        assert!(
1197            chains.is_empty(),
1198            "Single mention should be filtered as singleton"
1199        );
1200    }
1201
1202    #[test]
1203    fn test_single_mention_with_singletons() {
1204        let config = GraphCorefConfig {
1205            include_singletons: true,
1206            ..Default::default()
1207        };
1208        let coref = GraphCoref::with_config(config);
1209        let mentions = vec![make_mention("John", 0)];
1210        let chains = coref.resolve(&mentions);
1211        assert_eq!(chains.len(), 1, "Should include singleton when configured");
1212    }
1213
1214    #[test]
1215    fn test_exact_match_linking() {
1216        let coref = GraphCoref::new();
1217        let mentions = vec![make_mention("John", 0), make_mention("John", 50)];
1218
1219        let chains = coref.resolve(&mentions);
1220        assert_eq!(chains.len(), 1);
1221        assert_eq!(chains[0].mentions.len(), 2);
1222    }
1223
1224    #[test]
1225    fn test_substring_linking() {
1226        let config = GraphCorefConfig {
1227            link_threshold: 0.4,
1228            ..Default::default()
1229        };
1230        let coref = GraphCoref::with_config(config);
1231        let mentions = vec![make_mention("Marie Curie", 0), make_mention("Curie", 50)];
1232
1233        let chains = coref.resolve(&mentions);
1234        assert_eq!(chains.len(), 1);
1235        assert_eq!(chains[0].mentions.len(), 2);
1236    }
1237
1238    // -------------------------------------------------------------------------
1239    // MentionType usage
1240    // -------------------------------------------------------------------------
1241
1242    #[test]
1243    fn test_typed_pronoun_linking() {
1244        let config = GraphCorefConfig {
1245            link_threshold: 0.2,
1246            distance_weight: 0.0, // Disable distance penalty to isolate type signal
1247            ..Default::default()
1248        };
1249        let coref = GraphCoref::with_config(config);
1250
1251        let mentions = vec![
1252            make_typed_mention("Marie", 0, MentionType::Proper),
1253            make_typed_mention("she", 20, MentionType::Pronominal),
1254        ];
1255
1256        let chains = coref.resolve(&mentions);
1257        assert_eq!(chains.len(), 1, "Typed pronoun should link to proper noun");
1258    }
1259
1260    #[test]
1261    fn test_inferred_pronoun_detection() {
1262        let coref = GraphCoref::new();
1263
1264        // Create mention without type - should be inferred
1265        let he = make_mention("he", 0);
1266        assert_eq!(
1267            coref.infer_mention_type(&he),
1268            MentionType::Pronominal,
1269            "Should detect 'he' as pronoun"
1270        );
1271
1272        let john = make_mention("John", 0);
1273        assert_eq!(
1274            coref.infer_mention_type(&john),
1275            MentionType::Proper,
1276            "Should detect 'John' as proper"
1277        );
1278
1279        let dog = make_mention("the dog", 0);
1280        assert_eq!(
1281            coref.infer_mention_type(&dog),
1282            MentionType::Nominal,
1283            "Should detect 'the dog' as nominal"
1284        );
1285    }
1286
1287    // -------------------------------------------------------------------------
1288    // Transitivity and graph refinement
1289    // -------------------------------------------------------------------------
1290
1291    #[test]
1292    fn test_transitivity() {
1293        let config = GraphCorefConfig {
1294            max_iterations: 4,
1295            link_threshold: 0.3,
1296            transitivity_bonus: 0.3,
1297            per_shared_neighbor_bonus: 0.2,
1298            ..Default::default()
1299        };
1300        let coref = GraphCoref::with_config(config);
1301
1302        let mentions = vec![
1303            make_mention("John Smith", 0),
1304            make_mention("Smith", 30),
1305            make_mention("John Smith", 60),
1306        ];
1307
1308        let chains = coref.resolve(&mentions);
1309        assert_eq!(chains.len(), 1);
1310        assert_eq!(chains[0].mentions.len(), 3);
1311    }
1312
1313    #[test]
1314    fn test_convergence() {
1315        let coref = GraphCoref::new();
1316        let mentions = vec![
1317            make_mention("Apple", 0),
1318            make_mention("Apple", 50),
1319            make_mention("Microsoft", 100),
1320        ];
1321
1322        let (chains, stats) = coref.resolve_with_stats(&mentions);
1323
1324        assert!(stats.iterations <= 4);
1325        assert!(stats.converged || stats.iterations == 4);
1326        assert_eq!(chains.len(), 1);
1327        assert_eq!(stats.num_chains, 1);
1328    }
1329
1330    // -------------------------------------------------------------------------
1331    // CorefGraph tests
1332    // -------------------------------------------------------------------------
1333
1334    #[test]
1335    fn test_coref_graph_basics() {
1336        let mut graph = CorefGraph::new(5);
1337
1338        graph.add_edge(0, 1);
1339        graph.add_edge(1, 2);
1340
1341        assert!(graph.has_edge(0, 1));
1342        assert!(graph.has_edge(1, 0)); // Symmetric
1343        assert!(graph.has_edge(1, 2));
1344        assert!(!graph.has_edge(0, 2));
1345
1346        assert!(graph.transitively_connected(0, 2));
1347
1348        let clusters = graph.extract_clusters();
1349        assert_eq!(clusters.len(), 3); // {0,1,2}, {3}, {4}
1350
1351        let main_cluster = clusters.iter().find(|c| c.len() == 3).unwrap();
1352        assert!(main_cluster.contains(&0));
1353        assert!(main_cluster.contains(&1));
1354        assert!(main_cluster.contains(&2));
1355    }
1356
1357    #[test]
1358    fn test_shared_neighbors() {
1359        let mut graph = CorefGraph::new(4);
1360        graph.add_edge(0, 2);
1361        graph.add_edge(1, 2);
1362
1363        assert_eq!(graph.shared_neighbors(0, 1), 1);
1364        assert_eq!(graph.shared_neighbors(0, 3), 0);
1365    }
1366
1367    #[test]
1368    fn test_graph_self_loop_ignored() {
1369        let mut graph = CorefGraph::new(3);
1370        graph.add_edge(0, 0); // Self-loop
1371        assert!(!graph.has_edge(0, 0));
1372        assert_eq!(graph.edge_count(), 0);
1373    }
1374
1375    #[test]
1376    fn test_graph_out_of_bounds_ignored() {
1377        let mut graph = CorefGraph::new(3);
1378        graph.add_edge(0, 10); // Out of bounds
1379        assert_eq!(graph.edge_count(), 0);
1380    }
1381
1382    // -------------------------------------------------------------------------
1383    // Edge cases
1384    // -------------------------------------------------------------------------
1385
1386    #[test]
1387    fn test_empty_mention_filtered() {
1388        let coref = GraphCoref::new();
1389        let mentions = vec![
1390            make_mention("John", 0),
1391            Mention::new("", 10, 10),    // Empty
1392            Mention::new("   ", 20, 23), // Whitespace only
1393            make_mention("John", 50),
1394        ];
1395
1396        let chains = coref.resolve(&mentions);
1397        assert_eq!(chains.len(), 1);
1398        assert_eq!(chains[0].mentions.len(), 2);
1399    }
1400
1401    #[test]
1402    fn test_distance_filter() {
1403        let config = GraphCorefConfig {
1404            max_distance: Some(100),
1405            ..Default::default()
1406        };
1407        let coref = GraphCoref::with_config(config);
1408
1409        let mentions = vec![make_mention("John", 0), make_mention("John", 200)];
1410
1411        let chains = coref.resolve(&mentions);
1412        assert!(chains.is_empty());
1413    }
1414
1415    #[test]
1416    fn test_stats_edge_history() {
1417        let coref = GraphCoref::new();
1418        let mentions = vec![
1419            make_mention("A", 0),
1420            make_mention("A", 10),
1421            make_mention("A", 20),
1422        ];
1423
1424        let (_, stats) = coref.resolve_with_stats(&mentions);
1425
1426        assert!(!stats.edge_history.is_empty());
1427        assert_eq!(stats.edge_history[0], 0);
1428    }
1429
1430    // -------------------------------------------------------------------------
1431    // Unicode / multilingual
1432    // -------------------------------------------------------------------------
1433
1434    #[test]
1435    fn test_unicode_cjk() {
1436        let coref = GraphCoref::new();
1437        let mentions = vec![
1438            make_mention("北京", 0),
1439            make_mention("北京", 20),
1440            make_mention("東京", 40),
1441        ];
1442
1443        let chains = coref.resolve(&mentions);
1444        assert_eq!(chains.len(), 1);
1445        assert!(chains[0].mentions.iter().all(|m| m.text == "北京"));
1446    }
1447
1448    #[test]
1449    fn test_unicode_diacritics() {
1450        let coref = GraphCoref::new();
1451        let mentions = vec![make_mention("François", 0), make_mention("François", 50)];
1452
1453        let chains = coref.resolve(&mentions);
1454        assert_eq!(chains.len(), 1);
1455    }
1456
1457    #[test]
1458    fn test_unicode_arabic_rtl() {
1459        let coref = GraphCoref::new();
1460        // Arabic: "Muhammad" repeated
1461        let mentions = vec![make_mention("محمد", 0), make_mention("محمد", 20)];
1462
1463        let chains = coref.resolve(&mentions);
1464        assert_eq!(chains.len(), 1);
1465    }
1466
1467    // -------------------------------------------------------------------------
1468    // Evaluation helper
1469    // -------------------------------------------------------------------------
1470
1471    #[test]
1472    fn test_chains_to_document() {
1473        let chain = CorefChain::new(vec![make_mention("John", 0), make_mention("he", 20)]);
1474
1475        let doc = chains_to_document("John went home. He slept.", vec![chain]);
1476
1477        assert_eq!(doc.chain_count(), 1);
1478        assert_eq!(doc.mention_count(), 2);
1479    }
1480
1481    // -------------------------------------------------------------------------
1482    // Co-occurrence seeding (SpanEIT-inspired)
1483    // -------------------------------------------------------------------------
1484
1485    #[test]
1486    fn test_cooccurrence_seeding_basic() {
1487        let mut graph = CorefGraph::new(3);
1488        let positions = vec![0, 50, 200]; // Character offsets
1489
1490        // Window of 100: should connect 0-1 (distance 50) but not 0-2 (distance 200)
1491        graph.seed_cooccurrence_edges(&positions, 100, None::<fn(usize, usize) -> bool>);
1492
1493        assert!(graph.has_edge(0, 1), "Close mentions should be connected");
1494        assert!(
1495            !graph.has_edge(0, 2),
1496            "Distant mentions should not be connected"
1497        );
1498        assert!(
1499            !graph.has_edge(1, 2),
1500            "Distant mentions should not be connected"
1501        );
1502    }
1503
1504    #[test]
1505    fn test_cooccurrence_seeding_with_scorer() {
1506        let mut graph = CorefGraph::new(3);
1507        let positions = vec![0, 50, 80];
1508
1509        // Custom scorer: only connect if both indices are even
1510        let scorer = |i: usize, j: usize| i.is_multiple_of(2) && j.is_multiple_of(2);
1511        graph.seed_cooccurrence_edges(&positions, 100, Some(scorer));
1512
1513        assert!(graph.has_edge(0, 2), "0 and 2 are both even");
1514        assert!(!graph.has_edge(0, 1), "1 is odd");
1515        assert!(!graph.has_edge(1, 2), "1 is odd");
1516    }
1517
1518    #[test]
1519    fn test_cooccurrence_seeding_empty() {
1520        let mut graph = CorefGraph::new(3);
1521        let positions: Vec<usize> = vec![];
1522
1523        graph.seed_cooccurrence_edges(&positions, 100, None::<fn(usize, usize) -> bool>);
1524
1525        assert!(graph.is_empty(), "Empty positions should create no edges");
1526    }
1527}