Skip to main content

heliosdb_proxy/rewriter/
rules.rs

1//! Rewrite Rules
2//!
3//! Rule definitions for query rewriting.
4
5use std::collections::HashSet;
6
7/// A rewrite rule
8#[derive(Debug, Clone)]
9pub struct RewriteRule {
10    /// Rule identifier
11    pub id: String,
12
13    /// Human-readable description
14    pub description: String,
15
16    /// Pattern to match
17    pub pattern: QueryPattern,
18
19    /// Transformation to apply
20    pub transformation: Transformation,
21
22    /// Condition for applying rule
23    pub condition: Option<Condition>,
24
25    /// Priority (higher = applied first)
26    pub priority: i32,
27
28    /// Enabled/disabled
29    pub enabled: bool,
30
31    /// Rule tags for grouping
32    pub tags: HashSet<String>,
33}
34
35impl RewriteRule {
36    /// Create a new rule builder
37    pub fn build(id: impl Into<String>) -> RewriteRuleBuilder {
38        RewriteRuleBuilder::new(id)
39    }
40
41    /// Check if rule matches query pattern
42    pub fn matches(&self, fingerprint: u64, query: &str, tables: &[String]) -> bool {
43        if !self.enabled {
44            return false;
45        }
46
47        match &self.pattern {
48            QueryPattern::Fingerprint(fp) => *fp == fingerprint,
49            QueryPattern::Regex(pattern) => regex::Regex::new(pattern)
50                .map(|re| re.is_match(query))
51                .unwrap_or(false),
52            QueryPattern::Table(table) => tables.contains(table),
53            QueryPattern::TableAny(table_patterns) => {
54                tables.iter().any(|t| table_patterns.contains(t))
55            }
56            QueryPattern::Ast(_ast_pattern) => {
57                // AST matching is done by the matcher
58                false
59            }
60            QueryPattern::All => true,
61        }
62    }
63}
64
65/// Builder for RewriteRule
66pub struct RewriteRuleBuilder {
67    rule: RewriteRule,
68}
69
70impl RewriteRuleBuilder {
71    /// Create a new builder
72    pub fn new(id: impl Into<String>) -> Self {
73        Self {
74            rule: RewriteRule {
75                id: id.into(),
76                description: String::new(),
77                pattern: QueryPattern::All,
78                transformation: Transformation::NoOp,
79                condition: None,
80                priority: 0,
81                enabled: true,
82                tags: HashSet::new(),
83            },
84        }
85    }
86
87    /// Set description
88    pub fn description(mut self, desc: impl Into<String>) -> Self {
89        self.rule.description = desc.into();
90        self
91    }
92
93    /// Set pattern
94    pub fn pattern(mut self, pattern: QueryPattern) -> Self {
95        self.rule.pattern = pattern;
96        self
97    }
98
99    /// Set transformation
100    pub fn transform(mut self, transformation: Transformation) -> Self {
101        self.rule.transformation = transformation;
102        self
103    }
104
105    /// Set condition
106    pub fn condition(mut self, condition: Condition) -> Self {
107        self.rule.condition = Some(condition);
108        self
109    }
110
111    /// Set priority
112    pub fn priority(mut self, priority: i32) -> Self {
113        self.rule.priority = priority;
114        self
115    }
116
117    /// Enable/disable
118    pub fn enabled(mut self, enabled: bool) -> Self {
119        self.rule.enabled = enabled;
120        self
121    }
122
123    /// Add a tag
124    pub fn tag(mut self, tag: impl Into<String>) -> Self {
125        self.rule.tags.insert(tag.into());
126        self
127    }
128
129    /// Build the rule
130    pub fn build(self) -> RewriteRule {
131        self.rule
132    }
133}
134
135impl From<RewriteRuleBuilder> for RewriteRule {
136    fn from(builder: RewriteRuleBuilder) -> Self {
137        builder.build()
138    }
139}
140
141/// Query pattern for matching
142#[derive(Debug, Clone)]
143pub enum QueryPattern {
144    /// Match by fingerprint hash
145    Fingerprint(u64),
146
147    /// Match by SQL pattern (regex)
148    Regex(String),
149
150    /// Match by table name
151    Table(String),
152
153    /// Match any of these tables
154    TableAny(HashSet<String>),
155
156    /// Match by AST pattern
157    Ast(AstPattern),
158
159    /// Match all queries
160    All,
161}
162
163impl QueryPattern {
164    /// Create a fingerprint pattern
165    pub fn fingerprint(fp: u64) -> Self {
166        Self::Fingerprint(fp)
167    }
168
169    /// Create a regex pattern
170    pub fn regex(pattern: impl Into<String>) -> Self {
171        Self::Regex(pattern.into())
172    }
173
174    /// Create a table pattern
175    pub fn table(table: impl Into<String>) -> Self {
176        Self::Table(table.into())
177    }
178
179    /// Create a table-any pattern
180    pub fn table_any(tables: impl IntoIterator<Item = impl Into<String>>) -> Self {
181        Self::TableAny(tables.into_iter().map(Into::into).collect())
182    }
183
184    /// Create an AST pattern
185    pub fn ast(pattern: AstPattern) -> Self {
186        Self::Ast(pattern)
187    }
188
189    /// Create an all pattern
190    pub fn all() -> Self {
191        Self::All
192    }
193}
194
195/// AST-level pattern matching
196#[derive(Debug, Clone)]
197pub enum AstPattern {
198    /// SELECT * query
199    SelectStar,
200
201    /// SELECT with specific table
202    SelectFrom { table: String },
203
204    /// Query without LIMIT
205    NoLimit,
206
207    /// Query without WHERE
208    NoWhere,
209
210    /// INSERT statement
211    Insert,
212
213    /// UPDATE statement
214    Update,
215
216    /// DELETE statement
217    Delete,
218
219    /// DDL statement (CREATE, ALTER, DROP)
220    Ddl,
221
222    /// N+1 query pattern
223    NPlusOne { table: String },
224
225    /// Full table scan
226    FullTableScan,
227
228    /// Compound pattern
229    And(Vec<AstPattern>),
230
231    /// Any of patterns
232    Or(Vec<AstPattern>),
233}
234
235impl AstPattern {
236    /// Create a SELECT * pattern
237    pub fn select_star() -> Self {
238        Self::SelectStar
239    }
240
241    /// Create a no-limit pattern
242    pub fn no_limit() -> Self {
243        Self::NoLimit
244    }
245
246    /// Create a no-where pattern
247    pub fn no_where() -> Self {
248        Self::NoWhere
249    }
250}
251
252/// Query transformation
253#[derive(Debug, Clone)]
254pub enum Transformation {
255    /// No operation (pass through)
256    NoOp,
257
258    /// Replace entire query
259    Replace(String),
260
261    /// Add index hint
262    AddIndexHint { table: String, index: String },
263
264    /// Rewrite SELECT * to specific columns
265    ExpandSelectStar { columns: Vec<String> },
266
267    /// Add LIMIT clause
268    AddLimit(u32),
269
270    /// Add WHERE condition
271    AddWhereClause(String),
272
273    /// Append to WHERE clause with AND
274    AppendWhereAnd(String),
275
276    /// Replace table name
277    ReplaceTable { from: String, to: String },
278
279    /// Add ORDER BY clause
280    AddOrderBy { column: String, descending: bool },
281
282    /// Add query hint comment
283    AddHint(String),
284
285    /// Add branch routing hint
286    AddBranchHint(String),
287
288    /// Add timeout hint
289    AddTimeout(std::time::Duration),
290
291    /// Custom transformation by name
292    Custom(String),
293
294    /// Chain multiple transformations
295    Chain(Vec<Transformation>),
296}
297
298impl Transformation {
299    /// Create a replace transformation
300    pub fn replace(query: impl Into<String>) -> Self {
301        Self::Replace(query.into())
302    }
303
304    /// Create an add-limit transformation
305    pub fn add_limit(limit: u32) -> Self {
306        Self::AddLimit(limit)
307    }
308
309    /// Create an add-where transformation
310    pub fn add_where(condition: impl Into<String>) -> Self {
311        Self::AddWhereClause(condition.into())
312    }
313
314    /// Create a replace-table transformation
315    pub fn replace_table(from: impl Into<String>, to: impl Into<String>) -> Self {
316        Self::ReplaceTable {
317            from: from.into(),
318            to: to.into(),
319        }
320    }
321
322    /// Create an expand-select-star transformation
323    pub fn expand_select_star(columns: Vec<impl Into<String>>) -> Self {
324        Self::ExpandSelectStar {
325            columns: columns.into_iter().map(Into::into).collect(),
326        }
327    }
328
329    /// Create an add-index-hint transformation
330    pub fn add_index_hint(table: impl Into<String>, index: impl Into<String>) -> Self {
331        Self::AddIndexHint {
332            table: table.into(),
333            index: index.into(),
334        }
335    }
336
337    /// Create a chain transformation
338    pub fn chain(transformations: Vec<Transformation>) -> Self {
339        Self::Chain(transformations)
340    }
341}
342
343/// Condition for rule application
344#[derive(Debug, Clone)]
345pub enum Condition {
346    /// Query has no LIMIT clause
347    NoExistingLimit,
348
349    /// Query has no ORDER BY clause
350    NoExistingOrderBy,
351
352    /// Query has SELECT *
353    HasSelectStar,
354
355    /// Session variable check
356    SessionVar { name: String, exists: bool },
357
358    /// Client type check
359    ClientType { client_type: String },
360
361    /// Table exists in schema
362    TableExists { table: String },
363
364    /// All conditions must match
365    And(Vec<Condition>),
366
367    /// Any condition must match
368    Or(Vec<Condition>),
369
370    /// Negate condition
371    Not(Box<Condition>),
372}
373
374impl Condition {
375    /// No existing LIMIT
376    pub fn no_limit() -> Self {
377        Self::NoExistingLimit
378    }
379
380    /// No existing ORDER BY
381    pub fn no_order_by() -> Self {
382        Self::NoExistingOrderBy
383    }
384
385    /// Has SELECT *
386    pub fn has_select_star() -> Self {
387        Self::HasSelectStar
388    }
389
390    /// Session variable exists
391    pub fn session_var(name: impl Into<String>) -> Self {
392        Self::SessionVar {
393            name: name.into(),
394            exists: true,
395        }
396    }
397
398    /// Client type matches
399    pub fn client_type(client_type: impl Into<String>) -> Self {
400        Self::ClientType {
401            client_type: client_type.into(),
402        }
403    }
404
405    /// AND conditions
406    pub fn and(conditions: Vec<Condition>) -> Self {
407        Self::And(conditions)
408    }
409
410    /// OR conditions
411    pub fn or(conditions: Vec<Condition>) -> Self {
412        Self::Or(conditions)
413    }
414
415    /// NOT condition
416    #[allow(clippy::should_implement_trait)]
417    pub fn not(condition: Condition) -> Self {
418        Self::Not(Box::new(condition))
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    #[test]
427    fn test_rule_builder() {
428        let rule = RewriteRule::build("test")
429            .description("Test rule")
430            .pattern(QueryPattern::All)
431            .transform(Transformation::AddLimit(100))
432            .priority(50)
433            .tag("safety")
434            .build();
435
436        assert_eq!(rule.id, "test");
437        assert_eq!(rule.description, "Test rule");
438        assert_eq!(rule.priority, 50);
439        assert!(rule.enabled);
440        assert!(rule.tags.contains("safety"));
441    }
442
443    #[test]
444    fn test_query_pattern_table() {
445        let pattern = QueryPattern::table("users");
446
447        match pattern {
448            QueryPattern::Table(t) => assert_eq!(t, "users"),
449            _ => panic!("Expected Table pattern"),
450        }
451    }
452
453    #[test]
454    fn test_transformation_chain() {
455        let transform = Transformation::chain(vec![
456            Transformation::AddLimit(100),
457            Transformation::AddOrderBy {
458                column: "id".to_string(),
459                descending: true,
460            },
461        ]);
462
463        match transform {
464            Transformation::Chain(t) => assert_eq!(t.len(), 2),
465            _ => panic!("Expected Chain"),
466        }
467    }
468
469    #[test]
470    fn test_condition_and() {
471        let condition = Condition::and(vec![Condition::NoExistingLimit, Condition::HasSelectStar]);
472
473        match condition {
474            Condition::And(c) => assert_eq!(c.len(), 2),
475            _ => panic!("Expected And"),
476        }
477    }
478
479    #[test]
480    fn test_rule_matches() {
481        let rule = RewriteRule::build("test")
482            .pattern(QueryPattern::Table("users".to_string()))
483            .transform(Transformation::AddLimit(100))
484            .build();
485
486        assert!(rule.matches(0, "", &["users".to_string()]));
487        assert!(!rule.matches(0, "", &["orders".to_string()]));
488    }
489}