Skip to main content

heliosdb_proxy/rewriter/
matcher.rs

1//! Rule Matcher
2//!
3//! Efficient matching of queries against rewrite rules.
4
5use super::parser::ParsedQuery;
6use super::rules::{AstPattern, QueryPattern, RewriteRule};
7use regex::Regex;
8use std::collections::HashMap;
9
10/// Rule matcher for efficient query matching
11pub struct RuleMatcher {
12    /// Fingerprint index for fast lookup
13    fingerprint_index: HashMap<u64, Vec<usize>>,
14
15    /// Compiled regex patterns
16    regex_patterns: Vec<(Regex, usize)>,
17
18    /// Table index
19    table_index: HashMap<String, Vec<usize>>,
20
21    /// Rules that match all queries
22    all_rules: Vec<usize>,
23
24    /// AST pattern rules
25    ast_rules: Vec<usize>,
26}
27
28impl RuleMatcher {
29    /// Create a new matcher from rules
30    pub fn new(rules: &[RewriteRule]) -> Self {
31        let mut fingerprint_index: HashMap<u64, Vec<usize>> = HashMap::new();
32        let mut regex_patterns: Vec<(Regex, usize)> = Vec::new();
33        let mut table_index: HashMap<String, Vec<usize>> = HashMap::new();
34        let mut all_rules: Vec<usize> = Vec::new();
35        let mut ast_rules: Vec<usize> = Vec::new();
36
37        for (idx, rule) in rules.iter().enumerate() {
38            if !rule.enabled {
39                continue;
40            }
41
42            match &rule.pattern {
43                QueryPattern::Fingerprint(fp) => {
44                    fingerprint_index.entry(*fp).or_default().push(idx);
45                }
46                QueryPattern::Regex(pattern) => {
47                    if let Ok(re) = Regex::new(pattern) {
48                        regex_patterns.push((re, idx));
49                    }
50                }
51                QueryPattern::Table(table) => {
52                    table_index.entry(table.clone()).or_default().push(idx);
53                }
54                QueryPattern::TableAny(tables) => {
55                    for table in tables {
56                        table_index.entry(table.clone()).or_default().push(idx);
57                    }
58                }
59                QueryPattern::Ast(_) => {
60                    ast_rules.push(idx);
61                }
62                QueryPattern::All => {
63                    all_rules.push(idx);
64                }
65            }
66        }
67
68        Self {
69            fingerprint_index,
70            regex_patterns,
71            table_index,
72            all_rules,
73            ast_rules,
74        }
75    }
76
77    /// Match a query against rules
78    pub fn match_query<'a>(
79        &self,
80        parsed: &ParsedQuery,
81        rules: &'a [RewriteRule],
82    ) -> Vec<&'a RewriteRule> {
83        let mut matched_indices: Vec<usize> = Vec::new();
84
85        // Check fingerprint matches (fast path)
86        if let Some(indices) = self.fingerprint_index.get(&parsed.fingerprint()) {
87            matched_indices.extend(indices);
88        }
89
90        // Check regex matches
91        for (regex, idx) in &self.regex_patterns {
92            if regex.is_match(&parsed.original) {
93                matched_indices.push(*idx);
94            }
95        }
96
97        // Check table matches
98        for table in &parsed.tables {
99            if let Some(indices) = self.table_index.get(table) {
100                matched_indices.extend(indices);
101            }
102        }
103
104        // Check AST pattern matches
105        for &idx in &self.ast_rules {
106            if let Some(rule) = rules.get(idx) {
107                if self.matches_ast_pattern(&rule.pattern, parsed) {
108                    matched_indices.push(idx);
109                }
110            }
111        }
112
113        // Add all-matching rules
114        matched_indices.extend(&self.all_rules);
115
116        // Deduplicate and sort by priority
117        matched_indices.sort_unstable();
118        matched_indices.dedup();
119
120        let mut matched: Vec<&RewriteRule> = matched_indices
121            .into_iter()
122            .filter_map(|idx| rules.get(idx))
123            .filter(|r| r.enabled)
124            .collect();
125
126        // Sort by priority (highest first)
127        matched.sort_by_key(|r| -r.priority);
128
129        matched
130    }
131
132    /// Check if query matches AST pattern
133    fn matches_ast_pattern(&self, pattern: &QueryPattern, parsed: &ParsedQuery) -> bool {
134        match pattern {
135            QueryPattern::Ast(ast_pattern) => self.matches_ast(ast_pattern, parsed),
136            _ => false,
137        }
138    }
139
140    /// Match AST pattern against parsed query
141    fn matches_ast(&self, pattern: &AstPattern, parsed: &ParsedQuery) -> bool {
142        match pattern {
143            AstPattern::SelectStar => parsed.has_select_star,
144            AstPattern::SelectFrom { table } => parsed.is_select && parsed.tables.contains(table),
145            AstPattern::NoLimit => !parsed.has_limit,
146            AstPattern::NoWhere => !parsed.has_where,
147            AstPattern::Insert => parsed.is_insert,
148            AstPattern::Update => parsed.is_update,
149            AstPattern::Delete => parsed.is_delete,
150            AstPattern::Ddl => parsed.is_ddl,
151            AstPattern::NPlusOne { table } => {
152                // N+1 detection: SELECT ... WHERE id = $1 in loop
153                // Simplified: just check if table is accessed
154                parsed.tables.contains(table) && !parsed.has_limit
155            }
156            AstPattern::FullTableScan => parsed.is_select && !parsed.has_where,
157            AstPattern::And(patterns) => patterns.iter().all(|p| self.matches_ast(p, parsed)),
158            AstPattern::Or(patterns) => patterns.iter().any(|p| self.matches_ast(p, parsed)),
159        }
160    }
161
162    /// Get statistics about the matcher
163    pub fn stats(&self) -> MatcherStats {
164        MatcherStats {
165            fingerprint_rules: self.fingerprint_index.values().map(|v| v.len()).sum(),
166            regex_rules: self.regex_patterns.len(),
167            table_rules: self.table_index.values().map(|v| v.len()).sum(),
168            all_rules: self.all_rules.len(),
169            ast_rules: self.ast_rules.len(),
170        }
171    }
172}
173
174/// Match result
175#[derive(Debug, Clone)]
176pub struct MatchResult {
177    /// Matched rule IDs
178    pub rule_ids: Vec<String>,
179
180    /// Query fingerprint
181    pub fingerprint: u64,
182
183    /// Tables referenced
184    pub tables: Vec<String>,
185}
186
187/// Matcher statistics
188#[derive(Debug, Clone)]
189pub struct MatcherStats {
190    /// Number of fingerprint-indexed rules
191    pub fingerprint_rules: usize,
192
193    /// Number of regex rules
194    pub regex_rules: usize,
195
196    /// Number of table-indexed rules
197    pub table_rules: usize,
198
199    /// Number of all-matching rules
200    pub all_rules: usize,
201
202    /// Number of AST pattern rules
203    pub ast_rules: usize,
204}
205
206impl MatcherStats {
207    /// Total rules indexed
208    pub fn total(&self) -> usize {
209        self.fingerprint_rules
210            + self.regex_rules
211            + self.table_rules
212            + self.all_rules
213            + self.ast_rules
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::super::rules::{RewriteRule, Transformation};
220    use super::*;
221
222    fn test_rules() -> Vec<RewriteRule> {
223        vec![
224            RewriteRule::build("fp_match")
225                .pattern(QueryPattern::Fingerprint(12345))
226                .transform(Transformation::NoOp)
227                .priority(100)
228                .build(),
229            RewriteRule::build("regex_match")
230                .pattern(QueryPattern::Regex(r"SELECT .* FROM users".to_string()))
231                .transform(Transformation::NoOp)
232                .priority(50)
233                .build(),
234            RewriteRule::build("table_match")
235                .pattern(QueryPattern::Table("orders".to_string()))
236                .transform(Transformation::NoOp)
237                .priority(75)
238                .build(),
239            RewriteRule::build("all_match")
240                .pattern(QueryPattern::All)
241                .transform(Transformation::AddLimit(1000))
242                .priority(10)
243                .build(),
244            RewriteRule::build("ast_match")
245                .pattern(QueryPattern::Ast(AstPattern::SelectStar))
246                .transform(Transformation::NoOp)
247                .priority(60)
248                .build(),
249        ]
250    }
251
252    #[test]
253    fn test_matcher_creation() {
254        let rules = test_rules();
255        let matcher = RuleMatcher::new(&rules);
256
257        let stats = matcher.stats();
258        assert_eq!(stats.fingerprint_rules, 1);
259        assert_eq!(stats.regex_rules, 1);
260        assert_eq!(stats.table_rules, 1);
261        assert_eq!(stats.all_rules, 1);
262        assert_eq!(stats.ast_rules, 1);
263    }
264
265    #[test]
266    fn test_matcher_all_rules() {
267        let rules = test_rules();
268        let matcher = RuleMatcher::new(&rules);
269
270        let parsed = ParsedQuery {
271            original: "SELECT 1".to_string(),
272            normalized: "SELECT ?".to_string(),
273            tables: vec![],
274            has_select_star: false,
275            has_limit: false,
276            has_where: false,
277            is_select: true,
278            is_insert: false,
279            is_update: false,
280            is_delete: false,
281            is_ddl: false,
282        };
283
284        let matched = matcher.match_query(&parsed, &rules);
285        assert!(matched.iter().any(|r| r.id == "all_match"));
286    }
287
288    #[test]
289    fn test_matcher_regex() {
290        let rules = test_rules();
291        let matcher = RuleMatcher::new(&rules);
292
293        let parsed = ParsedQuery {
294            original: "SELECT id, name FROM users WHERE id = 1".to_string(),
295            normalized: "SELECT id, name FROM users WHERE id = ?".to_string(),
296            tables: vec!["users".to_string()],
297            has_select_star: false,
298            has_limit: false,
299            has_where: true,
300            is_select: true,
301            is_insert: false,
302            is_update: false,
303            is_delete: false,
304            is_ddl: false,
305        };
306
307        let matched = matcher.match_query(&parsed, &rules);
308        assert!(matched.iter().any(|r| r.id == "regex_match"));
309    }
310
311    #[test]
312    fn test_matcher_table() {
313        let rules = test_rules();
314        let matcher = RuleMatcher::new(&rules);
315
316        let parsed = ParsedQuery {
317            original: "SELECT * FROM orders".to_string(),
318            normalized: "SELECT * FROM orders".to_string(),
319            tables: vec!["orders".to_string()],
320            has_select_star: true,
321            has_limit: false,
322            has_where: false,
323            is_select: true,
324            is_insert: false,
325            is_update: false,
326            is_delete: false,
327            is_ddl: false,
328        };
329
330        let matched = matcher.match_query(&parsed, &rules);
331        assert!(matched.iter().any(|r| r.id == "table_match"));
332        assert!(matched.iter().any(|r| r.id == "ast_match")); // SELECT *
333    }
334
335    #[test]
336    fn test_matcher_priority_ordering() {
337        let rules = test_rules();
338        let matcher = RuleMatcher::new(&rules);
339
340        let parsed = ParsedQuery {
341            original: "SELECT * FROM orders".to_string(),
342            normalized: "SELECT * FROM orders".to_string(),
343            tables: vec!["orders".to_string()],
344            has_select_star: true,
345            has_limit: false,
346            has_where: false,
347            is_select: true,
348            is_insert: false,
349            is_update: false,
350            is_delete: false,
351            is_ddl: false,
352        };
353
354        let matched = matcher.match_query(&parsed, &rules);
355        // Should be ordered by priority: table_match (75), ast_match (60), all_match (10)
356        assert!(matched.len() >= 3);
357        assert!(matched[0].priority >= matched[1].priority);
358    }
359}