use brainwires_core::graph::{EntityStoreT, EntityType, RelationshipGraphT};
use regex::Regex;
use std::collections::HashMap;
use std::sync::LazyLock;
static RE_SINGULAR_NEUTRAL: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\b(it|this|that)\b").expect("valid regex"));
static RE_PLURAL: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\b(they|them|those|these)\b").expect("valid regex"));
static RE_THE_FILE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\bthe\s+(file|files)\b").expect("valid regex"));
static RE_THE_FUNCTION: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\bthe\s+(function|method|fn)\b").expect("valid regex"));
static RE_THE_TYPE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"\bthe\s+(type|struct|class|enum|interface)\b").expect("valid regex")
});
static RE_THE_ERROR: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\bthe\s+(error|bug|issue)\b").expect("valid regex"));
static RE_THE_VARIABLE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\bthe\s+(variable|var|const|let)\b").expect("valid regex"));
static RE_THE_COMMAND: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\bthe\s+(command|cmd)\b").expect("valid regex"));
static RE_DEMO_FILE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\b(that|this)\s+(file)\b").expect("valid regex"));
static RE_DEMO_FUNCTION: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\b(that|this)\s+(function|method|fn)\b").expect("valid regex"));
static RE_DEMO_TYPE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"\b(that|this)\s+(type|struct|class|enum)\b").expect("valid regex")
});
static RE_DEMO_ERROR: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\b(that|this)\s+(error|bug|issue)\b").expect("valid regex"));
#[derive(Debug, Clone, PartialEq)]
pub enum ReferenceType {
SingularNeutral,
Plural,
DefiniteNP {
entity_type: EntityType,
},
Demonstrative {
entity_type: EntityType,
},
Ellipsis,
}
impl ReferenceType {
pub fn compatible_types(&self) -> Vec<EntityType> {
match self {
ReferenceType::SingularNeutral => vec![
EntityType::File,
EntityType::Function,
EntityType::Type,
EntityType::Variable,
EntityType::Error,
EntityType::Concept,
EntityType::Command,
],
ReferenceType::Plural => vec![
EntityType::File,
EntityType::Function,
EntityType::Type,
EntityType::Variable,
EntityType::Error,
],
ReferenceType::DefiniteNP { entity_type } => vec![entity_type.clone()],
ReferenceType::Demonstrative { entity_type } => vec![entity_type.clone()],
ReferenceType::Ellipsis => vec![
EntityType::File,
EntityType::Function,
EntityType::Type,
EntityType::Command,
],
}
}
}
#[derive(Debug, Clone)]
pub struct UnresolvedReference {
pub text: String,
pub ref_type: ReferenceType,
pub start: usize,
pub end: usize,
}
#[derive(Debug, Clone)]
pub struct ResolvedReference {
pub reference: UnresolvedReference,
pub antecedent: String,
pub entity_type: EntityType,
pub confidence: f32,
pub salience: SalienceScore,
}
#[derive(Debug, Clone, Default)]
pub struct SalienceScore {
pub recency: f32,
pub frequency: f32,
pub graph_centrality: f32,
pub type_match: f32,
pub syntactic_prominence: f32,
}
impl SalienceScore {
pub fn total(&self) -> f32 {
self.recency * 0.35
+ self.frequency * 0.15
+ self.graph_centrality * 0.20
+ self.type_match * 0.20
+ self.syntactic_prominence * 0.10
}
}
#[derive(Debug, Clone, Default)]
pub struct DialogState {
pub focus_stack: Vec<String>,
pub mention_history: HashMap<String, Vec<u32>>,
pub current_turn: u32,
pub recently_modified: Vec<String>,
entity_types: HashMap<String, EntityType>,
}
impl DialogState {
pub fn new() -> Self {
Self::default()
}
pub fn next_turn(&mut self) {
self.current_turn += 1;
}
pub fn mention_entity(&mut self, name: &str, entity_type: EntityType) {
self.focus_stack.retain(|n| n != name);
self.focus_stack.insert(0, name.to_string());
if self.focus_stack.len() > 20 {
self.focus_stack.truncate(20);
}
self.mention_history
.entry(name.to_string())
.or_default()
.push(self.current_turn);
self.entity_types.insert(name.to_string(), entity_type);
}
pub fn mark_modified(&mut self, name: &str) {
self.recently_modified.retain(|n| n != name);
self.recently_modified.insert(0, name.to_string());
if self.recently_modified.len() > 10 {
self.recently_modified.truncate(10);
}
}
pub fn get_entity_type(&self, name: &str) -> Option<&EntityType> {
self.entity_types.get(name)
}
pub fn recency_score(&self, name: &str) -> f32 {
if let Some(pos) = self.focus_stack.iter().position(|n| n == name) {
let focus_score = 1.0 - (pos as f32 / self.focus_stack.len() as f32);
let modified_bonus = if self.recently_modified.contains(&name.to_string()) {
0.2
} else {
0.0
};
(focus_score + modified_bonus).min(1.0)
} else {
if let Some(turns) = self.mention_history.get(name) {
if let Some(&last_turn) = turns.last() {
let age = self.current_turn.saturating_sub(last_turn) as f32;
(-0.1 * age).exp() } else {
0.0
}
} else {
0.0
}
}
}
pub fn frequency_score(&self, name: &str) -> f32 {
if let Some(turns) = self.mention_history.get(name) {
let count = turns.len() as f32;
(count.ln_1p() / 3.0).min(1.0)
} else {
0.0
}
}
pub fn clear(&mut self) {
self.focus_stack.clear();
self.mention_history.clear();
self.current_turn = 0;
self.recently_modified.clear();
self.entity_types.clear();
}
}
struct ReferencePattern {
regex: &'static Regex,
ref_type_fn: fn(®ex::Captures) -> ReferenceType,
}
pub struct CoreferenceResolver {
pronoun_patterns: Vec<ReferencePattern>,
definite_np_patterns: Vec<ReferencePattern>,
demonstrative_patterns: Vec<ReferencePattern>,
}
impl CoreferenceResolver {
pub fn new() -> Self {
Self {
pronoun_patterns: Self::build_pronoun_patterns(),
definite_np_patterns: Self::build_definite_np_patterns(),
demonstrative_patterns: Self::build_demonstrative_patterns(),
}
}
fn build_pronoun_patterns() -> Vec<ReferencePattern> {
vec![
ReferencePattern {
regex: &RE_SINGULAR_NEUTRAL,
ref_type_fn: |_| ReferenceType::SingularNeutral,
},
ReferencePattern {
regex: &RE_PLURAL,
ref_type_fn: |_| ReferenceType::Plural,
},
]
}
fn build_definite_np_patterns() -> Vec<ReferencePattern> {
vec![
ReferencePattern {
regex: &RE_THE_FILE,
ref_type_fn: |_| ReferenceType::DefiniteNP {
entity_type: EntityType::File,
},
},
ReferencePattern {
regex: &RE_THE_FUNCTION,
ref_type_fn: |_| ReferenceType::DefiniteNP {
entity_type: EntityType::Function,
},
},
ReferencePattern {
regex: &RE_THE_TYPE,
ref_type_fn: |_| ReferenceType::DefiniteNP {
entity_type: EntityType::Type,
},
},
ReferencePattern {
regex: &RE_THE_ERROR,
ref_type_fn: |_| ReferenceType::DefiniteNP {
entity_type: EntityType::Error,
},
},
ReferencePattern {
regex: &RE_THE_VARIABLE,
ref_type_fn: |_| ReferenceType::DefiniteNP {
entity_type: EntityType::Variable,
},
},
ReferencePattern {
regex: &RE_THE_COMMAND,
ref_type_fn: |_| ReferenceType::DefiniteNP {
entity_type: EntityType::Command,
},
},
]
}
fn build_demonstrative_patterns() -> Vec<ReferencePattern> {
vec![
ReferencePattern {
regex: &RE_DEMO_FILE,
ref_type_fn: |_| ReferenceType::Demonstrative {
entity_type: EntityType::File,
},
},
ReferencePattern {
regex: &RE_DEMO_FUNCTION,
ref_type_fn: |_| ReferenceType::Demonstrative {
entity_type: EntityType::Function,
},
},
ReferencePattern {
regex: &RE_DEMO_TYPE,
ref_type_fn: |_| ReferenceType::Demonstrative {
entity_type: EntityType::Type,
},
},
ReferencePattern {
regex: &RE_DEMO_ERROR,
ref_type_fn: |_| ReferenceType::Demonstrative {
entity_type: EntityType::Error,
},
},
]
}
pub fn detect_references(&self, message: &str) -> Vec<UnresolvedReference> {
let mut references = Vec::new();
let lower = message.to_lowercase();
for pattern in &self.demonstrative_patterns {
for cap in pattern.regex.captures_iter(&lower) {
if let Some(m) = cap.get(0) {
references.push(UnresolvedReference {
text: m.as_str().to_string(),
ref_type: (pattern.ref_type_fn)(&cap),
start: m.start(),
end: m.end(),
});
}
}
}
for pattern in &self.definite_np_patterns {
for cap in pattern.regex.captures_iter(&lower) {
if let Some(m) = cap.get(0) {
let overlaps = references
.iter()
.any(|r| r.start <= m.start() && r.end >= m.end());
if !overlaps {
references.push(UnresolvedReference {
text: m.as_str().to_string(),
ref_type: (pattern.ref_type_fn)(&cap),
start: m.start(),
end: m.end(),
});
}
}
}
}
for pattern in &self.pronoun_patterns {
for cap in pattern.regex.captures_iter(&lower) {
if let Some(m) = cap.get(0) {
let overlaps = references
.iter()
.any(|r| r.start <= m.start() && r.end >= m.end());
if !overlaps {
references.push(UnresolvedReference {
text: m.as_str().to_string(),
ref_type: (pattern.ref_type_fn)(&cap),
start: m.start(),
end: m.end(),
});
}
}
}
}
references.sort_by_key(|r| r.start);
references
}
pub fn resolve(
&self,
references: &[UnresolvedReference],
dialog_state: &DialogState,
entity_store: &dyn EntityStoreT,
graph: Option<&dyn RelationshipGraphT>,
) -> Vec<ResolvedReference> {
let mut resolved = Vec::new();
for reference in references {
if let Some(resolution) =
self.resolve_single(reference, dialog_state, entity_store, graph)
{
resolved.push(resolution);
}
}
resolved
}
fn resolve_single(
&self,
reference: &UnresolvedReference,
dialog_state: &DialogState,
entity_store: &dyn EntityStoreT,
graph: Option<&dyn RelationshipGraphT>,
) -> Option<ResolvedReference> {
let compatible_types = reference.ref_type.compatible_types();
let mut candidates: Vec<(&str, &EntityType, SalienceScore)> = Vec::new();
for name in &dialog_state.focus_stack {
if let Some(entity_type) = dialog_state.get_entity_type(name)
&& compatible_types.contains(entity_type)
{
let salience = self.compute_salience(name, entity_type, dialog_state, graph);
candidates.push((name, entity_type, salience));
}
}
let entity_names: Vec<(String, EntityType)> = compatible_types
.iter()
.flat_map(|et| {
entity_store
.entity_names_by_type(et)
.into_iter()
.map(move |name| (name, et.clone()))
})
.collect();
for (entity_name, entity_type) in &entity_names {
if candidates
.iter()
.any(|(n, _, _)| *n == entity_name.as_str())
{
continue;
}
let salience = self.compute_salience(entity_name, entity_type, dialog_state, graph);
candidates.push((entity_name, entity_type, salience));
}
candidates.sort_by(|a, b| {
b.2.total()
.partial_cmp(&a.2.total())
.unwrap_or(std::cmp::Ordering::Equal)
});
candidates
.first()
.map(|(name, entity_type, salience)| ResolvedReference {
reference: reference.clone(),
antecedent: name.to_string(),
entity_type: (*entity_type).clone(),
confidence: salience.total(),
salience: salience.clone(),
})
}
fn compute_salience(
&self,
name: &str,
_entity_type: &EntityType,
dialog_state: &DialogState,
graph: Option<&dyn RelationshipGraphT>,
) -> SalienceScore {
let recency = dialog_state.recency_score(name);
let frequency = dialog_state.frequency_score(name);
let graph_centrality = if let Some(g) = graph {
if let Some(node) = g.get_node(name) {
node.importance
} else {
0.0
}
} else {
0.5 };
let type_match = 1.0;
let syntactic_prominence = if dialog_state.focus_stack.first() == Some(&name.to_string()) {
1.0
} else if dialog_state.focus_stack.contains(&name.to_string()) {
0.5
} else {
0.0
};
SalienceScore {
recency,
frequency,
graph_centrality,
type_match,
syntactic_prominence,
}
}
pub fn rewrite_with_resolutions(
&self,
message: &str,
resolutions: &[ResolvedReference],
) -> String {
if resolutions.is_empty() {
return message.to_string();
}
let mut sorted = resolutions.to_vec();
sorted.sort_by(|a, b| b.reference.start.cmp(&a.reference.start));
let mut result = message.to_string();
let lower = message.to_lowercase();
for resolution in sorted {
let search_start = resolution.reference.start;
let search_end = resolution.reference.end;
if search_end <= lower.len() && search_start < search_end {
let replacement = format!("[{}]", resolution.antecedent);
let ref_text = &lower[search_start..search_end];
if let Some(pos) = result.to_lowercase().find(ref_text) {
result = format!(
"{}{}{}",
&result[..pos],
replacement,
&result[pos + (search_end - search_start)..]
);
}
}
}
result
}
}
impl Default for CoreferenceResolver {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use brainwires_knowledge::knowledge::EntityStore;
#[test]
fn test_detect_pronouns() {
let resolver = CoreferenceResolver::new();
let refs = resolver.detect_references("Fix it and run the tests");
assert!(!refs.is_empty());
assert!(refs.iter().any(|r| r.text == "it"));
assert!(refs[0].ref_type == ReferenceType::SingularNeutral);
}
#[test]
fn test_detect_definite_np() {
let resolver = CoreferenceResolver::new();
let refs = resolver.detect_references("Update the file with the new logic");
assert!(refs.iter().any(|r| r.text == "the file"));
assert!(refs.iter().any(|r| matches!(
&r.ref_type,
ReferenceType::DefiniteNP { entity_type } if *entity_type == EntityType::File
)));
}
#[test]
fn test_detect_demonstrative() {
let resolver = CoreferenceResolver::new();
let refs = resolver.detect_references("Fix that error in the code");
assert!(refs.iter().any(|r| r.text == "that error"));
assert!(refs.iter().any(|r| matches!(
&r.ref_type,
ReferenceType::Demonstrative { entity_type } if *entity_type == EntityType::Error
)));
}
#[test]
fn test_dialog_state_mention() {
let mut state = DialogState::new();
state.mention_entity("main.rs", EntityType::File);
state.next_turn();
state.mention_entity("config.toml", EntityType::File);
assert_eq!(state.focus_stack[0], "config.toml");
assert_eq!(state.focus_stack[1], "main.rs");
assert!(state.recency_score("config.toml") > state.recency_score("main.rs"));
}
#[test]
fn test_dialog_state_frequency() {
let mut state = DialogState::new();
state.mention_entity("main.rs", EntityType::File);
state.next_turn();
state.mention_entity("main.rs", EntityType::File);
state.next_turn();
state.mention_entity("config.toml", EntityType::File);
assert!(state.frequency_score("main.rs") > state.frequency_score("config.toml"));
}
#[test]
fn test_resolve_pronoun() {
let resolver = CoreferenceResolver::new();
let mut state = DialogState::new();
let entity_store = EntityStore::new();
state.mention_entity("src/main.rs", EntityType::File);
state.next_turn();
let refs = resolver.detect_references("Fix it");
let resolved = resolver.resolve(&refs, &state, &entity_store, None);
assert_eq!(resolved.len(), 1);
assert_eq!(resolved[0].antecedent, "src/main.rs");
}
#[test]
fn test_resolve_type_constrained() {
let resolver = CoreferenceResolver::new();
let mut state = DialogState::new();
let entity_store = EntityStore::new();
state.mention_entity("main.rs", EntityType::File);
state.mention_entity("process_data", EntityType::Function);
state.next_turn();
let refs = resolver.detect_references("Update the function");
let resolved = resolver.resolve(&refs, &state, &entity_store, None);
assert_eq!(resolved.len(), 1);
assert_eq!(resolved[0].antecedent, "process_data");
}
#[test]
fn test_rewrite_with_resolutions() {
let resolver = CoreferenceResolver::new();
let mut state = DialogState::new();
let entity_store = EntityStore::new();
state.mention_entity("main.rs", EntityType::File);
state.next_turn();
let refs = resolver.detect_references("Fix it and test");
let resolved = resolver.resolve(&refs, &state, &entity_store, None);
let rewritten = resolver.rewrite_with_resolutions("Fix it and test", &resolved);
assert_eq!(rewritten, "Fix [main.rs] and test");
}
#[test]
fn test_salience_score_total() {
let score = SalienceScore {
recency: 1.0,
frequency: 0.5,
graph_centrality: 0.8,
type_match: 1.0,
syntactic_prominence: 0.5,
};
assert!((score.total() - 0.835).abs() < 0.001);
}
#[test]
fn test_empty_references() {
let resolver = CoreferenceResolver::new();
let refs = resolver.detect_references("Build the project using cargo");
assert!(refs.is_empty() || !refs.iter().any(|r| r.text == "the project"));
}
#[test]
fn test_multiple_references() {
let resolver = CoreferenceResolver::new();
let refs = resolver.detect_references("Fix it and update the file");
assert!(refs.len() >= 2);
let texts: Vec<_> = refs.iter().map(|r| r.text.as_str()).collect();
assert!(texts.contains(&"it"));
assert!(texts.contains(&"the file"));
}
}