use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct RewriteRule {
pub id: String,
pub description: String,
pub pattern: QueryPattern,
pub transformation: Transformation,
pub condition: Option<Condition>,
pub priority: i32,
pub enabled: bool,
pub tags: HashSet<String>,
}
impl RewriteRule {
pub fn new(id: impl Into<String>) -> RewriteRuleBuilder {
RewriteRuleBuilder::new(id)
}
pub fn matches(&self, fingerprint: u64, query: &str, tables: &[String]) -> bool {
if !self.enabled {
return false;
}
match &self.pattern {
QueryPattern::Fingerprint(fp) => *fp == fingerprint,
QueryPattern::Regex(pattern) => {
regex::Regex::new(pattern)
.map(|re| re.is_match(query))
.unwrap_or(false)
}
QueryPattern::Table(table) => tables.contains(table),
QueryPattern::TableAny(table_patterns) => {
tables.iter().any(|t| table_patterns.contains(t))
}
QueryPattern::Ast(_ast_pattern) => {
false
}
QueryPattern::All => true,
}
}
}
pub struct RewriteRuleBuilder {
rule: RewriteRule,
}
impl RewriteRuleBuilder {
pub fn new(id: impl Into<String>) -> Self {
Self {
rule: RewriteRule {
id: id.into(),
description: String::new(),
pattern: QueryPattern::All,
transformation: Transformation::NoOp,
condition: None,
priority: 0,
enabled: true,
tags: HashSet::new(),
},
}
}
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.rule.description = desc.into();
self
}
pub fn pattern(mut self, pattern: QueryPattern) -> Self {
self.rule.pattern = pattern;
self
}
pub fn transform(mut self, transformation: Transformation) -> Self {
self.rule.transformation = transformation;
self
}
pub fn condition(mut self, condition: Condition) -> Self {
self.rule.condition = Some(condition);
self
}
pub fn priority(mut self, priority: i32) -> Self {
self.rule.priority = priority;
self
}
pub fn enabled(mut self, enabled: bool) -> Self {
self.rule.enabled = enabled;
self
}
pub fn tag(mut self, tag: impl Into<String>) -> Self {
self.rule.tags.insert(tag.into());
self
}
pub fn build(self) -> RewriteRule {
self.rule
}
}
impl From<RewriteRuleBuilder> for RewriteRule {
fn from(builder: RewriteRuleBuilder) -> Self {
builder.build()
}
}
#[derive(Debug, Clone)]
pub enum QueryPattern {
Fingerprint(u64),
Regex(String),
Table(String),
TableAny(HashSet<String>),
Ast(AstPattern),
All,
}
impl QueryPattern {
pub fn fingerprint(fp: u64) -> Self {
Self::Fingerprint(fp)
}
pub fn regex(pattern: impl Into<String>) -> Self {
Self::Regex(pattern.into())
}
pub fn table(table: impl Into<String>) -> Self {
Self::Table(table.into())
}
pub fn table_any(tables: impl IntoIterator<Item = impl Into<String>>) -> Self {
Self::TableAny(tables.into_iter().map(Into::into).collect())
}
pub fn ast(pattern: AstPattern) -> Self {
Self::Ast(pattern)
}
pub fn all() -> Self {
Self::All
}
}
#[derive(Debug, Clone)]
pub enum AstPattern {
SelectStar,
SelectFrom { table: String },
NoLimit,
NoWhere,
Insert,
Update,
Delete,
Ddl,
NPlusOne { table: String },
FullTableScan,
And(Vec<AstPattern>),
Or(Vec<AstPattern>),
}
impl AstPattern {
pub fn select_star() -> Self {
Self::SelectStar
}
pub fn no_limit() -> Self {
Self::NoLimit
}
pub fn no_where() -> Self {
Self::NoWhere
}
}
#[derive(Debug, Clone)]
pub enum Transformation {
NoOp,
Replace(String),
AddIndexHint {
table: String,
index: String,
},
ExpandSelectStar {
columns: Vec<String>,
},
AddLimit(u32),
AddWhereClause(String),
AppendWhereAnd(String),
ReplaceTable {
from: String,
to: String,
},
AddOrderBy {
column: String,
descending: bool,
},
AddHint(String),
AddBranchHint(String),
AddTimeout(std::time::Duration),
Custom(String),
Chain(Vec<Transformation>),
}
impl Transformation {
pub fn replace(query: impl Into<String>) -> Self {
Self::Replace(query.into())
}
pub fn add_limit(limit: u32) -> Self {
Self::AddLimit(limit)
}
pub fn add_where(condition: impl Into<String>) -> Self {
Self::AddWhereClause(condition.into())
}
pub fn replace_table(from: impl Into<String>, to: impl Into<String>) -> Self {
Self::ReplaceTable {
from: from.into(),
to: to.into(),
}
}
pub fn expand_select_star(columns: Vec<impl Into<String>>) -> Self {
Self::ExpandSelectStar {
columns: columns.into_iter().map(Into::into).collect(),
}
}
pub fn add_index_hint(table: impl Into<String>, index: impl Into<String>) -> Self {
Self::AddIndexHint {
table: table.into(),
index: index.into(),
}
}
pub fn chain(transformations: Vec<Transformation>) -> Self {
Self::Chain(transformations)
}
}
#[derive(Debug, Clone)]
pub enum Condition {
NoExistingLimit,
NoExistingOrderBy,
HasSelectStar,
SessionVar {
name: String,
exists: bool,
},
ClientType {
client_type: String,
},
TableExists {
table: String,
},
And(Vec<Condition>),
Or(Vec<Condition>),
Not(Box<Condition>),
}
impl Condition {
pub fn no_limit() -> Self {
Self::NoExistingLimit
}
pub fn no_order_by() -> Self {
Self::NoExistingOrderBy
}
pub fn has_select_star() -> Self {
Self::HasSelectStar
}
pub fn session_var(name: impl Into<String>) -> Self {
Self::SessionVar {
name: name.into(),
exists: true,
}
}
pub fn client_type(client_type: impl Into<String>) -> Self {
Self::ClientType {
client_type: client_type.into(),
}
}
pub fn and(conditions: Vec<Condition>) -> Self {
Self::And(conditions)
}
pub fn or(conditions: Vec<Condition>) -> Self {
Self::Or(conditions)
}
pub fn not(condition: Condition) -> Self {
Self::Not(Box::new(condition))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rule_builder() {
let rule = RewriteRule::new("test")
.description("Test rule")
.pattern(QueryPattern::All)
.transform(Transformation::AddLimit(100))
.priority(50)
.tag("safety")
.build();
assert_eq!(rule.id, "test");
assert_eq!(rule.description, "Test rule");
assert_eq!(rule.priority, 50);
assert!(rule.enabled);
assert!(rule.tags.contains("safety"));
}
#[test]
fn test_query_pattern_table() {
let pattern = QueryPattern::table("users");
match pattern {
QueryPattern::Table(t) => assert_eq!(t, "users"),
_ => panic!("Expected Table pattern"),
}
}
#[test]
fn test_transformation_chain() {
let transform = Transformation::chain(vec![
Transformation::AddLimit(100),
Transformation::AddOrderBy {
column: "id".to_string(),
descending: true,
},
]);
match transform {
Transformation::Chain(t) => assert_eq!(t.len(), 2),
_ => panic!("Expected Chain"),
}
}
#[test]
fn test_condition_and() {
let condition = Condition::and(vec![
Condition::NoExistingLimit,
Condition::HasSelectStar,
]);
match condition {
Condition::And(c) => assert_eq!(c.len(), 2),
_ => panic!("Expected And"),
}
}
#[test]
fn test_rule_matches() {
let rule = RewriteRule::new("test")
.pattern(QueryPattern::Table("users".to_string()))
.transform(Transformation::AddLimit(100))
.build();
assert!(rule.matches(0, "", &["users".to_string()]));
assert!(!rule.matches(0, "", &["orders".to_string()]));
}
}