1use std::collections::HashSet;
6
7#[derive(Debug, Clone)]
9pub struct RewriteRule {
10 pub id: String,
12
13 pub description: String,
15
16 pub pattern: QueryPattern,
18
19 pub transformation: Transformation,
21
22 pub condition: Option<Condition>,
24
25 pub priority: i32,
27
28 pub enabled: bool,
30
31 pub tags: HashSet<String>,
33}
34
35impl RewriteRule {
36 pub fn build(id: impl Into<String>) -> RewriteRuleBuilder {
38 RewriteRuleBuilder::new(id)
39 }
40
41 pub fn matches(&self, fingerprint: u64, query: &str, tables: &[String]) -> bool {
43 if !self.enabled {
44 return false;
45 }
46
47 match &self.pattern {
48 QueryPattern::Fingerprint(fp) => *fp == fingerprint,
49 QueryPattern::Regex(pattern) => regex::Regex::new(pattern)
50 .map(|re| re.is_match(query))
51 .unwrap_or(false),
52 QueryPattern::Table(table) => tables.contains(table),
53 QueryPattern::TableAny(table_patterns) => {
54 tables.iter().any(|t| table_patterns.contains(t))
55 }
56 QueryPattern::Ast(_ast_pattern) => {
57 false
59 }
60 QueryPattern::All => true,
61 }
62 }
63}
64
65pub struct RewriteRuleBuilder {
67 rule: RewriteRule,
68}
69
70impl RewriteRuleBuilder {
71 pub fn new(id: impl Into<String>) -> Self {
73 Self {
74 rule: RewriteRule {
75 id: id.into(),
76 description: String::new(),
77 pattern: QueryPattern::All,
78 transformation: Transformation::NoOp,
79 condition: None,
80 priority: 0,
81 enabled: true,
82 tags: HashSet::new(),
83 },
84 }
85 }
86
87 pub fn description(mut self, desc: impl Into<String>) -> Self {
89 self.rule.description = desc.into();
90 self
91 }
92
93 pub fn pattern(mut self, pattern: QueryPattern) -> Self {
95 self.rule.pattern = pattern;
96 self
97 }
98
99 pub fn transform(mut self, transformation: Transformation) -> Self {
101 self.rule.transformation = transformation;
102 self
103 }
104
105 pub fn condition(mut self, condition: Condition) -> Self {
107 self.rule.condition = Some(condition);
108 self
109 }
110
111 pub fn priority(mut self, priority: i32) -> Self {
113 self.rule.priority = priority;
114 self
115 }
116
117 pub fn enabled(mut self, enabled: bool) -> Self {
119 self.rule.enabled = enabled;
120 self
121 }
122
123 pub fn tag(mut self, tag: impl Into<String>) -> Self {
125 self.rule.tags.insert(tag.into());
126 self
127 }
128
129 pub fn build(self) -> RewriteRule {
131 self.rule
132 }
133}
134
135impl From<RewriteRuleBuilder> for RewriteRule {
136 fn from(builder: RewriteRuleBuilder) -> Self {
137 builder.build()
138 }
139}
140
141#[derive(Debug, Clone)]
143pub enum QueryPattern {
144 Fingerprint(u64),
146
147 Regex(String),
149
150 Table(String),
152
153 TableAny(HashSet<String>),
155
156 Ast(AstPattern),
158
159 All,
161}
162
163impl QueryPattern {
164 pub fn fingerprint(fp: u64) -> Self {
166 Self::Fingerprint(fp)
167 }
168
169 pub fn regex(pattern: impl Into<String>) -> Self {
171 Self::Regex(pattern.into())
172 }
173
174 pub fn table(table: impl Into<String>) -> Self {
176 Self::Table(table.into())
177 }
178
179 pub fn table_any(tables: impl IntoIterator<Item = impl Into<String>>) -> Self {
181 Self::TableAny(tables.into_iter().map(Into::into).collect())
182 }
183
184 pub fn ast(pattern: AstPattern) -> Self {
186 Self::Ast(pattern)
187 }
188
189 pub fn all() -> Self {
191 Self::All
192 }
193}
194
195#[derive(Debug, Clone)]
197pub enum AstPattern {
198 SelectStar,
200
201 SelectFrom { table: String },
203
204 NoLimit,
206
207 NoWhere,
209
210 Insert,
212
213 Update,
215
216 Delete,
218
219 Ddl,
221
222 NPlusOne { table: String },
224
225 FullTableScan,
227
228 And(Vec<AstPattern>),
230
231 Or(Vec<AstPattern>),
233}
234
235impl AstPattern {
236 pub fn select_star() -> Self {
238 Self::SelectStar
239 }
240
241 pub fn no_limit() -> Self {
243 Self::NoLimit
244 }
245
246 pub fn no_where() -> Self {
248 Self::NoWhere
249 }
250}
251
252#[derive(Debug, Clone)]
254pub enum Transformation {
255 NoOp,
257
258 Replace(String),
260
261 AddIndexHint { table: String, index: String },
263
264 ExpandSelectStar { columns: Vec<String> },
266
267 AddLimit(u32),
269
270 AddWhereClause(String),
272
273 AppendWhereAnd(String),
275
276 ReplaceTable { from: String, to: String },
278
279 AddOrderBy { column: String, descending: bool },
281
282 AddHint(String),
284
285 AddBranchHint(String),
287
288 AddTimeout(std::time::Duration),
290
291 Custom(String),
293
294 Chain(Vec<Transformation>),
296}
297
298impl Transformation {
299 pub fn replace(query: impl Into<String>) -> Self {
301 Self::Replace(query.into())
302 }
303
304 pub fn add_limit(limit: u32) -> Self {
306 Self::AddLimit(limit)
307 }
308
309 pub fn add_where(condition: impl Into<String>) -> Self {
311 Self::AddWhereClause(condition.into())
312 }
313
314 pub fn replace_table(from: impl Into<String>, to: impl Into<String>) -> Self {
316 Self::ReplaceTable {
317 from: from.into(),
318 to: to.into(),
319 }
320 }
321
322 pub fn expand_select_star(columns: Vec<impl Into<String>>) -> Self {
324 Self::ExpandSelectStar {
325 columns: columns.into_iter().map(Into::into).collect(),
326 }
327 }
328
329 pub fn add_index_hint(table: impl Into<String>, index: impl Into<String>) -> Self {
331 Self::AddIndexHint {
332 table: table.into(),
333 index: index.into(),
334 }
335 }
336
337 pub fn chain(transformations: Vec<Transformation>) -> Self {
339 Self::Chain(transformations)
340 }
341}
342
343#[derive(Debug, Clone)]
345pub enum Condition {
346 NoExistingLimit,
348
349 NoExistingOrderBy,
351
352 HasSelectStar,
354
355 SessionVar { name: String, exists: bool },
357
358 ClientType { client_type: String },
360
361 TableExists { table: String },
363
364 And(Vec<Condition>),
366
367 Or(Vec<Condition>),
369
370 Not(Box<Condition>),
372}
373
374impl Condition {
375 pub fn no_limit() -> Self {
377 Self::NoExistingLimit
378 }
379
380 pub fn no_order_by() -> Self {
382 Self::NoExistingOrderBy
383 }
384
385 pub fn has_select_star() -> Self {
387 Self::HasSelectStar
388 }
389
390 pub fn session_var(name: impl Into<String>) -> Self {
392 Self::SessionVar {
393 name: name.into(),
394 exists: true,
395 }
396 }
397
398 pub fn client_type(client_type: impl Into<String>) -> Self {
400 Self::ClientType {
401 client_type: client_type.into(),
402 }
403 }
404
405 pub fn and(conditions: Vec<Condition>) -> Self {
407 Self::And(conditions)
408 }
409
410 pub fn or(conditions: Vec<Condition>) -> Self {
412 Self::Or(conditions)
413 }
414
415 #[allow(clippy::should_implement_trait)]
417 pub fn not(condition: Condition) -> Self {
418 Self::Not(Box::new(condition))
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 #[test]
427 fn test_rule_builder() {
428 let rule = RewriteRule::build("test")
429 .description("Test rule")
430 .pattern(QueryPattern::All)
431 .transform(Transformation::AddLimit(100))
432 .priority(50)
433 .tag("safety")
434 .build();
435
436 assert_eq!(rule.id, "test");
437 assert_eq!(rule.description, "Test rule");
438 assert_eq!(rule.priority, 50);
439 assert!(rule.enabled);
440 assert!(rule.tags.contains("safety"));
441 }
442
443 #[test]
444 fn test_query_pattern_table() {
445 let pattern = QueryPattern::table("users");
446
447 match pattern {
448 QueryPattern::Table(t) => assert_eq!(t, "users"),
449 _ => panic!("Expected Table pattern"),
450 }
451 }
452
453 #[test]
454 fn test_transformation_chain() {
455 let transform = Transformation::chain(vec![
456 Transformation::AddLimit(100),
457 Transformation::AddOrderBy {
458 column: "id".to_string(),
459 descending: true,
460 },
461 ]);
462
463 match transform {
464 Transformation::Chain(t) => assert_eq!(t.len(), 2),
465 _ => panic!("Expected Chain"),
466 }
467 }
468
469 #[test]
470 fn test_condition_and() {
471 let condition = Condition::and(vec![Condition::NoExistingLimit, Condition::HasSelectStar]);
472
473 match condition {
474 Condition::And(c) => assert_eq!(c.len(), 2),
475 _ => panic!("Expected And"),
476 }
477 }
478
479 #[test]
480 fn test_rule_matches() {
481 let rule = RewriteRule::build("test")
482 .pattern(QueryPattern::Table("users".to_string()))
483 .transform(Transformation::AddLimit(100))
484 .build();
485
486 assert!(rule.matches(0, "", &["users".to_string()]));
487 assert!(!rule.matches(0, "", &["orders".to_string()]));
488 }
489}