1use super::rules::{RewriteRule, QueryPattern, AstPattern};
6use super::parser::ParsedQuery;
7use std::collections::HashMap;
8use regex::Regex;
9
10pub struct RuleMatcher {
12 fingerprint_index: HashMap<u64, Vec<usize>>,
14
15 regex_patterns: Vec<(Regex, usize)>,
17
18 table_index: HashMap<String, Vec<usize>>,
20
21 all_rules: Vec<usize>,
23
24 ast_rules: Vec<usize>,
26}
27
28impl RuleMatcher {
29 pub fn new(rules: &[RewriteRule]) -> Self {
31 let mut fingerprint_index: HashMap<u64, Vec<usize>> = HashMap::new();
32 let mut regex_patterns: Vec<(Regex, usize)> = Vec::new();
33 let mut table_index: HashMap<String, Vec<usize>> = HashMap::new();
34 let mut all_rules: Vec<usize> = Vec::new();
35 let mut ast_rules: Vec<usize> = Vec::new();
36
37 for (idx, rule) in rules.iter().enumerate() {
38 if !rule.enabled {
39 continue;
40 }
41
42 match &rule.pattern {
43 QueryPattern::Fingerprint(fp) => {
44 fingerprint_index.entry(*fp).or_default().push(idx);
45 }
46 QueryPattern::Regex(pattern) => {
47 if let Ok(re) = Regex::new(pattern) {
48 regex_patterns.push((re, idx));
49 }
50 }
51 QueryPattern::Table(table) => {
52 table_index.entry(table.clone()).or_default().push(idx);
53 }
54 QueryPattern::TableAny(tables) => {
55 for table in tables {
56 table_index.entry(table.clone()).or_default().push(idx);
57 }
58 }
59 QueryPattern::Ast(_) => {
60 ast_rules.push(idx);
61 }
62 QueryPattern::All => {
63 all_rules.push(idx);
64 }
65 }
66 }
67
68 Self {
69 fingerprint_index,
70 regex_patterns,
71 table_index,
72 all_rules,
73 ast_rules,
74 }
75 }
76
77 pub fn match_query<'a>(&self, parsed: &ParsedQuery, rules: &'a [RewriteRule]) -> Vec<&'a RewriteRule> {
79 let mut matched_indices: Vec<usize> = Vec::new();
80
81 if let Some(indices) = self.fingerprint_index.get(&parsed.fingerprint()) {
83 matched_indices.extend(indices);
84 }
85
86 for (regex, idx) in &self.regex_patterns {
88 if regex.is_match(&parsed.original) {
89 matched_indices.push(*idx);
90 }
91 }
92
93 for table in &parsed.tables {
95 if let Some(indices) = self.table_index.get(table) {
96 matched_indices.extend(indices);
97 }
98 }
99
100 for &idx in &self.ast_rules {
102 if let Some(rule) = rules.get(idx) {
103 if self.matches_ast_pattern(&rule.pattern, parsed) {
104 matched_indices.push(idx);
105 }
106 }
107 }
108
109 matched_indices.extend(&self.all_rules);
111
112 matched_indices.sort_unstable();
114 matched_indices.dedup();
115
116 let mut matched: Vec<&RewriteRule> = matched_indices
117 .into_iter()
118 .filter_map(|idx| rules.get(idx))
119 .filter(|r| r.enabled)
120 .collect();
121
122 matched.sort_by_key(|r| -r.priority);
124
125 matched
126 }
127
128 fn matches_ast_pattern(&self, pattern: &QueryPattern, parsed: &ParsedQuery) -> bool {
130 match pattern {
131 QueryPattern::Ast(ast_pattern) => self.matches_ast(ast_pattern, parsed),
132 _ => false,
133 }
134 }
135
136 fn matches_ast(&self, pattern: &AstPattern, parsed: &ParsedQuery) -> bool {
138 match pattern {
139 AstPattern::SelectStar => parsed.has_select_star,
140 AstPattern::SelectFrom { table } => {
141 parsed.is_select && parsed.tables.contains(table)
142 }
143 AstPattern::NoLimit => !parsed.has_limit,
144 AstPattern::NoWhere => !parsed.has_where,
145 AstPattern::Insert => parsed.is_insert,
146 AstPattern::Update => parsed.is_update,
147 AstPattern::Delete => parsed.is_delete,
148 AstPattern::Ddl => parsed.is_ddl,
149 AstPattern::NPlusOne { table } => {
150 parsed.tables.contains(table) && !parsed.has_limit
153 }
154 AstPattern::FullTableScan => {
155 parsed.is_select && !parsed.has_where
156 }
157 AstPattern::And(patterns) => {
158 patterns.iter().all(|p| self.matches_ast(p, parsed))
159 }
160 AstPattern::Or(patterns) => {
161 patterns.iter().any(|p| self.matches_ast(p, parsed))
162 }
163 }
164 }
165
166 pub fn stats(&self) -> MatcherStats {
168 MatcherStats {
169 fingerprint_rules: self.fingerprint_index.values().map(|v| v.len()).sum(),
170 regex_rules: self.regex_patterns.len(),
171 table_rules: self.table_index.values().map(|v| v.len()).sum(),
172 all_rules: self.all_rules.len(),
173 ast_rules: self.ast_rules.len(),
174 }
175 }
176}
177
178#[derive(Debug, Clone)]
180pub struct MatchResult {
181 pub rule_ids: Vec<String>,
183
184 pub fingerprint: u64,
186
187 pub tables: Vec<String>,
189}
190
191#[derive(Debug, Clone)]
193pub struct MatcherStats {
194 pub fingerprint_rules: usize,
196
197 pub regex_rules: usize,
199
200 pub table_rules: usize,
202
203 pub all_rules: usize,
205
206 pub ast_rules: usize,
208}
209
210impl MatcherStats {
211 pub fn total(&self) -> usize {
213 self.fingerprint_rules + self.regex_rules + self.table_rules + self.all_rules + self.ast_rules
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use super::super::rules::{Transformation, RewriteRule};
221
222 fn test_rules() -> Vec<RewriteRule> {
223 vec![
224 RewriteRule::new("fp_match")
225 .pattern(QueryPattern::Fingerprint(12345))
226 .transform(Transformation::NoOp)
227 .priority(100)
228 .build(),
229 RewriteRule::new("regex_match")
230 .pattern(QueryPattern::Regex(r"SELECT .* FROM users".to_string()))
231 .transform(Transformation::NoOp)
232 .priority(50)
233 .build(),
234 RewriteRule::new("table_match")
235 .pattern(QueryPattern::Table("orders".to_string()))
236 .transform(Transformation::NoOp)
237 .priority(75)
238 .build(),
239 RewriteRule::new("all_match")
240 .pattern(QueryPattern::All)
241 .transform(Transformation::AddLimit(1000))
242 .priority(10)
243 .build(),
244 RewriteRule::new("ast_match")
245 .pattern(QueryPattern::Ast(AstPattern::SelectStar))
246 .transform(Transformation::NoOp)
247 .priority(60)
248 .build(),
249 ]
250 }
251
252 #[test]
253 fn test_matcher_creation() {
254 let rules = test_rules();
255 let matcher = RuleMatcher::new(&rules);
256
257 let stats = matcher.stats();
258 assert_eq!(stats.fingerprint_rules, 1);
259 assert_eq!(stats.regex_rules, 1);
260 assert_eq!(stats.table_rules, 1);
261 assert_eq!(stats.all_rules, 1);
262 assert_eq!(stats.ast_rules, 1);
263 }
264
265 #[test]
266 fn test_matcher_all_rules() {
267 let rules = test_rules();
268 let matcher = RuleMatcher::new(&rules);
269
270 let parsed = ParsedQuery {
271 original: "SELECT 1".to_string(),
272 normalized: "SELECT ?".to_string(),
273 tables: vec![],
274 has_select_star: false,
275 has_limit: false,
276 has_where: false,
277 is_select: true,
278 is_insert: false,
279 is_update: false,
280 is_delete: false,
281 is_ddl: false,
282 };
283
284 let matched = matcher.match_query(&parsed, &rules);
285 assert!(matched.iter().any(|r| r.id == "all_match"));
286 }
287
288 #[test]
289 fn test_matcher_regex() {
290 let rules = test_rules();
291 let matcher = RuleMatcher::new(&rules);
292
293 let parsed = ParsedQuery {
294 original: "SELECT id, name FROM users WHERE id = 1".to_string(),
295 normalized: "SELECT id, name FROM users WHERE id = ?".to_string(),
296 tables: vec!["users".to_string()],
297 has_select_star: false,
298 has_limit: false,
299 has_where: true,
300 is_select: true,
301 is_insert: false,
302 is_update: false,
303 is_delete: false,
304 is_ddl: false,
305 };
306
307 let matched = matcher.match_query(&parsed, &rules);
308 assert!(matched.iter().any(|r| r.id == "regex_match"));
309 }
310
311 #[test]
312 fn test_matcher_table() {
313 let rules = test_rules();
314 let matcher = RuleMatcher::new(&rules);
315
316 let parsed = ParsedQuery {
317 original: "SELECT * FROM orders".to_string(),
318 normalized: "SELECT * FROM orders".to_string(),
319 tables: vec!["orders".to_string()],
320 has_select_star: true,
321 has_limit: false,
322 has_where: false,
323 is_select: true,
324 is_insert: false,
325 is_update: false,
326 is_delete: false,
327 is_ddl: false,
328 };
329
330 let matched = matcher.match_query(&parsed, &rules);
331 assert!(matched.iter().any(|r| r.id == "table_match"));
332 assert!(matched.iter().any(|r| r.id == "ast_match")); }
334
335 #[test]
336 fn test_matcher_priority_ordering() {
337 let rules = test_rules();
338 let matcher = RuleMatcher::new(&rules);
339
340 let parsed = ParsedQuery {
341 original: "SELECT * FROM orders".to_string(),
342 normalized: "SELECT * FROM orders".to_string(),
343 tables: vec!["orders".to_string()],
344 has_select_star: true,
345 has_limit: false,
346 has_where: false,
347 is_select: true,
348 is_insert: false,
349 is_update: false,
350 is_delete: false,
351 is_ddl: false,
352 };
353
354 let matched = matcher.match_query(&parsed, &rules);
355 assert!(matched.len() >= 3);
357 assert!(matched[0].priority >= matched[1].priority);
358 }
359}