1use std::collections::HashSet;
6
7#[derive(Debug, Clone)]
9pub struct RewriteRule {
10 pub id: String,
12
13 pub description: String,
15
16 pub pattern: QueryPattern,
18
19 pub transformation: Transformation,
21
22 pub condition: Option<Condition>,
24
25 pub priority: i32,
27
28 pub enabled: bool,
30
31 pub tags: HashSet<String>,
33}
34
35impl RewriteRule {
36 pub fn new(id: impl Into<String>) -> RewriteRuleBuilder {
38 RewriteRuleBuilder::new(id)
39 }
40
41 pub fn matches(&self, fingerprint: u64, query: &str, tables: &[String]) -> bool {
43 if !self.enabled {
44 return false;
45 }
46
47 match &self.pattern {
48 QueryPattern::Fingerprint(fp) => *fp == fingerprint,
49 QueryPattern::Regex(pattern) => {
50 regex::Regex::new(pattern)
51 .map(|re| re.is_match(query))
52 .unwrap_or(false)
53 }
54 QueryPattern::Table(table) => tables.contains(table),
55 QueryPattern::TableAny(table_patterns) => {
56 tables.iter().any(|t| table_patterns.contains(t))
57 }
58 QueryPattern::Ast(_ast_pattern) => {
59 false
61 }
62 QueryPattern::All => true,
63 }
64 }
65}
66
67pub struct RewriteRuleBuilder {
69 rule: RewriteRule,
70}
71
72impl RewriteRuleBuilder {
73 pub fn new(id: impl Into<String>) -> Self {
75 Self {
76 rule: RewriteRule {
77 id: id.into(),
78 description: String::new(),
79 pattern: QueryPattern::All,
80 transformation: Transformation::NoOp,
81 condition: None,
82 priority: 0,
83 enabled: true,
84 tags: HashSet::new(),
85 },
86 }
87 }
88
89 pub fn description(mut self, desc: impl Into<String>) -> Self {
91 self.rule.description = desc.into();
92 self
93 }
94
95 pub fn pattern(mut self, pattern: QueryPattern) -> Self {
97 self.rule.pattern = pattern;
98 self
99 }
100
101 pub fn transform(mut self, transformation: Transformation) -> Self {
103 self.rule.transformation = transformation;
104 self
105 }
106
107 pub fn condition(mut self, condition: Condition) -> Self {
109 self.rule.condition = Some(condition);
110 self
111 }
112
113 pub fn priority(mut self, priority: i32) -> Self {
115 self.rule.priority = priority;
116 self
117 }
118
119 pub fn enabled(mut self, enabled: bool) -> Self {
121 self.rule.enabled = enabled;
122 self
123 }
124
125 pub fn tag(mut self, tag: impl Into<String>) -> Self {
127 self.rule.tags.insert(tag.into());
128 self
129 }
130
131 pub fn build(self) -> RewriteRule {
133 self.rule
134 }
135}
136
137impl From<RewriteRuleBuilder> for RewriteRule {
138 fn from(builder: RewriteRuleBuilder) -> Self {
139 builder.build()
140 }
141}
142
143#[derive(Debug, Clone)]
145pub enum QueryPattern {
146 Fingerprint(u64),
148
149 Regex(String),
151
152 Table(String),
154
155 TableAny(HashSet<String>),
157
158 Ast(AstPattern),
160
161 All,
163}
164
165impl QueryPattern {
166 pub fn fingerprint(fp: u64) -> Self {
168 Self::Fingerprint(fp)
169 }
170
171 pub fn regex(pattern: impl Into<String>) -> Self {
173 Self::Regex(pattern.into())
174 }
175
176 pub fn table(table: impl Into<String>) -> Self {
178 Self::Table(table.into())
179 }
180
181 pub fn table_any(tables: impl IntoIterator<Item = impl Into<String>>) -> Self {
183 Self::TableAny(tables.into_iter().map(Into::into).collect())
184 }
185
186 pub fn ast(pattern: AstPattern) -> Self {
188 Self::Ast(pattern)
189 }
190
191 pub fn all() -> Self {
193 Self::All
194 }
195}
196
197#[derive(Debug, Clone)]
199pub enum AstPattern {
200 SelectStar,
202
203 SelectFrom { table: String },
205
206 NoLimit,
208
209 NoWhere,
211
212 Insert,
214
215 Update,
217
218 Delete,
220
221 Ddl,
223
224 NPlusOne { table: String },
226
227 FullTableScan,
229
230 And(Vec<AstPattern>),
232
233 Or(Vec<AstPattern>),
235}
236
237impl AstPattern {
238 pub fn select_star() -> Self {
240 Self::SelectStar
241 }
242
243 pub fn no_limit() -> Self {
245 Self::NoLimit
246 }
247
248 pub fn no_where() -> Self {
250 Self::NoWhere
251 }
252}
253
254#[derive(Debug, Clone)]
256pub enum Transformation {
257 NoOp,
259
260 Replace(String),
262
263 AddIndexHint {
265 table: String,
266 index: String,
267 },
268
269 ExpandSelectStar {
271 columns: Vec<String>,
272 },
273
274 AddLimit(u32),
276
277 AddWhereClause(String),
279
280 AppendWhereAnd(String),
282
283 ReplaceTable {
285 from: String,
286 to: String,
287 },
288
289 AddOrderBy {
291 column: String,
292 descending: bool,
293 },
294
295 AddHint(String),
297
298 AddBranchHint(String),
300
301 AddTimeout(std::time::Duration),
303
304 Custom(String),
306
307 Chain(Vec<Transformation>),
309}
310
311impl Transformation {
312 pub fn replace(query: impl Into<String>) -> Self {
314 Self::Replace(query.into())
315 }
316
317 pub fn add_limit(limit: u32) -> Self {
319 Self::AddLimit(limit)
320 }
321
322 pub fn add_where(condition: impl Into<String>) -> Self {
324 Self::AddWhereClause(condition.into())
325 }
326
327 pub fn replace_table(from: impl Into<String>, to: impl Into<String>) -> Self {
329 Self::ReplaceTable {
330 from: from.into(),
331 to: to.into(),
332 }
333 }
334
335 pub fn expand_select_star(columns: Vec<impl Into<String>>) -> Self {
337 Self::ExpandSelectStar {
338 columns: columns.into_iter().map(Into::into).collect(),
339 }
340 }
341
342 pub fn add_index_hint(table: impl Into<String>, index: impl Into<String>) -> Self {
344 Self::AddIndexHint {
345 table: table.into(),
346 index: index.into(),
347 }
348 }
349
350 pub fn chain(transformations: Vec<Transformation>) -> Self {
352 Self::Chain(transformations)
353 }
354}
355
356#[derive(Debug, Clone)]
358pub enum Condition {
359 NoExistingLimit,
361
362 NoExistingOrderBy,
364
365 HasSelectStar,
367
368 SessionVar {
370 name: String,
371 exists: bool,
372 },
373
374 ClientType {
376 client_type: String,
377 },
378
379 TableExists {
381 table: String,
382 },
383
384 And(Vec<Condition>),
386
387 Or(Vec<Condition>),
389
390 Not(Box<Condition>),
392}
393
394impl Condition {
395 pub fn no_limit() -> Self {
397 Self::NoExistingLimit
398 }
399
400 pub fn no_order_by() -> Self {
402 Self::NoExistingOrderBy
403 }
404
405 pub fn has_select_star() -> Self {
407 Self::HasSelectStar
408 }
409
410 pub fn session_var(name: impl Into<String>) -> Self {
412 Self::SessionVar {
413 name: name.into(),
414 exists: true,
415 }
416 }
417
418 pub fn client_type(client_type: impl Into<String>) -> Self {
420 Self::ClientType {
421 client_type: client_type.into(),
422 }
423 }
424
425 pub fn and(conditions: Vec<Condition>) -> Self {
427 Self::And(conditions)
428 }
429
430 pub fn or(conditions: Vec<Condition>) -> Self {
432 Self::Or(conditions)
433 }
434
435 pub fn not(condition: Condition) -> Self {
437 Self::Not(Box::new(condition))
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444
445 #[test]
446 fn test_rule_builder() {
447 let rule = RewriteRule::new("test")
448 .description("Test rule")
449 .pattern(QueryPattern::All)
450 .transform(Transformation::AddLimit(100))
451 .priority(50)
452 .tag("safety")
453 .build();
454
455 assert_eq!(rule.id, "test");
456 assert_eq!(rule.description, "Test rule");
457 assert_eq!(rule.priority, 50);
458 assert!(rule.enabled);
459 assert!(rule.tags.contains("safety"));
460 }
461
462 #[test]
463 fn test_query_pattern_table() {
464 let pattern = QueryPattern::table("users");
465
466 match pattern {
467 QueryPattern::Table(t) => assert_eq!(t, "users"),
468 _ => panic!("Expected Table pattern"),
469 }
470 }
471
472 #[test]
473 fn test_transformation_chain() {
474 let transform = Transformation::chain(vec![
475 Transformation::AddLimit(100),
476 Transformation::AddOrderBy {
477 column: "id".to_string(),
478 descending: true,
479 },
480 ]);
481
482 match transform {
483 Transformation::Chain(t) => assert_eq!(t.len(), 2),
484 _ => panic!("Expected Chain"),
485 }
486 }
487
488 #[test]
489 fn test_condition_and() {
490 let condition = Condition::and(vec![
491 Condition::NoExistingLimit,
492 Condition::HasSelectStar,
493 ]);
494
495 match condition {
496 Condition::And(c) => assert_eq!(c.len(), 2),
497 _ => panic!("Expected And"),
498 }
499 }
500
501 #[test]
502 fn test_rule_matches() {
503 let rule = RewriteRule::new("test")
504 .pattern(QueryPattern::Table("users".to_string()))
505 .transform(Transformation::AddLimit(100))
506 .build();
507
508 assert!(rule.matches(0, "", &["users".to_string()]));
509 assert!(!rule.matches(0, "", &["orders".to_string()]));
510 }
511}