Skip to main content

brainwires_seal/
coreference.rs

1//! Coreference Resolution for Multi-Turn Conversations
2//!
3//! Resolves anaphoric references (pronouns, definite NPs, ellipsis) to concrete
4//! entities from the conversation history. This is critical for understanding
5//! queries like "fix it", "update the file", or "what does that function do?"
6//!
7//! ## Approach
8//!
9//! Uses a salience-based ranking algorithm that considers:
10//! - **Recency**: More recently mentioned entities score higher
11//! - **Frequency**: Entities mentioned multiple times score higher
12//! - **Graph centrality**: Important entities in the relationship graph score higher
13//! - **Type matching**: Entity type compatibility with the reference
14//! - **Syntactic prominence**: Subjects score higher than objects
15//!
16//! ## Example
17//!
18//! ```rust,ignore
19//! let resolver = CoreferenceResolver::new();
20//! let dialog_state = DialogState::new();
21//!
22//! // After discussing "main.rs"
23//! dialog_state.mention_entity("main.rs", EntityType::File);
24//!
25//! let refs = resolver.detect_references("Fix it and run the tests");
26//! // refs[0] = UnresolvedReference { text: "it", ref_type: SingularNeutral }
27//!
28//! let resolved = resolver.resolve(&refs, &dialog_state, &entity_store, None);
29//! // resolved[0].antecedent = "main.rs"
30//! ```
31
32use brainwires_core::graph::{EntityStoreT, EntityType, RelationshipGraphT};
33use regex::Regex;
34use std::collections::HashMap;
35use std::sync::LazyLock;
36
37// --- LazyLock regex statics for coreference pattern detection ---
38
39// Pronoun patterns
40static RE_SINGULAR_NEUTRAL: LazyLock<Regex> =
41    LazyLock::new(|| Regex::new(r"\b(it|this|that)\b").expect("valid regex"));
42static RE_PLURAL: LazyLock<Regex> =
43    LazyLock::new(|| Regex::new(r"\b(they|them|those|these)\b").expect("valid regex"));
44
45// Definite NP patterns
46static RE_THE_FILE: LazyLock<Regex> =
47    LazyLock::new(|| Regex::new(r"\bthe\s+(file|files)\b").expect("valid regex"));
48static RE_THE_FUNCTION: LazyLock<Regex> =
49    LazyLock::new(|| Regex::new(r"\bthe\s+(function|method|fn)\b").expect("valid regex"));
50static RE_THE_TYPE: LazyLock<Regex> = LazyLock::new(|| {
51    Regex::new(r"\bthe\s+(type|struct|class|enum|interface)\b").expect("valid regex")
52});
53static RE_THE_ERROR: LazyLock<Regex> =
54    LazyLock::new(|| Regex::new(r"\bthe\s+(error|bug|issue)\b").expect("valid regex"));
55static RE_THE_VARIABLE: LazyLock<Regex> =
56    LazyLock::new(|| Regex::new(r"\bthe\s+(variable|var|const|let)\b").expect("valid regex"));
57static RE_THE_COMMAND: LazyLock<Regex> =
58    LazyLock::new(|| Regex::new(r"\bthe\s+(command|cmd)\b").expect("valid regex"));
59
60// Demonstrative patterns
61static RE_DEMO_FILE: LazyLock<Regex> =
62    LazyLock::new(|| Regex::new(r"\b(that|this)\s+(file)\b").expect("valid regex"));
63static RE_DEMO_FUNCTION: LazyLock<Regex> =
64    LazyLock::new(|| Regex::new(r"\b(that|this)\s+(function|method|fn)\b").expect("valid regex"));
65static RE_DEMO_TYPE: LazyLock<Regex> = LazyLock::new(|| {
66    Regex::new(r"\b(that|this)\s+(type|struct|class|enum)\b").expect("valid regex")
67});
68static RE_DEMO_ERROR: LazyLock<Regex> =
69    LazyLock::new(|| Regex::new(r"\b(that|this)\s+(error|bug|issue)\b").expect("valid regex"));
70
71/// Types of anaphoric references we can detect
72#[derive(Debug, Clone, PartialEq)]
73pub enum ReferenceType {
74    /// Singular neutral pronouns: "it", "this", "that"
75    SingularNeutral,
76    /// Plural pronouns: "they", "them", "those", "these"
77    Plural,
78    /// Definite noun phrase with entity type: "the file", "the function"
79    DefiniteNP {
80        /// The entity type referenced by the noun phrase.
81        entity_type: EntityType,
82    },
83    /// Demonstrative with entity type: "that error", "this type"
84    Demonstrative {
85        /// The entity type referenced by the demonstrative.
86        entity_type: EntityType,
87    },
88    /// Missing subject from context (implied reference)
89    Ellipsis,
90}
91
92impl ReferenceType {
93    /// Get compatible entity types for this reference type
94    pub fn compatible_types(&self) -> Vec<EntityType> {
95        match self {
96            ReferenceType::SingularNeutral => vec![
97                EntityType::File,
98                EntityType::Function,
99                EntityType::Type,
100                EntityType::Variable,
101                EntityType::Error,
102                EntityType::Concept,
103                EntityType::Command,
104            ],
105            ReferenceType::Plural => vec![
106                EntityType::File,
107                EntityType::Function,
108                EntityType::Type,
109                EntityType::Variable,
110                EntityType::Error,
111            ],
112            ReferenceType::DefiniteNP { entity_type } => vec![entity_type.clone()],
113            ReferenceType::Demonstrative { entity_type } => vec![entity_type.clone()],
114            ReferenceType::Ellipsis => vec![
115                EntityType::File,
116                EntityType::Function,
117                EntityType::Type,
118                EntityType::Command,
119            ],
120        }
121    }
122}
123
124/// An unresolved reference detected in user input
125#[derive(Debug, Clone)]
126pub struct UnresolvedReference {
127    /// The text of the reference (e.g., "it", "the file")
128    pub text: String,
129    /// Type of reference
130    pub ref_type: ReferenceType,
131    /// Character offset in the original message
132    pub start: usize,
133    /// Character offset end
134    pub end: usize,
135}
136
137/// A resolved reference with its antecedent
138#[derive(Debug, Clone)]
139pub struct ResolvedReference {
140    /// The original reference
141    pub reference: UnresolvedReference,
142    /// The resolved entity name
143    pub antecedent: String,
144    /// Entity type of the antecedent
145    pub entity_type: EntityType,
146    /// Confidence score (0.0 - 1.0)
147    pub confidence: f32,
148    /// Salience breakdown for debugging
149    pub salience: SalienceScore,
150}
151
152/// Salience factors for ranking antecedent candidates
153#[derive(Debug, Clone, Default)]
154pub struct SalienceScore {
155    /// How recently mentioned (0.0 - 1.0), weight: 0.35
156    pub recency: f32,
157    /// How often mentioned (0.0 - 1.0), weight: 0.15
158    pub frequency: f32,
159    /// Importance in relationship graph (0.0 - 1.0), weight: 0.20
160    pub graph_centrality: f32,
161    /// Type compatibility (0.0 or 1.0), weight: 0.20
162    pub type_match: f32,
163    /// Subject position bonus (0.0 - 1.0), weight: 0.10
164    pub syntactic_prominence: f32,
165}
166
167impl SalienceScore {
168    /// Compute weighted total score
169    pub fn total(&self) -> f32 {
170        self.recency * 0.35
171            + self.frequency * 0.15
172            + self.graph_centrality * 0.20
173            + self.type_match * 0.20
174            + self.syntactic_prominence * 0.10
175    }
176}
177
178/// Dialog state for tracking entities across conversation turns
179#[derive(Debug, Clone, Default)]
180pub struct DialogState {
181    /// Stack of entities in current focus (most recent first)
182    pub focus_stack: Vec<String>,
183    /// Entity name -> turn number when mentioned
184    pub mention_history: HashMap<String, Vec<u32>>,
185    /// Current turn number
186    pub current_turn: u32,
187    /// Entities that were recently modified/acted upon
188    pub recently_modified: Vec<String>,
189    /// Entity type cache
190    entity_types: HashMap<String, EntityType>,
191}
192
193impl DialogState {
194    /// Create a new dialog state
195    pub fn new() -> Self {
196        Self::default()
197    }
198
199    /// Advance to the next turn
200    pub fn next_turn(&mut self) {
201        self.current_turn += 1;
202    }
203
204    /// Record a mention of an entity
205    pub fn mention_entity(&mut self, name: &str, entity_type: EntityType) {
206        // Add to focus stack (remove if already present, add to front)
207        self.focus_stack.retain(|n| n != name);
208        self.focus_stack.insert(0, name.to_string());
209
210        // Limit focus stack size
211        if self.focus_stack.len() > 20 {
212            self.focus_stack.truncate(20);
213        }
214
215        // Record mention with turn number
216        self.mention_history
217            .entry(name.to_string())
218            .or_default()
219            .push(self.current_turn);
220
221        // Cache entity type
222        self.entity_types.insert(name.to_string(), entity_type);
223    }
224
225    /// Mark an entity as recently modified
226    pub fn mark_modified(&mut self, name: &str) {
227        self.recently_modified.retain(|n| n != name);
228        self.recently_modified.insert(0, name.to_string());
229
230        // Limit modified list size
231        if self.recently_modified.len() > 10 {
232            self.recently_modified.truncate(10);
233        }
234    }
235
236    /// Get the entity type for a name (if known)
237    pub fn get_entity_type(&self, name: &str) -> Option<&EntityType> {
238        self.entity_types.get(name)
239    }
240
241    /// Get recency score for an entity (1.0 for most recent, decays with age)
242    pub fn recency_score(&self, name: &str) -> f32 {
243        // Check focus stack position
244        if let Some(pos) = self.focus_stack.iter().position(|n| n == name) {
245            let focus_score = 1.0 - (pos as f32 / self.focus_stack.len() as f32);
246
247            // Bonus for recently modified
248            let modified_bonus = if self.recently_modified.contains(&name.to_string()) {
249                0.2
250            } else {
251                0.0
252            };
253
254            (focus_score + modified_bonus).min(1.0)
255        } else {
256            // Check mention history
257            if let Some(turns) = self.mention_history.get(name) {
258                if let Some(&last_turn) = turns.last() {
259                    let age = self.current_turn.saturating_sub(last_turn) as f32;
260                    (-0.1 * age).exp() // Exponential decay
261                } else {
262                    0.0
263                }
264            } else {
265                0.0
266            }
267        }
268    }
269
270    /// Get frequency score for an entity
271    pub fn frequency_score(&self, name: &str) -> f32 {
272        if let Some(turns) = self.mention_history.get(name) {
273            let count = turns.len() as f32;
274            // Logarithmic scaling to prevent domination by very frequent entities
275            (count.ln_1p() / 3.0).min(1.0)
276        } else {
277            0.0
278        }
279    }
280
281    /// Clear the state for a new conversation
282    pub fn clear(&mut self) {
283        self.focus_stack.clear();
284        self.mention_history.clear();
285        self.current_turn = 0;
286        self.recently_modified.clear();
287        self.entity_types.clear();
288    }
289}
290
291/// Pattern definition for detecting references
292struct ReferencePattern {
293    regex: &'static Regex,
294    ref_type_fn: fn(&regex::Captures) -> ReferenceType,
295}
296
297/// Coreference resolver for multi-turn conversations
298pub struct CoreferenceResolver {
299    /// Patterns for detecting pronoun references
300    pronoun_patterns: Vec<ReferencePattern>,
301    /// Patterns for detecting definite NP references
302    definite_np_patterns: Vec<ReferencePattern>,
303    /// Patterns for detecting demonstrative references
304    demonstrative_patterns: Vec<ReferencePattern>,
305}
306
307impl CoreferenceResolver {
308    /// Create a new coreference resolver
309    pub fn new() -> Self {
310        Self {
311            pronoun_patterns: Self::build_pronoun_patterns(),
312            definite_np_patterns: Self::build_definite_np_patterns(),
313            demonstrative_patterns: Self::build_demonstrative_patterns(),
314        }
315    }
316
317    fn build_pronoun_patterns() -> Vec<ReferencePattern> {
318        vec![
319            ReferencePattern {
320                regex: &RE_SINGULAR_NEUTRAL,
321                ref_type_fn: |_| ReferenceType::SingularNeutral,
322            },
323            ReferencePattern {
324                regex: &RE_PLURAL,
325                ref_type_fn: |_| ReferenceType::Plural,
326            },
327        ]
328    }
329
330    fn build_definite_np_patterns() -> Vec<ReferencePattern> {
331        vec![
332            ReferencePattern {
333                regex: &RE_THE_FILE,
334                ref_type_fn: |_| ReferenceType::DefiniteNP {
335                    entity_type: EntityType::File,
336                },
337            },
338            ReferencePattern {
339                regex: &RE_THE_FUNCTION,
340                ref_type_fn: |_| ReferenceType::DefiniteNP {
341                    entity_type: EntityType::Function,
342                },
343            },
344            ReferencePattern {
345                regex: &RE_THE_TYPE,
346                ref_type_fn: |_| ReferenceType::DefiniteNP {
347                    entity_type: EntityType::Type,
348                },
349            },
350            ReferencePattern {
351                regex: &RE_THE_ERROR,
352                ref_type_fn: |_| ReferenceType::DefiniteNP {
353                    entity_type: EntityType::Error,
354                },
355            },
356            ReferencePattern {
357                regex: &RE_THE_VARIABLE,
358                ref_type_fn: |_| ReferenceType::DefiniteNP {
359                    entity_type: EntityType::Variable,
360                },
361            },
362            ReferencePattern {
363                regex: &RE_THE_COMMAND,
364                ref_type_fn: |_| ReferenceType::DefiniteNP {
365                    entity_type: EntityType::Command,
366                },
367            },
368        ]
369    }
370
371    fn build_demonstrative_patterns() -> Vec<ReferencePattern> {
372        vec![
373            ReferencePattern {
374                regex: &RE_DEMO_FILE,
375                ref_type_fn: |_| ReferenceType::Demonstrative {
376                    entity_type: EntityType::File,
377                },
378            },
379            ReferencePattern {
380                regex: &RE_DEMO_FUNCTION,
381                ref_type_fn: |_| ReferenceType::Demonstrative {
382                    entity_type: EntityType::Function,
383                },
384            },
385            ReferencePattern {
386                regex: &RE_DEMO_TYPE,
387                ref_type_fn: |_| ReferenceType::Demonstrative {
388                    entity_type: EntityType::Type,
389                },
390            },
391            ReferencePattern {
392                regex: &RE_DEMO_ERROR,
393                ref_type_fn: |_| ReferenceType::Demonstrative {
394                    entity_type: EntityType::Error,
395                },
396            },
397        ]
398    }
399
400    /// Detect unresolved references in a message
401    pub fn detect_references(&self, message: &str) -> Vec<UnresolvedReference> {
402        let mut references = Vec::new();
403        let lower = message.to_lowercase();
404
405        // Check demonstratives first (they're more specific)
406        for pattern in &self.demonstrative_patterns {
407            for cap in pattern.regex.captures_iter(&lower) {
408                if let Some(m) = cap.get(0) {
409                    references.push(UnresolvedReference {
410                        text: m.as_str().to_string(),
411                        ref_type: (pattern.ref_type_fn)(&cap),
412                        start: m.start(),
413                        end: m.end(),
414                    });
415                }
416            }
417        }
418
419        // Check definite NPs
420        for pattern in &self.definite_np_patterns {
421            for cap in pattern.regex.captures_iter(&lower) {
422                if let Some(m) = cap.get(0) {
423                    // Skip if already covered by a demonstrative
424                    let overlaps = references
425                        .iter()
426                        .any(|r| r.start <= m.start() && r.end >= m.end());
427                    if !overlaps {
428                        references.push(UnresolvedReference {
429                            text: m.as_str().to_string(),
430                            ref_type: (pattern.ref_type_fn)(&cap),
431                            start: m.start(),
432                            end: m.end(),
433                        });
434                    }
435                }
436            }
437        }
438
439        // Check pronouns last
440        for pattern in &self.pronoun_patterns {
441            for cap in pattern.regex.captures_iter(&lower) {
442                if let Some(m) = cap.get(0) {
443                    // Skip if already covered
444                    let overlaps = references
445                        .iter()
446                        .any(|r| r.start <= m.start() && r.end >= m.end());
447                    if !overlaps {
448                        references.push(UnresolvedReference {
449                            text: m.as_str().to_string(),
450                            ref_type: (pattern.ref_type_fn)(&cap),
451                            start: m.start(),
452                            end: m.end(),
453                        });
454                    }
455                }
456            }
457        }
458
459        // Sort by position
460        references.sort_by_key(|r| r.start);
461        references
462    }
463
464    /// Resolve references using dialog state and entity store
465    pub fn resolve(
466        &self,
467        references: &[UnresolvedReference],
468        dialog_state: &DialogState,
469        entity_store: &dyn EntityStoreT,
470        graph: Option<&dyn RelationshipGraphT>,
471    ) -> Vec<ResolvedReference> {
472        let mut resolved = Vec::new();
473
474        for reference in references {
475            if let Some(resolution) =
476                self.resolve_single(reference, dialog_state, entity_store, graph)
477            {
478                resolved.push(resolution);
479            }
480        }
481
482        resolved
483    }
484
485    /// Resolve a single reference
486    fn resolve_single(
487        &self,
488        reference: &UnresolvedReference,
489        dialog_state: &DialogState,
490        entity_store: &dyn EntityStoreT,
491        graph: Option<&dyn RelationshipGraphT>,
492    ) -> Option<ResolvedReference> {
493        let compatible_types = reference.ref_type.compatible_types();
494
495        // Gather candidates from dialog state and entity store
496        let mut candidates: Vec<(&str, &EntityType, SalienceScore)> = Vec::new();
497
498        // Check focus stack first (most likely candidates)
499        for name in &dialog_state.focus_stack {
500            if let Some(entity_type) = dialog_state.get_entity_type(name)
501                && compatible_types.contains(entity_type)
502            {
503                let salience = self.compute_salience(name, entity_type, dialog_state, graph);
504                candidates.push((name, entity_type, salience));
505            }
506        }
507
508        // Also check entity store for additional candidates
509        let entity_names: Vec<(String, EntityType)> = compatible_types
510            .iter()
511            .flat_map(|et| {
512                entity_store
513                    .entity_names_by_type(et)
514                    .into_iter()
515                    .map(move |name| (name, et.clone()))
516            })
517            .collect();
518
519        for (entity_name, entity_type) in &entity_names {
520            // Skip if already in candidates
521            if candidates
522                .iter()
523                .any(|(n, _, _)| *n == entity_name.as_str())
524            {
525                continue;
526            }
527
528            let salience = self.compute_salience(entity_name, entity_type, dialog_state, graph);
529            candidates.push((entity_name, entity_type, salience));
530        }
531
532        // Sort by salience score
533        candidates.sort_by(|a, b| {
534            b.2.total()
535                .partial_cmp(&a.2.total())
536                .unwrap_or(std::cmp::Ordering::Equal)
537        });
538
539        // Take the best candidate
540        candidates
541            .first()
542            .map(|(name, entity_type, salience)| ResolvedReference {
543                reference: reference.clone(),
544                antecedent: name.to_string(),
545                entity_type: (*entity_type).clone(),
546                confidence: salience.total(),
547                salience: salience.clone(),
548            })
549    }
550
551    /// Compute salience score for a candidate antecedent
552    fn compute_salience(
553        &self,
554        name: &str,
555        _entity_type: &EntityType,
556        dialog_state: &DialogState,
557        graph: Option<&dyn RelationshipGraphT>,
558    ) -> SalienceScore {
559        let recency = dialog_state.recency_score(name);
560        let frequency = dialog_state.frequency_score(name);
561
562        let graph_centrality = if let Some(g) = graph {
563            if let Some(node) = g.get_node(name) {
564                node.importance
565            } else {
566                0.0
567            }
568        } else {
569            0.5 // Neutral if no graph available
570        };
571
572        // Type match is handled at the candidate selection stage
573        let type_match = 1.0;
574
575        // Syntactic prominence - subjects in focus stack get bonus
576        let syntactic_prominence = if dialog_state.focus_stack.first() == Some(&name.to_string()) {
577            1.0
578        } else if dialog_state.focus_stack.contains(&name.to_string()) {
579            0.5
580        } else {
581            0.0
582        };
583
584        SalienceScore {
585            recency,
586            frequency,
587            graph_centrality,
588            type_match,
589            syntactic_prominence,
590        }
591    }
592
593    /// Rewrite message with resolved references
594    pub fn rewrite_with_resolutions(
595        &self,
596        message: &str,
597        resolutions: &[ResolvedReference],
598    ) -> String {
599        if resolutions.is_empty() {
600            return message.to_string();
601        }
602
603        // Sort resolutions by position (descending) to replace from end first
604        let mut sorted = resolutions.to_vec();
605        sorted.sort_by(|a, b| b.reference.start.cmp(&a.reference.start));
606
607        let mut result = message.to_string();
608        let lower = message.to_lowercase();
609
610        for resolution in sorted {
611            // Find the actual position in the original string
612            // (accounting for case differences)
613            let search_start = resolution.reference.start;
614            let search_end = resolution.reference.end;
615
616            if search_end <= lower.len() && search_start < search_end {
617                // Create the replacement with bracket notation
618                let replacement = format!("[{}]", resolution.antecedent);
619
620                // Replace in the result string
621                // We need to find the corresponding position in the (possibly modified) result
622                let ref_text = &lower[search_start..search_end];
623                if let Some(pos) = result.to_lowercase().find(ref_text) {
624                    result = format!(
625                        "{}{}{}",
626                        &result[..pos],
627                        replacement,
628                        &result[pos + (search_end - search_start)..]
629                    );
630                }
631            }
632        }
633
634        result
635    }
636}
637
638impl Default for CoreferenceResolver {
639    fn default() -> Self {
640        Self::new()
641    }
642}
643
644#[cfg(test)]
645mod tests {
646    use super::*;
647    use brainwires_knowledge::knowledge::EntityStore;
648
649    #[test]
650    fn test_detect_pronouns() {
651        let resolver = CoreferenceResolver::new();
652        let refs = resolver.detect_references("Fix it and run the tests");
653
654        assert!(!refs.is_empty());
655        assert!(refs.iter().any(|r| r.text == "it"));
656        assert!(refs[0].ref_type == ReferenceType::SingularNeutral);
657    }
658
659    #[test]
660    fn test_detect_definite_np() {
661        let resolver = CoreferenceResolver::new();
662        let refs = resolver.detect_references("Update the file with the new logic");
663
664        assert!(refs.iter().any(|r| r.text == "the file"));
665        assert!(refs.iter().any(|r| matches!(
666            &r.ref_type,
667            ReferenceType::DefiniteNP { entity_type } if *entity_type == EntityType::File
668        )));
669    }
670
671    #[test]
672    fn test_detect_demonstrative() {
673        let resolver = CoreferenceResolver::new();
674        let refs = resolver.detect_references("Fix that error in the code");
675
676        assert!(refs.iter().any(|r| r.text == "that error"));
677        assert!(refs.iter().any(|r| matches!(
678            &r.ref_type,
679            ReferenceType::Demonstrative { entity_type } if *entity_type == EntityType::Error
680        )));
681    }
682
683    #[test]
684    fn test_dialog_state_mention() {
685        let mut state = DialogState::new();
686        state.mention_entity("main.rs", EntityType::File);
687        state.next_turn();
688        state.mention_entity("config.toml", EntityType::File);
689
690        // config.toml should be at the top of focus stack
691        assert_eq!(state.focus_stack[0], "config.toml");
692        assert_eq!(state.focus_stack[1], "main.rs");
693
694        // Recency score should be higher for config.toml
695        assert!(state.recency_score("config.toml") > state.recency_score("main.rs"));
696    }
697
698    #[test]
699    fn test_dialog_state_frequency() {
700        let mut state = DialogState::new();
701        state.mention_entity("main.rs", EntityType::File);
702        state.next_turn();
703        state.mention_entity("main.rs", EntityType::File);
704        state.next_turn();
705        state.mention_entity("config.toml", EntityType::File);
706
707        // main.rs mentioned twice, should have higher frequency
708        assert!(state.frequency_score("main.rs") > state.frequency_score("config.toml"));
709    }
710
711    #[test]
712    fn test_resolve_pronoun() {
713        let resolver = CoreferenceResolver::new();
714        let mut state = DialogState::new();
715        let entity_store = EntityStore::new();
716
717        state.mention_entity("src/main.rs", EntityType::File);
718        state.next_turn();
719
720        let refs = resolver.detect_references("Fix it");
721        let resolved = resolver.resolve(&refs, &state, &entity_store, None);
722
723        assert_eq!(resolved.len(), 1);
724        assert_eq!(resolved[0].antecedent, "src/main.rs");
725    }
726
727    #[test]
728    fn test_resolve_type_constrained() {
729        let resolver = CoreferenceResolver::new();
730        let mut state = DialogState::new();
731        let entity_store = EntityStore::new();
732
733        // Mention a file and a function
734        state.mention_entity("main.rs", EntityType::File);
735        state.mention_entity("process_data", EntityType::Function);
736        state.next_turn();
737
738        // "the function" should resolve to the function, not the file
739        let refs = resolver.detect_references("Update the function");
740        let resolved = resolver.resolve(&refs, &state, &entity_store, None);
741
742        assert_eq!(resolved.len(), 1);
743        assert_eq!(resolved[0].antecedent, "process_data");
744    }
745
746    #[test]
747    fn test_rewrite_with_resolutions() {
748        let resolver = CoreferenceResolver::new();
749        let mut state = DialogState::new();
750        let entity_store = EntityStore::new();
751
752        state.mention_entity("main.rs", EntityType::File);
753        state.next_turn();
754
755        let refs = resolver.detect_references("Fix it and test");
756        let resolved = resolver.resolve(&refs, &state, &entity_store, None);
757        let rewritten = resolver.rewrite_with_resolutions("Fix it and test", &resolved);
758
759        assert_eq!(rewritten, "Fix [main.rs] and test");
760    }
761
762    #[test]
763    fn test_salience_score_total() {
764        let score = SalienceScore {
765            recency: 1.0,
766            frequency: 0.5,
767            graph_centrality: 0.8,
768            type_match: 1.0,
769            syntactic_prominence: 0.5,
770        };
771
772        // 1.0*0.35 + 0.5*0.15 + 0.8*0.20 + 1.0*0.20 + 0.5*0.10
773        // = 0.35 + 0.075 + 0.16 + 0.20 + 0.05 = 0.835
774        assert!((score.total() - 0.835).abs() < 0.001);
775    }
776
777    #[test]
778    fn test_empty_references() {
779        let resolver = CoreferenceResolver::new();
780        let refs = resolver.detect_references("Build the project using cargo");
781
782        // "the project" doesn't match our patterns
783        // This is expected - we only detect specific entity type references
784        assert!(refs.is_empty() || !refs.iter().any(|r| r.text == "the project"));
785    }
786
787    #[test]
788    fn test_multiple_references() {
789        let resolver = CoreferenceResolver::new();
790        let refs = resolver.detect_references("Fix it and update the file");
791
792        assert!(refs.len() >= 2);
793        // Should have both "it" and "the file"
794        let texts: Vec<_> = refs.iter().map(|r| r.text.as_str()).collect();
795        assert!(texts.contains(&"it"));
796        assert!(texts.contains(&"the file"));
797    }
798}