Skip to main content

agentic_evolve_core/matching/
context.rs

1//! ContextMatcher — matches based on surrounding code context.
2
3use crate::types::error::EvolveResult;
4use crate::types::match_result::{MatchContext, MatchResult, MatchScore};
5use crate::types::pattern::{FunctionSignature, Pattern};
6
7/// Matches patterns based on surrounding code context.
8#[derive(Debug, Default)]
9pub struct ContextMatcher;
10
11impl ContextMatcher {
12    pub fn new() -> Self {
13        Self
14    }
15
16    pub fn find_matches(
17        &self,
18        _signature: &FunctionSignature,
19        patterns: &[&Pattern],
20        context: &MatchContext,
21        limit: usize,
22    ) -> EvolveResult<Vec<MatchResult>> {
23        let mut results: Vec<MatchResult> = patterns
24            .iter()
25            .map(|p| {
26                let score = self.score_context(p, context);
27                MatchResult {
28                    pattern_id: p.id.clone(),
29                    pattern: (*p).clone(),
30                    score: MatchScore::from_single(score),
31                    suggested_bindings: std::collections::HashMap::new(),
32                }
33            })
34            .filter(|r| r.score.combined > 0.0)
35            .collect();
36
37        results.sort_by(|a, b| {
38            b.score
39                .combined
40                .partial_cmp(&a.score.combined)
41                .unwrap_or(std::cmp::Ordering::Equal)
42        });
43        results.truncate(limit);
44        Ok(results)
45    }
46
47    pub fn score_context(&self, pattern: &Pattern, context: &MatchContext) -> f64 {
48        let mut score = 0.0;
49        let mut factors = 0;
50
51        // Domain match
52        if let Some(domain) = &context.domain {
53            if pattern.domain.to_lowercase() == domain.to_lowercase() {
54                score += 1.0;
55            } else if pattern
56                .domain
57                .to_lowercase()
58                .contains(&domain.to_lowercase())
59            {
60                score += 0.5;
61            }
62            factors += 1;
63        }
64
65        // Import overlap
66        if !context.imports.is_empty() {
67            let template_lower = pattern.template.to_lowercase();
68            let import_matches = context
69                .imports
70                .iter()
71                .filter(|imp| template_lower.contains(&imp.to_lowercase()))
72                .count();
73            if !context.imports.is_empty() {
74                score += import_matches as f64 / context.imports.len() as f64;
75            }
76            factors += 1;
77        }
78
79        // Surrounding code similarity
80        if let Some(surrounding) = &context.surrounding_code {
81            let words: std::collections::HashSet<&str> = surrounding.split_whitespace().collect();
82            let template_words: std::collections::HashSet<&str> =
83                pattern.template.split_whitespace().collect();
84            let overlap = words.intersection(&template_words).count();
85            let total = words.len().max(1);
86            score += overlap as f64 / total as f64;
87            factors += 1;
88        }
89
90        if factors > 0 {
91            score / factors as f64
92        } else {
93            0.3 // Default context score when no context provided
94        }
95    }
96}