Skip to main content

heliosdb_proxy/rewriter/
matcher.rs

1//! Rule Matcher
2//!
3//! Efficient matching of queries against rewrite rules.
4
5use super::rules::{RewriteRule, QueryPattern, AstPattern};
6use super::parser::ParsedQuery;
7use std::collections::HashMap;
8use regex::Regex;
9
10/// Rule matcher for efficient query matching
11pub struct RuleMatcher {
12    /// Fingerprint index for fast lookup
13    fingerprint_index: HashMap<u64, Vec<usize>>,
14
15    /// Compiled regex patterns
16    regex_patterns: Vec<(Regex, usize)>,
17
18    /// Table index
19    table_index: HashMap<String, Vec<usize>>,
20
21    /// Rules that match all queries
22    all_rules: Vec<usize>,
23
24    /// AST pattern rules
25    ast_rules: Vec<usize>,
26}
27
28impl RuleMatcher {
29    /// Create a new matcher from rules
30    pub fn new(rules: &[RewriteRule]) -> Self {
31        let mut fingerprint_index: HashMap<u64, Vec<usize>> = HashMap::new();
32        let mut regex_patterns: Vec<(Regex, usize)> = Vec::new();
33        let mut table_index: HashMap<String, Vec<usize>> = HashMap::new();
34        let mut all_rules: Vec<usize> = Vec::new();
35        let mut ast_rules: Vec<usize> = Vec::new();
36
37        for (idx, rule) in rules.iter().enumerate() {
38            if !rule.enabled {
39                continue;
40            }
41
42            match &rule.pattern {
43                QueryPattern::Fingerprint(fp) => {
44                    fingerprint_index.entry(*fp).or_default().push(idx);
45                }
46                QueryPattern::Regex(pattern) => {
47                    if let Ok(re) = Regex::new(pattern) {
48                        regex_patterns.push((re, idx));
49                    }
50                }
51                QueryPattern::Table(table) => {
52                    table_index.entry(table.clone()).or_default().push(idx);
53                }
54                QueryPattern::TableAny(tables) => {
55                    for table in tables {
56                        table_index.entry(table.clone()).or_default().push(idx);
57                    }
58                }
59                QueryPattern::Ast(_) => {
60                    ast_rules.push(idx);
61                }
62                QueryPattern::All => {
63                    all_rules.push(idx);
64                }
65            }
66        }
67
68        Self {
69            fingerprint_index,
70            regex_patterns,
71            table_index,
72            all_rules,
73            ast_rules,
74        }
75    }
76
77    /// Match a query against rules
78    pub fn match_query<'a>(&self, parsed: &ParsedQuery, rules: &'a [RewriteRule]) -> Vec<&'a RewriteRule> {
79        let mut matched_indices: Vec<usize> = Vec::new();
80
81        // Check fingerprint matches (fast path)
82        if let Some(indices) = self.fingerprint_index.get(&parsed.fingerprint()) {
83            matched_indices.extend(indices);
84        }
85
86        // Check regex matches
87        for (regex, idx) in &self.regex_patterns {
88            if regex.is_match(&parsed.original) {
89                matched_indices.push(*idx);
90            }
91        }
92
93        // Check table matches
94        for table in &parsed.tables {
95            if let Some(indices) = self.table_index.get(table) {
96                matched_indices.extend(indices);
97            }
98        }
99
100        // Check AST pattern matches
101        for &idx in &self.ast_rules {
102            if let Some(rule) = rules.get(idx) {
103                if self.matches_ast_pattern(&rule.pattern, parsed) {
104                    matched_indices.push(idx);
105                }
106            }
107        }
108
109        // Add all-matching rules
110        matched_indices.extend(&self.all_rules);
111
112        // Deduplicate and sort by priority
113        matched_indices.sort_unstable();
114        matched_indices.dedup();
115
116        let mut matched: Vec<&RewriteRule> = matched_indices
117            .into_iter()
118            .filter_map(|idx| rules.get(idx))
119            .filter(|r| r.enabled)
120            .collect();
121
122        // Sort by priority (highest first)
123        matched.sort_by_key(|r| -r.priority);
124
125        matched
126    }
127
128    /// Check if query matches AST pattern
129    fn matches_ast_pattern(&self, pattern: &QueryPattern, parsed: &ParsedQuery) -> bool {
130        match pattern {
131            QueryPattern::Ast(ast_pattern) => self.matches_ast(ast_pattern, parsed),
132            _ => false,
133        }
134    }
135
136    /// Match AST pattern against parsed query
137    fn matches_ast(&self, pattern: &AstPattern, parsed: &ParsedQuery) -> bool {
138        match pattern {
139            AstPattern::SelectStar => parsed.has_select_star,
140            AstPattern::SelectFrom { table } => {
141                parsed.is_select && parsed.tables.contains(table)
142            }
143            AstPattern::NoLimit => !parsed.has_limit,
144            AstPattern::NoWhere => !parsed.has_where,
145            AstPattern::Insert => parsed.is_insert,
146            AstPattern::Update => parsed.is_update,
147            AstPattern::Delete => parsed.is_delete,
148            AstPattern::Ddl => parsed.is_ddl,
149            AstPattern::NPlusOne { table } => {
150                // N+1 detection: SELECT ... WHERE id = $1 in loop
151                // Simplified: just check if table is accessed
152                parsed.tables.contains(table) && !parsed.has_limit
153            }
154            AstPattern::FullTableScan => {
155                parsed.is_select && !parsed.has_where
156            }
157            AstPattern::And(patterns) => {
158                patterns.iter().all(|p| self.matches_ast(p, parsed))
159            }
160            AstPattern::Or(patterns) => {
161                patterns.iter().any(|p| self.matches_ast(p, parsed))
162            }
163        }
164    }
165
166    /// Get statistics about the matcher
167    pub fn stats(&self) -> MatcherStats {
168        MatcherStats {
169            fingerprint_rules: self.fingerprint_index.values().map(|v| v.len()).sum(),
170            regex_rules: self.regex_patterns.len(),
171            table_rules: self.table_index.values().map(|v| v.len()).sum(),
172            all_rules: self.all_rules.len(),
173            ast_rules: self.ast_rules.len(),
174        }
175    }
176}
177
178/// Match result
179#[derive(Debug, Clone)]
180pub struct MatchResult {
181    /// Matched rule IDs
182    pub rule_ids: Vec<String>,
183
184    /// Query fingerprint
185    pub fingerprint: u64,
186
187    /// Tables referenced
188    pub tables: Vec<String>,
189}
190
191/// Matcher statistics
192#[derive(Debug, Clone)]
193pub struct MatcherStats {
194    /// Number of fingerprint-indexed rules
195    pub fingerprint_rules: usize,
196
197    /// Number of regex rules
198    pub regex_rules: usize,
199
200    /// Number of table-indexed rules
201    pub table_rules: usize,
202
203    /// Number of all-matching rules
204    pub all_rules: usize,
205
206    /// Number of AST pattern rules
207    pub ast_rules: usize,
208}
209
210impl MatcherStats {
211    /// Total rules indexed
212    pub fn total(&self) -> usize {
213        self.fingerprint_rules + self.regex_rules + self.table_rules + self.all_rules + self.ast_rules
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use super::super::rules::{Transformation, RewriteRule};
221
222    fn test_rules() -> Vec<RewriteRule> {
223        vec![
224            RewriteRule::new("fp_match")
225                .pattern(QueryPattern::Fingerprint(12345))
226                .transform(Transformation::NoOp)
227                .priority(100)
228                .build(),
229            RewriteRule::new("regex_match")
230                .pattern(QueryPattern::Regex(r"SELECT .* FROM users".to_string()))
231                .transform(Transformation::NoOp)
232                .priority(50)
233                .build(),
234            RewriteRule::new("table_match")
235                .pattern(QueryPattern::Table("orders".to_string()))
236                .transform(Transformation::NoOp)
237                .priority(75)
238                .build(),
239            RewriteRule::new("all_match")
240                .pattern(QueryPattern::All)
241                .transform(Transformation::AddLimit(1000))
242                .priority(10)
243                .build(),
244            RewriteRule::new("ast_match")
245                .pattern(QueryPattern::Ast(AstPattern::SelectStar))
246                .transform(Transformation::NoOp)
247                .priority(60)
248                .build(),
249        ]
250    }
251
252    #[test]
253    fn test_matcher_creation() {
254        let rules = test_rules();
255        let matcher = RuleMatcher::new(&rules);
256
257        let stats = matcher.stats();
258        assert_eq!(stats.fingerprint_rules, 1);
259        assert_eq!(stats.regex_rules, 1);
260        assert_eq!(stats.table_rules, 1);
261        assert_eq!(stats.all_rules, 1);
262        assert_eq!(stats.ast_rules, 1);
263    }
264
265    #[test]
266    fn test_matcher_all_rules() {
267        let rules = test_rules();
268        let matcher = RuleMatcher::new(&rules);
269
270        let parsed = ParsedQuery {
271            original: "SELECT 1".to_string(),
272            normalized: "SELECT ?".to_string(),
273            tables: vec![],
274            has_select_star: false,
275            has_limit: false,
276            has_where: false,
277            is_select: true,
278            is_insert: false,
279            is_update: false,
280            is_delete: false,
281            is_ddl: false,
282        };
283
284        let matched = matcher.match_query(&parsed, &rules);
285        assert!(matched.iter().any(|r| r.id == "all_match"));
286    }
287
288    #[test]
289    fn test_matcher_regex() {
290        let rules = test_rules();
291        let matcher = RuleMatcher::new(&rules);
292
293        let parsed = ParsedQuery {
294            original: "SELECT id, name FROM users WHERE id = 1".to_string(),
295            normalized: "SELECT id, name FROM users WHERE id = ?".to_string(),
296            tables: vec!["users".to_string()],
297            has_select_star: false,
298            has_limit: false,
299            has_where: true,
300            is_select: true,
301            is_insert: false,
302            is_update: false,
303            is_delete: false,
304            is_ddl: false,
305        };
306
307        let matched = matcher.match_query(&parsed, &rules);
308        assert!(matched.iter().any(|r| r.id == "regex_match"));
309    }
310
311    #[test]
312    fn test_matcher_table() {
313        let rules = test_rules();
314        let matcher = RuleMatcher::new(&rules);
315
316        let parsed = ParsedQuery {
317            original: "SELECT * FROM orders".to_string(),
318            normalized: "SELECT * FROM orders".to_string(),
319            tables: vec!["orders".to_string()],
320            has_select_star: true,
321            has_limit: false,
322            has_where: false,
323            is_select: true,
324            is_insert: false,
325            is_update: false,
326            is_delete: false,
327            is_ddl: false,
328        };
329
330        let matched = matcher.match_query(&parsed, &rules);
331        assert!(matched.iter().any(|r| r.id == "table_match"));
332        assert!(matched.iter().any(|r| r.id == "ast_match")); // SELECT *
333    }
334
335    #[test]
336    fn test_matcher_priority_ordering() {
337        let rules = test_rules();
338        let matcher = RuleMatcher::new(&rules);
339
340        let parsed = ParsedQuery {
341            original: "SELECT * FROM orders".to_string(),
342            normalized: "SELECT * FROM orders".to_string(),
343            tables: vec!["orders".to_string()],
344            has_select_star: true,
345            has_limit: false,
346            has_where: false,
347            is_select: true,
348            is_insert: false,
349            is_update: false,
350            is_delete: false,
351            is_ddl: false,
352        };
353
354        let matched = matcher.match_query(&parsed, &rules);
355        // Should be ordered by priority: table_match (75), ast_match (60), all_match (10)
356        assert!(matched.len() >= 3);
357        assert!(matched[0].priority >= matched[1].priority);
358    }
359}