1use brainwires_core::graph::{EntityStoreT, EntityType, RelationshipGraphT};
33use regex::Regex;
34use std::collections::HashMap;
35use std::sync::LazyLock;
36
37static RE_SINGULAR_NEUTRAL: LazyLock<Regex> =
41 LazyLock::new(|| Regex::new(r"\b(it|this|that)\b").expect("valid regex"));
42static RE_PLURAL: LazyLock<Regex> =
43 LazyLock::new(|| Regex::new(r"\b(they|them|those|these)\b").expect("valid regex"));
44
45static RE_THE_FILE: LazyLock<Regex> =
47 LazyLock::new(|| Regex::new(r"\bthe\s+(file|files)\b").expect("valid regex"));
48static RE_THE_FUNCTION: LazyLock<Regex> =
49 LazyLock::new(|| Regex::new(r"\bthe\s+(function|method|fn)\b").expect("valid regex"));
50static RE_THE_TYPE: LazyLock<Regex> = LazyLock::new(|| {
51 Regex::new(r"\bthe\s+(type|struct|class|enum|interface)\b").expect("valid regex")
52});
53static RE_THE_ERROR: LazyLock<Regex> =
54 LazyLock::new(|| Regex::new(r"\bthe\s+(error|bug|issue)\b").expect("valid regex"));
55static RE_THE_VARIABLE: LazyLock<Regex> =
56 LazyLock::new(|| Regex::new(r"\bthe\s+(variable|var|const|let)\b").expect("valid regex"));
57static RE_THE_COMMAND: LazyLock<Regex> =
58 LazyLock::new(|| Regex::new(r"\bthe\s+(command|cmd)\b").expect("valid regex"));
59
60static RE_DEMO_FILE: LazyLock<Regex> =
62 LazyLock::new(|| Regex::new(r"\b(that|this)\s+(file)\b").expect("valid regex"));
63static RE_DEMO_FUNCTION: LazyLock<Regex> =
64 LazyLock::new(|| Regex::new(r"\b(that|this)\s+(function|method|fn)\b").expect("valid regex"));
65static RE_DEMO_TYPE: LazyLock<Regex> = LazyLock::new(|| {
66 Regex::new(r"\b(that|this)\s+(type|struct|class|enum)\b").expect("valid regex")
67});
68static RE_DEMO_ERROR: LazyLock<Regex> =
69 LazyLock::new(|| Regex::new(r"\b(that|this)\s+(error|bug|issue)\b").expect("valid regex"));
70
71#[derive(Debug, Clone, PartialEq)]
73pub enum ReferenceType {
74 SingularNeutral,
76 Plural,
78 DefiniteNP {
80 entity_type: EntityType,
82 },
83 Demonstrative {
85 entity_type: EntityType,
87 },
88 Ellipsis,
90}
91
92impl ReferenceType {
93 pub fn compatible_types(&self) -> Vec<EntityType> {
95 match self {
96 ReferenceType::SingularNeutral => vec![
97 EntityType::File,
98 EntityType::Function,
99 EntityType::Type,
100 EntityType::Variable,
101 EntityType::Error,
102 EntityType::Concept,
103 EntityType::Command,
104 ],
105 ReferenceType::Plural => vec![
106 EntityType::File,
107 EntityType::Function,
108 EntityType::Type,
109 EntityType::Variable,
110 EntityType::Error,
111 ],
112 ReferenceType::DefiniteNP { entity_type } => vec![entity_type.clone()],
113 ReferenceType::Demonstrative { entity_type } => vec![entity_type.clone()],
114 ReferenceType::Ellipsis => vec![
115 EntityType::File,
116 EntityType::Function,
117 EntityType::Type,
118 EntityType::Command,
119 ],
120 }
121 }
122}
123
124#[derive(Debug, Clone)]
126pub struct UnresolvedReference {
127 pub text: String,
129 pub ref_type: ReferenceType,
131 pub start: usize,
133 pub end: usize,
135}
136
137#[derive(Debug, Clone)]
139pub struct ResolvedReference {
140 pub reference: UnresolvedReference,
142 pub antecedent: String,
144 pub entity_type: EntityType,
146 pub confidence: f32,
148 pub salience: SalienceScore,
150}
151
152#[derive(Debug, Clone, Default)]
154pub struct SalienceScore {
155 pub recency: f32,
157 pub frequency: f32,
159 pub graph_centrality: f32,
161 pub type_match: f32,
163 pub syntactic_prominence: f32,
165}
166
167impl SalienceScore {
168 pub fn total(&self) -> f32 {
170 self.recency * 0.35
171 + self.frequency * 0.15
172 + self.graph_centrality * 0.20
173 + self.type_match * 0.20
174 + self.syntactic_prominence * 0.10
175 }
176}
177
178#[derive(Debug, Clone, Default)]
180pub struct DialogState {
181 pub focus_stack: Vec<String>,
183 pub mention_history: HashMap<String, Vec<u32>>,
185 pub current_turn: u32,
187 pub recently_modified: Vec<String>,
189 entity_types: HashMap<String, EntityType>,
191}
192
193impl DialogState {
194 pub fn new() -> Self {
196 Self::default()
197 }
198
199 pub fn next_turn(&mut self) {
201 self.current_turn += 1;
202 }
203
204 pub fn mention_entity(&mut self, name: &str, entity_type: EntityType) {
206 self.focus_stack.retain(|n| n != name);
208 self.focus_stack.insert(0, name.to_string());
209
210 if self.focus_stack.len() > 20 {
212 self.focus_stack.truncate(20);
213 }
214
215 self.mention_history
217 .entry(name.to_string())
218 .or_default()
219 .push(self.current_turn);
220
221 self.entity_types.insert(name.to_string(), entity_type);
223 }
224
225 pub fn mark_modified(&mut self, name: &str) {
227 self.recently_modified.retain(|n| n != name);
228 self.recently_modified.insert(0, name.to_string());
229
230 if self.recently_modified.len() > 10 {
232 self.recently_modified.truncate(10);
233 }
234 }
235
236 pub fn get_entity_type(&self, name: &str) -> Option<&EntityType> {
238 self.entity_types.get(name)
239 }
240
241 pub fn recency_score(&self, name: &str) -> f32 {
243 if let Some(pos) = self.focus_stack.iter().position(|n| n == name) {
245 let focus_score = 1.0 - (pos as f32 / self.focus_stack.len() as f32);
246
247 let modified_bonus = if self.recently_modified.contains(&name.to_string()) {
249 0.2
250 } else {
251 0.0
252 };
253
254 (focus_score + modified_bonus).min(1.0)
255 } else {
256 if let Some(turns) = self.mention_history.get(name) {
258 if let Some(&last_turn) = turns.last() {
259 let age = self.current_turn.saturating_sub(last_turn) as f32;
260 (-0.1 * age).exp() } else {
262 0.0
263 }
264 } else {
265 0.0
266 }
267 }
268 }
269
270 pub fn frequency_score(&self, name: &str) -> f32 {
272 if let Some(turns) = self.mention_history.get(name) {
273 let count = turns.len() as f32;
274 (count.ln_1p() / 3.0).min(1.0)
276 } else {
277 0.0
278 }
279 }
280
281 pub fn clear(&mut self) {
283 self.focus_stack.clear();
284 self.mention_history.clear();
285 self.current_turn = 0;
286 self.recently_modified.clear();
287 self.entity_types.clear();
288 }
289}
290
291struct ReferencePattern {
293 regex: &'static Regex,
294 ref_type_fn: fn(®ex::Captures) -> ReferenceType,
295}
296
297pub struct CoreferenceResolver {
299 pronoun_patterns: Vec<ReferencePattern>,
301 definite_np_patterns: Vec<ReferencePattern>,
303 demonstrative_patterns: Vec<ReferencePattern>,
305}
306
307impl CoreferenceResolver {
308 pub fn new() -> Self {
310 Self {
311 pronoun_patterns: Self::build_pronoun_patterns(),
312 definite_np_patterns: Self::build_definite_np_patterns(),
313 demonstrative_patterns: Self::build_demonstrative_patterns(),
314 }
315 }
316
317 fn build_pronoun_patterns() -> Vec<ReferencePattern> {
318 vec![
319 ReferencePattern {
320 regex: &RE_SINGULAR_NEUTRAL,
321 ref_type_fn: |_| ReferenceType::SingularNeutral,
322 },
323 ReferencePattern {
324 regex: &RE_PLURAL,
325 ref_type_fn: |_| ReferenceType::Plural,
326 },
327 ]
328 }
329
330 fn build_definite_np_patterns() -> Vec<ReferencePattern> {
331 vec![
332 ReferencePattern {
333 regex: &RE_THE_FILE,
334 ref_type_fn: |_| ReferenceType::DefiniteNP {
335 entity_type: EntityType::File,
336 },
337 },
338 ReferencePattern {
339 regex: &RE_THE_FUNCTION,
340 ref_type_fn: |_| ReferenceType::DefiniteNP {
341 entity_type: EntityType::Function,
342 },
343 },
344 ReferencePattern {
345 regex: &RE_THE_TYPE,
346 ref_type_fn: |_| ReferenceType::DefiniteNP {
347 entity_type: EntityType::Type,
348 },
349 },
350 ReferencePattern {
351 regex: &RE_THE_ERROR,
352 ref_type_fn: |_| ReferenceType::DefiniteNP {
353 entity_type: EntityType::Error,
354 },
355 },
356 ReferencePattern {
357 regex: &RE_THE_VARIABLE,
358 ref_type_fn: |_| ReferenceType::DefiniteNP {
359 entity_type: EntityType::Variable,
360 },
361 },
362 ReferencePattern {
363 regex: &RE_THE_COMMAND,
364 ref_type_fn: |_| ReferenceType::DefiniteNP {
365 entity_type: EntityType::Command,
366 },
367 },
368 ]
369 }
370
371 fn build_demonstrative_patterns() -> Vec<ReferencePattern> {
372 vec![
373 ReferencePattern {
374 regex: &RE_DEMO_FILE,
375 ref_type_fn: |_| ReferenceType::Demonstrative {
376 entity_type: EntityType::File,
377 },
378 },
379 ReferencePattern {
380 regex: &RE_DEMO_FUNCTION,
381 ref_type_fn: |_| ReferenceType::Demonstrative {
382 entity_type: EntityType::Function,
383 },
384 },
385 ReferencePattern {
386 regex: &RE_DEMO_TYPE,
387 ref_type_fn: |_| ReferenceType::Demonstrative {
388 entity_type: EntityType::Type,
389 },
390 },
391 ReferencePattern {
392 regex: &RE_DEMO_ERROR,
393 ref_type_fn: |_| ReferenceType::Demonstrative {
394 entity_type: EntityType::Error,
395 },
396 },
397 ]
398 }
399
400 pub fn detect_references(&self, message: &str) -> Vec<UnresolvedReference> {
402 let mut references = Vec::new();
403 let lower = message.to_lowercase();
404
405 for pattern in &self.demonstrative_patterns {
407 for cap in pattern.regex.captures_iter(&lower) {
408 if let Some(m) = cap.get(0) {
409 references.push(UnresolvedReference {
410 text: m.as_str().to_string(),
411 ref_type: (pattern.ref_type_fn)(&cap),
412 start: m.start(),
413 end: m.end(),
414 });
415 }
416 }
417 }
418
419 for pattern in &self.definite_np_patterns {
421 for cap in pattern.regex.captures_iter(&lower) {
422 if let Some(m) = cap.get(0) {
423 let overlaps = references
425 .iter()
426 .any(|r| r.start <= m.start() && r.end >= m.end());
427 if !overlaps {
428 references.push(UnresolvedReference {
429 text: m.as_str().to_string(),
430 ref_type: (pattern.ref_type_fn)(&cap),
431 start: m.start(),
432 end: m.end(),
433 });
434 }
435 }
436 }
437 }
438
439 for pattern in &self.pronoun_patterns {
441 for cap in pattern.regex.captures_iter(&lower) {
442 if let Some(m) = cap.get(0) {
443 let overlaps = references
445 .iter()
446 .any(|r| r.start <= m.start() && r.end >= m.end());
447 if !overlaps {
448 references.push(UnresolvedReference {
449 text: m.as_str().to_string(),
450 ref_type: (pattern.ref_type_fn)(&cap),
451 start: m.start(),
452 end: m.end(),
453 });
454 }
455 }
456 }
457 }
458
459 references.sort_by_key(|r| r.start);
461 references
462 }
463
464 pub fn resolve(
466 &self,
467 references: &[UnresolvedReference],
468 dialog_state: &DialogState,
469 entity_store: &dyn EntityStoreT,
470 graph: Option<&dyn RelationshipGraphT>,
471 ) -> Vec<ResolvedReference> {
472 let mut resolved = Vec::new();
473
474 for reference in references {
475 if let Some(resolution) =
476 self.resolve_single(reference, dialog_state, entity_store, graph)
477 {
478 resolved.push(resolution);
479 }
480 }
481
482 resolved
483 }
484
485 fn resolve_single(
487 &self,
488 reference: &UnresolvedReference,
489 dialog_state: &DialogState,
490 entity_store: &dyn EntityStoreT,
491 graph: Option<&dyn RelationshipGraphT>,
492 ) -> Option<ResolvedReference> {
493 let compatible_types = reference.ref_type.compatible_types();
494
495 let mut candidates: Vec<(&str, &EntityType, SalienceScore)> = Vec::new();
497
498 for name in &dialog_state.focus_stack {
500 if let Some(entity_type) = dialog_state.get_entity_type(name)
501 && compatible_types.contains(entity_type)
502 {
503 let salience = self.compute_salience(name, entity_type, dialog_state, graph);
504 candidates.push((name, entity_type, salience));
505 }
506 }
507
508 let entity_names: Vec<(String, EntityType)> = compatible_types
510 .iter()
511 .flat_map(|et| {
512 entity_store
513 .entity_names_by_type(et)
514 .into_iter()
515 .map(move |name| (name, et.clone()))
516 })
517 .collect();
518
519 for (entity_name, entity_type) in &entity_names {
520 if candidates
522 .iter()
523 .any(|(n, _, _)| *n == entity_name.as_str())
524 {
525 continue;
526 }
527
528 let salience = self.compute_salience(entity_name, entity_type, dialog_state, graph);
529 candidates.push((entity_name, entity_type, salience));
530 }
531
532 candidates.sort_by(|a, b| {
534 b.2.total()
535 .partial_cmp(&a.2.total())
536 .unwrap_or(std::cmp::Ordering::Equal)
537 });
538
539 candidates
541 .first()
542 .map(|(name, entity_type, salience)| ResolvedReference {
543 reference: reference.clone(),
544 antecedent: name.to_string(),
545 entity_type: (*entity_type).clone(),
546 confidence: salience.total(),
547 salience: salience.clone(),
548 })
549 }
550
551 fn compute_salience(
553 &self,
554 name: &str,
555 _entity_type: &EntityType,
556 dialog_state: &DialogState,
557 graph: Option<&dyn RelationshipGraphT>,
558 ) -> SalienceScore {
559 let recency = dialog_state.recency_score(name);
560 let frequency = dialog_state.frequency_score(name);
561
562 let graph_centrality = if let Some(g) = graph {
563 if let Some(node) = g.get_node(name) {
564 node.importance
565 } else {
566 0.0
567 }
568 } else {
569 0.5 };
571
572 let type_match = 1.0;
574
575 let syntactic_prominence = if dialog_state.focus_stack.first() == Some(&name.to_string()) {
577 1.0
578 } else if dialog_state.focus_stack.contains(&name.to_string()) {
579 0.5
580 } else {
581 0.0
582 };
583
584 SalienceScore {
585 recency,
586 frequency,
587 graph_centrality,
588 type_match,
589 syntactic_prominence,
590 }
591 }
592
593 pub fn rewrite_with_resolutions(
595 &self,
596 message: &str,
597 resolutions: &[ResolvedReference],
598 ) -> String {
599 if resolutions.is_empty() {
600 return message.to_string();
601 }
602
603 let mut sorted = resolutions.to_vec();
605 sorted.sort_by(|a, b| b.reference.start.cmp(&a.reference.start));
606
607 let mut result = message.to_string();
608 let lower = message.to_lowercase();
609
610 for resolution in sorted {
611 let search_start = resolution.reference.start;
614 let search_end = resolution.reference.end;
615
616 if search_end <= lower.len() && search_start < search_end {
617 let replacement = format!("[{}]", resolution.antecedent);
619
620 let ref_text = &lower[search_start..search_end];
623 if let Some(pos) = result.to_lowercase().find(ref_text) {
624 result = format!(
625 "{}{}{}",
626 &result[..pos],
627 replacement,
628 &result[pos + (search_end - search_start)..]
629 );
630 }
631 }
632 }
633
634 result
635 }
636}
637
638impl Default for CoreferenceResolver {
639 fn default() -> Self {
640 Self::new()
641 }
642}
643
644#[cfg(test)]
645mod tests {
646 use super::*;
647 use brainwires_knowledge::knowledge::EntityStore;
648
649 #[test]
650 fn test_detect_pronouns() {
651 let resolver = CoreferenceResolver::new();
652 let refs = resolver.detect_references("Fix it and run the tests");
653
654 assert!(!refs.is_empty());
655 assert!(refs.iter().any(|r| r.text == "it"));
656 assert!(refs[0].ref_type == ReferenceType::SingularNeutral);
657 }
658
659 #[test]
660 fn test_detect_definite_np() {
661 let resolver = CoreferenceResolver::new();
662 let refs = resolver.detect_references("Update the file with the new logic");
663
664 assert!(refs.iter().any(|r| r.text == "the file"));
665 assert!(refs.iter().any(|r| matches!(
666 &r.ref_type,
667 ReferenceType::DefiniteNP { entity_type } if *entity_type == EntityType::File
668 )));
669 }
670
671 #[test]
672 fn test_detect_demonstrative() {
673 let resolver = CoreferenceResolver::new();
674 let refs = resolver.detect_references("Fix that error in the code");
675
676 assert!(refs.iter().any(|r| r.text == "that error"));
677 assert!(refs.iter().any(|r| matches!(
678 &r.ref_type,
679 ReferenceType::Demonstrative { entity_type } if *entity_type == EntityType::Error
680 )));
681 }
682
683 #[test]
684 fn test_dialog_state_mention() {
685 let mut state = DialogState::new();
686 state.mention_entity("main.rs", EntityType::File);
687 state.next_turn();
688 state.mention_entity("config.toml", EntityType::File);
689
690 assert_eq!(state.focus_stack[0], "config.toml");
692 assert_eq!(state.focus_stack[1], "main.rs");
693
694 assert!(state.recency_score("config.toml") > state.recency_score("main.rs"));
696 }
697
698 #[test]
699 fn test_dialog_state_frequency() {
700 let mut state = DialogState::new();
701 state.mention_entity("main.rs", EntityType::File);
702 state.next_turn();
703 state.mention_entity("main.rs", EntityType::File);
704 state.next_turn();
705 state.mention_entity("config.toml", EntityType::File);
706
707 assert!(state.frequency_score("main.rs") > state.frequency_score("config.toml"));
709 }
710
711 #[test]
712 fn test_resolve_pronoun() {
713 let resolver = CoreferenceResolver::new();
714 let mut state = DialogState::new();
715 let entity_store = EntityStore::new();
716
717 state.mention_entity("src/main.rs", EntityType::File);
718 state.next_turn();
719
720 let refs = resolver.detect_references("Fix it");
721 let resolved = resolver.resolve(&refs, &state, &entity_store, None);
722
723 assert_eq!(resolved.len(), 1);
724 assert_eq!(resolved[0].antecedent, "src/main.rs");
725 }
726
727 #[test]
728 fn test_resolve_type_constrained() {
729 let resolver = CoreferenceResolver::new();
730 let mut state = DialogState::new();
731 let entity_store = EntityStore::new();
732
733 state.mention_entity("main.rs", EntityType::File);
735 state.mention_entity("process_data", EntityType::Function);
736 state.next_turn();
737
738 let refs = resolver.detect_references("Update the function");
740 let resolved = resolver.resolve(&refs, &state, &entity_store, None);
741
742 assert_eq!(resolved.len(), 1);
743 assert_eq!(resolved[0].antecedent, "process_data");
744 }
745
746 #[test]
747 fn test_rewrite_with_resolutions() {
748 let resolver = CoreferenceResolver::new();
749 let mut state = DialogState::new();
750 let entity_store = EntityStore::new();
751
752 state.mention_entity("main.rs", EntityType::File);
753 state.next_turn();
754
755 let refs = resolver.detect_references("Fix it and test");
756 let resolved = resolver.resolve(&refs, &state, &entity_store, None);
757 let rewritten = resolver.rewrite_with_resolutions("Fix it and test", &resolved);
758
759 assert_eq!(rewritten, "Fix [main.rs] and test");
760 }
761
762 #[test]
763 fn test_salience_score_total() {
764 let score = SalienceScore {
765 recency: 1.0,
766 frequency: 0.5,
767 graph_centrality: 0.8,
768 type_match: 1.0,
769 syntactic_prominence: 0.5,
770 };
771
772 assert!((score.total() - 0.835).abs() < 0.001);
775 }
776
777 #[test]
778 fn test_empty_references() {
779 let resolver = CoreferenceResolver::new();
780 let refs = resolver.detect_references("Build the project using cargo");
781
782 assert!(refs.is_empty() || !refs.iter().any(|r| r.text == "the project"));
785 }
786
787 #[test]
788 fn test_multiple_references() {
789 let resolver = CoreferenceResolver::new();
790 let refs = resolver.detect_references("Fix it and update the file");
791
792 assert!(refs.len() >= 2);
793 let texts: Vec<_> = refs.iter().map(|r| r.text.as_str()).collect();
795 assert!(texts.contains(&"it"));
796 assert!(texts.contains(&"the file"));
797 }
798}