agentic_evolve_core/matching/
context.rs1use crate::types::error::EvolveResult;
4use crate::types::match_result::{MatchContext, MatchResult, MatchScore};
5use crate::types::pattern::{FunctionSignature, Pattern};
6
7#[derive(Debug, Default)]
9pub struct ContextMatcher;
10
11impl ContextMatcher {
12 pub fn new() -> Self {
13 Self
14 }
15
16 pub fn find_matches(
17 &self,
18 _signature: &FunctionSignature,
19 patterns: &[&Pattern],
20 context: &MatchContext,
21 limit: usize,
22 ) -> EvolveResult<Vec<MatchResult>> {
23 let mut results: Vec<MatchResult> = patterns
24 .iter()
25 .map(|p| {
26 let score = self.score_context(p, context);
27 MatchResult {
28 pattern_id: p.id.clone(),
29 pattern: (*p).clone(),
30 score: MatchScore::from_single(score),
31 suggested_bindings: std::collections::HashMap::new(),
32 }
33 })
34 .filter(|r| r.score.combined > 0.0)
35 .collect();
36
37 results.sort_by(|a, b| {
38 b.score
39 .combined
40 .partial_cmp(&a.score.combined)
41 .unwrap_or(std::cmp::Ordering::Equal)
42 });
43 results.truncate(limit);
44 Ok(results)
45 }
46
47 pub fn score_context(&self, pattern: &Pattern, context: &MatchContext) -> f64 {
48 let mut score = 0.0;
49 let mut factors = 0;
50
51 if let Some(domain) = &context.domain {
53 if pattern.domain.to_lowercase() == domain.to_lowercase() {
54 score += 1.0;
55 } else if pattern
56 .domain
57 .to_lowercase()
58 .contains(&domain.to_lowercase())
59 {
60 score += 0.5;
61 }
62 factors += 1;
63 }
64
65 if !context.imports.is_empty() {
67 let template_lower = pattern.template.to_lowercase();
68 let import_matches = context
69 .imports
70 .iter()
71 .filter(|imp| template_lower.contains(&imp.to_lowercase()))
72 .count();
73 if !context.imports.is_empty() {
74 score += import_matches as f64 / context.imports.len() as f64;
75 }
76 factors += 1;
77 }
78
79 if let Some(surrounding) = &context.surrounding_code {
81 let words: std::collections::HashSet<&str> = surrounding.split_whitespace().collect();
82 let template_words: std::collections::HashSet<&str> =
83 pattern.template.split_whitespace().collect();
84 let overlap = words.intersection(&template_words).count();
85 let total = words.len().max(1);
86 score += overlap as f64 / total as f64;
87 factors += 1;
88 }
89
90 if factors > 0 {
91 score / factors as f64
92 } else {
93 0.3 }
95 }
96}