Skip to main content

heliosdb_proxy/rewriter/
rules.rs

1//! Rewrite Rules
2//!
3//! Rule definitions for query rewriting.
4
5use std::collections::HashSet;
6
7/// A rewrite rule
8#[derive(Debug, Clone)]
9pub struct RewriteRule {
10    /// Rule identifier
11    pub id: String,
12
13    /// Human-readable description
14    pub description: String,
15
16    /// Pattern to match
17    pub pattern: QueryPattern,
18
19    /// Transformation to apply
20    pub transformation: Transformation,
21
22    /// Condition for applying rule
23    pub condition: Option<Condition>,
24
25    /// Priority (higher = applied first)
26    pub priority: i32,
27
28    /// Enabled/disabled
29    pub enabled: bool,
30
31    /// Rule tags for grouping
32    pub tags: HashSet<String>,
33}
34
35impl RewriteRule {
36    /// Create a new rule
37    pub fn new(id: impl Into<String>) -> RewriteRuleBuilder {
38        RewriteRuleBuilder::new(id)
39    }
40
41    /// Check if rule matches query pattern
42    pub fn matches(&self, fingerprint: u64, query: &str, tables: &[String]) -> bool {
43        if !self.enabled {
44            return false;
45        }
46
47        match &self.pattern {
48            QueryPattern::Fingerprint(fp) => *fp == fingerprint,
49            QueryPattern::Regex(pattern) => {
50                regex::Regex::new(pattern)
51                    .map(|re| re.is_match(query))
52                    .unwrap_or(false)
53            }
54            QueryPattern::Table(table) => tables.contains(table),
55            QueryPattern::TableAny(table_patterns) => {
56                tables.iter().any(|t| table_patterns.contains(t))
57            }
58            QueryPattern::Ast(_ast_pattern) => {
59                // AST matching is done by the matcher
60                false
61            }
62            QueryPattern::All => true,
63        }
64    }
65}
66
67/// Builder for RewriteRule
68pub struct RewriteRuleBuilder {
69    rule: RewriteRule,
70}
71
72impl RewriteRuleBuilder {
73    /// Create a new builder
74    pub fn new(id: impl Into<String>) -> Self {
75        Self {
76            rule: RewriteRule {
77                id: id.into(),
78                description: String::new(),
79                pattern: QueryPattern::All,
80                transformation: Transformation::NoOp,
81                condition: None,
82                priority: 0,
83                enabled: true,
84                tags: HashSet::new(),
85            },
86        }
87    }
88
89    /// Set description
90    pub fn description(mut self, desc: impl Into<String>) -> Self {
91        self.rule.description = desc.into();
92        self
93    }
94
95    /// Set pattern
96    pub fn pattern(mut self, pattern: QueryPattern) -> Self {
97        self.rule.pattern = pattern;
98        self
99    }
100
101    /// Set transformation
102    pub fn transform(mut self, transformation: Transformation) -> Self {
103        self.rule.transformation = transformation;
104        self
105    }
106
107    /// Set condition
108    pub fn condition(mut self, condition: Condition) -> Self {
109        self.rule.condition = Some(condition);
110        self
111    }
112
113    /// Set priority
114    pub fn priority(mut self, priority: i32) -> Self {
115        self.rule.priority = priority;
116        self
117    }
118
119    /// Enable/disable
120    pub fn enabled(mut self, enabled: bool) -> Self {
121        self.rule.enabled = enabled;
122        self
123    }
124
125    /// Add a tag
126    pub fn tag(mut self, tag: impl Into<String>) -> Self {
127        self.rule.tags.insert(tag.into());
128        self
129    }
130
131    /// Build the rule
132    pub fn build(self) -> RewriteRule {
133        self.rule
134    }
135}
136
137impl From<RewriteRuleBuilder> for RewriteRule {
138    fn from(builder: RewriteRuleBuilder) -> Self {
139        builder.build()
140    }
141}
142
143/// Query pattern for matching
144#[derive(Debug, Clone)]
145pub enum QueryPattern {
146    /// Match by fingerprint hash
147    Fingerprint(u64),
148
149    /// Match by SQL pattern (regex)
150    Regex(String),
151
152    /// Match by table name
153    Table(String),
154
155    /// Match any of these tables
156    TableAny(HashSet<String>),
157
158    /// Match by AST pattern
159    Ast(AstPattern),
160
161    /// Match all queries
162    All,
163}
164
165impl QueryPattern {
166    /// Create a fingerprint pattern
167    pub fn fingerprint(fp: u64) -> Self {
168        Self::Fingerprint(fp)
169    }
170
171    /// Create a regex pattern
172    pub fn regex(pattern: impl Into<String>) -> Self {
173        Self::Regex(pattern.into())
174    }
175
176    /// Create a table pattern
177    pub fn table(table: impl Into<String>) -> Self {
178        Self::Table(table.into())
179    }
180
181    /// Create a table-any pattern
182    pub fn table_any(tables: impl IntoIterator<Item = impl Into<String>>) -> Self {
183        Self::TableAny(tables.into_iter().map(Into::into).collect())
184    }
185
186    /// Create an AST pattern
187    pub fn ast(pattern: AstPattern) -> Self {
188        Self::Ast(pattern)
189    }
190
191    /// Create an all pattern
192    pub fn all() -> Self {
193        Self::All
194    }
195}
196
197/// AST-level pattern matching
198#[derive(Debug, Clone)]
199pub enum AstPattern {
200    /// SELECT * query
201    SelectStar,
202
203    /// SELECT with specific table
204    SelectFrom { table: String },
205
206    /// Query without LIMIT
207    NoLimit,
208
209    /// Query without WHERE
210    NoWhere,
211
212    /// INSERT statement
213    Insert,
214
215    /// UPDATE statement
216    Update,
217
218    /// DELETE statement
219    Delete,
220
221    /// DDL statement (CREATE, ALTER, DROP)
222    Ddl,
223
224    /// N+1 query pattern
225    NPlusOne { table: String },
226
227    /// Full table scan
228    FullTableScan,
229
230    /// Compound pattern
231    And(Vec<AstPattern>),
232
233    /// Any of patterns
234    Or(Vec<AstPattern>),
235}
236
237impl AstPattern {
238    /// Create a SELECT * pattern
239    pub fn select_star() -> Self {
240        Self::SelectStar
241    }
242
243    /// Create a no-limit pattern
244    pub fn no_limit() -> Self {
245        Self::NoLimit
246    }
247
248    /// Create a no-where pattern
249    pub fn no_where() -> Self {
250        Self::NoWhere
251    }
252}
253
254/// Query transformation
255#[derive(Debug, Clone)]
256pub enum Transformation {
257    /// No operation (pass through)
258    NoOp,
259
260    /// Replace entire query
261    Replace(String),
262
263    /// Add index hint
264    AddIndexHint {
265        table: String,
266        index: String,
267    },
268
269    /// Rewrite SELECT * to specific columns
270    ExpandSelectStar {
271        columns: Vec<String>,
272    },
273
274    /// Add LIMIT clause
275    AddLimit(u32),
276
277    /// Add WHERE condition
278    AddWhereClause(String),
279
280    /// Append to WHERE clause with AND
281    AppendWhereAnd(String),
282
283    /// Replace table name
284    ReplaceTable {
285        from: String,
286        to: String,
287    },
288
289    /// Add ORDER BY clause
290    AddOrderBy {
291        column: String,
292        descending: bool,
293    },
294
295    /// Add query hint comment
296    AddHint(String),
297
298    /// Add branch routing hint
299    AddBranchHint(String),
300
301    /// Add timeout hint
302    AddTimeout(std::time::Duration),
303
304    /// Custom transformation by name
305    Custom(String),
306
307    /// Chain multiple transformations
308    Chain(Vec<Transformation>),
309}
310
311impl Transformation {
312    /// Create a replace transformation
313    pub fn replace(query: impl Into<String>) -> Self {
314        Self::Replace(query.into())
315    }
316
317    /// Create an add-limit transformation
318    pub fn add_limit(limit: u32) -> Self {
319        Self::AddLimit(limit)
320    }
321
322    /// Create an add-where transformation
323    pub fn add_where(condition: impl Into<String>) -> Self {
324        Self::AddWhereClause(condition.into())
325    }
326
327    /// Create a replace-table transformation
328    pub fn replace_table(from: impl Into<String>, to: impl Into<String>) -> Self {
329        Self::ReplaceTable {
330            from: from.into(),
331            to: to.into(),
332        }
333    }
334
335    /// Create an expand-select-star transformation
336    pub fn expand_select_star(columns: Vec<impl Into<String>>) -> Self {
337        Self::ExpandSelectStar {
338            columns: columns.into_iter().map(Into::into).collect(),
339        }
340    }
341
342    /// Create an add-index-hint transformation
343    pub fn add_index_hint(table: impl Into<String>, index: impl Into<String>) -> Self {
344        Self::AddIndexHint {
345            table: table.into(),
346            index: index.into(),
347        }
348    }
349
350    /// Create a chain transformation
351    pub fn chain(transformations: Vec<Transformation>) -> Self {
352        Self::Chain(transformations)
353    }
354}
355
356/// Condition for rule application
357#[derive(Debug, Clone)]
358pub enum Condition {
359    /// Query has no LIMIT clause
360    NoExistingLimit,
361
362    /// Query has no ORDER BY clause
363    NoExistingOrderBy,
364
365    /// Query has SELECT *
366    HasSelectStar,
367
368    /// Session variable check
369    SessionVar {
370        name: String,
371        exists: bool,
372    },
373
374    /// Client type check
375    ClientType {
376        client_type: String,
377    },
378
379    /// Table exists in schema
380    TableExists {
381        table: String,
382    },
383
384    /// All conditions must match
385    And(Vec<Condition>),
386
387    /// Any condition must match
388    Or(Vec<Condition>),
389
390    /// Negate condition
391    Not(Box<Condition>),
392}
393
394impl Condition {
395    /// No existing LIMIT
396    pub fn no_limit() -> Self {
397        Self::NoExistingLimit
398    }
399
400    /// No existing ORDER BY
401    pub fn no_order_by() -> Self {
402        Self::NoExistingOrderBy
403    }
404
405    /// Has SELECT *
406    pub fn has_select_star() -> Self {
407        Self::HasSelectStar
408    }
409
410    /// Session variable exists
411    pub fn session_var(name: impl Into<String>) -> Self {
412        Self::SessionVar {
413            name: name.into(),
414            exists: true,
415        }
416    }
417
418    /// Client type matches
419    pub fn client_type(client_type: impl Into<String>) -> Self {
420        Self::ClientType {
421            client_type: client_type.into(),
422        }
423    }
424
425    /// AND conditions
426    pub fn and(conditions: Vec<Condition>) -> Self {
427        Self::And(conditions)
428    }
429
430    /// OR conditions
431    pub fn or(conditions: Vec<Condition>) -> Self {
432        Self::Or(conditions)
433    }
434
435    /// NOT condition
436    pub fn not(condition: Condition) -> Self {
437        Self::Not(Box::new(condition))
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    #[test]
446    fn test_rule_builder() {
447        let rule = RewriteRule::new("test")
448            .description("Test rule")
449            .pattern(QueryPattern::All)
450            .transform(Transformation::AddLimit(100))
451            .priority(50)
452            .tag("safety")
453            .build();
454
455        assert_eq!(rule.id, "test");
456        assert_eq!(rule.description, "Test rule");
457        assert_eq!(rule.priority, 50);
458        assert!(rule.enabled);
459        assert!(rule.tags.contains("safety"));
460    }
461
462    #[test]
463    fn test_query_pattern_table() {
464        let pattern = QueryPattern::table("users");
465
466        match pattern {
467            QueryPattern::Table(t) => assert_eq!(t, "users"),
468            _ => panic!("Expected Table pattern"),
469        }
470    }
471
472    #[test]
473    fn test_transformation_chain() {
474        let transform = Transformation::chain(vec![
475            Transformation::AddLimit(100),
476            Transformation::AddOrderBy {
477                column: "id".to_string(),
478                descending: true,
479            },
480        ]);
481
482        match transform {
483            Transformation::Chain(t) => assert_eq!(t.len(), 2),
484            _ => panic!("Expected Chain"),
485        }
486    }
487
488    #[test]
489    fn test_condition_and() {
490        let condition = Condition::and(vec![
491            Condition::NoExistingLimit,
492            Condition::HasSelectStar,
493        ]);
494
495        match condition {
496            Condition::And(c) => assert_eq!(c.len(), 2),
497            _ => panic!("Expected And"),
498        }
499    }
500
501    #[test]
502    fn test_rule_matches() {
503        let rule = RewriteRule::new("test")
504            .pattern(QueryPattern::Table("users".to_string()))
505            .transform(Transformation::AddLimit(100))
506            .build();
507
508        assert!(rule.matches(0, "", &["users".to_string()]));
509        assert!(!rule.matches(0, "", &["orders".to_string()]));
510    }
511}