1pub mod config;
49pub mod matcher;
50pub mod metrics;
51pub mod parser;
52pub mod rules;
53pub mod transformer;
54
55pub use config::{RewriterConfig, RewriterConfigBuilder};
57pub use matcher::{MatchResult, RuleMatcher};
58pub use metrics::{RewriteMetrics, RewriteStats, RuleStats};
59pub use parser::{ParsedQuery, SqlParser, SqlStatement};
60pub use rules::{
61 AstPattern, Condition, QueryPattern, RewriteRule, RewriteRuleBuilder, Transformation,
62};
63pub use transformer::{TransformError, TransformationEngine};
64
65use parking_lot::RwLock;
66use std::sync::Arc;
67
68pub struct QueryRewriter {
72 config: RewriterConfig,
74
75 parser: SqlParser,
77
78 rules: Arc<RwLock<Vec<RewriteRule>>>,
80
81 matcher: Arc<RwLock<RuleMatcher>>,
83
84 transformer: TransformationEngine,
86
87 metrics: Arc<RewriteMetrics>,
89}
90
91impl QueryRewriter {
92 pub fn new(config: RewriterConfig) -> Self {
94 let rules = Arc::new(RwLock::new(config.rules.clone()));
95 let matcher = Arc::new(RwLock::new(RuleMatcher::new(&config.rules)));
96 let parser = SqlParser::new();
97 let transformer = TransformationEngine::new();
98 let metrics = Arc::new(RewriteMetrics::new());
99
100 Self {
101 config,
102 parser,
103 rules,
104 matcher,
105 transformer,
106 metrics,
107 }
108 }
109
110 pub fn builder() -> QueryRewriterBuilder {
112 QueryRewriterBuilder::new()
113 }
114
115 pub fn is_enabled(&self) -> bool {
117 self.config.enabled
118 }
119
120 pub fn rewrite(&self, query: &str) -> Result<RewriteResult, RewriteError> {
124 if !self.config.enabled {
125 return Ok(RewriteResult::unchanged(query));
126 }
127
128 let start = std::time::Instant::now();
129
130 let parsed = self.parser.parse(query)?;
132 let fingerprint = parsed.fingerprint();
133
134 let rules = self.rules.read();
139 let matcher = self.matcher.read();
140 let matched = matcher.match_query(&parsed, &rules);
141
142 if matched.is_empty() {
143 self.metrics.record_no_match(start.elapsed());
144 return Ok(RewriteResult::unchanged(query));
145 }
146
147 let mut current_query = query.to_string();
149 let mut applied_rules = Vec::new();
150
151 for rule in matched {
152 if !rule.enabled {
153 continue;
154 }
155
156 if let Some(ref condition) = rule.condition {
158 if !self.evaluate_condition(condition, ¤t_query) {
159 continue;
160 }
161 }
162
163 match self.transformer.apply(¤t_query, &rule.transformation) {
165 Ok(rewritten) => {
166 current_query = rewritten;
167 applied_rules.push(rule.id.clone());
168 self.metrics.record_rule_match(&rule.id);
169 }
170 Err(e) => {
171 if self.config.log_errors {
172 eprintln!("Rewrite error for rule {}: {}", rule.id, e);
173 }
174 }
176 }
177 }
178
179 let duration = start.elapsed();
180 self.metrics
181 .record_rewrite(duration, !applied_rules.is_empty());
182
183 if applied_rules.is_empty() {
184 Ok(RewriteResult::unchanged(query))
185 } else {
186 if self.config.log_rewrites {
187 println!("Rewritten query:");
188 println!(" Original: {}", query);
189 println!(" Rewritten: {}", current_query);
190 println!(" Rules: {:?}", applied_rules);
191 }
192
193 Ok(RewriteResult {
194 original: query.to_string(),
195 rewritten: current_query,
196 rules_applied: applied_rules,
197 fingerprint,
198 duration,
199 })
200 }
201 }
202
203 pub fn test_rewrite(&self, query: &str) -> Result<RewriteResult, RewriteError> {
205 let parsed = self.parser.parse(query)?;
206 let fingerprint = parsed.fingerprint();
207
208 let rules = self.rules.read();
209 let matcher = self.matcher.read();
210 let matched = matcher.match_query(&parsed, &rules);
211
212 let mut current_query = query.to_string();
213 let mut applied_rules = Vec::new();
214
215 for rule in matched {
216 if !rule.enabled {
217 continue;
218 }
219
220 if let Some(ref condition) = rule.condition {
221 if !self.evaluate_condition(condition, ¤t_query) {
222 continue;
223 }
224 }
225
226 if let Ok(rewritten) = self.transformer.apply(¤t_query, &rule.transformation) {
227 current_query = rewritten;
228 applied_rules.push(rule.id.clone());
229 }
230 }
231
232 Ok(RewriteResult {
233 original: query.to_string(),
234 rewritten: current_query,
235 rules_applied: applied_rules,
236 fingerprint,
237 duration: std::time::Duration::ZERO,
238 })
239 }
240
241 pub fn add_rule(&self, rule: impl Into<RewriteRule>) {
243 let mut rules = self.rules.write();
244 rules.push(rule.into());
245
246 let mut matcher = self.matcher.write();
248 *matcher = RuleMatcher::new(&rules);
249 }
250
251 pub fn remove_rule(&self, rule_id: &str) -> bool {
253 let mut rules = self.rules.write();
254 let initial_len = rules.len();
255 rules.retain(|r| r.id != rule_id);
256
257 if rules.len() != initial_len {
258 let mut matcher = self.matcher.write();
259 *matcher = RuleMatcher::new(&rules);
260 true
261 } else {
262 false
263 }
264 }
265
266 pub fn update_rule(&self, rule_id: &str, update: impl FnOnce(&mut RewriteRule)) -> bool {
268 let mut rules = self.rules.write();
269 if let Some(rule) = rules.iter_mut().find(|r| r.id == rule_id) {
270 update(rule);
271
272 let mut matcher = self.matcher.write();
273 *matcher = RuleMatcher::new(&rules);
274 true
275 } else {
276 false
277 }
278 }
279
280 pub fn set_rule_enabled(&self, rule_id: &str, enabled: bool) -> bool {
282 self.update_rule(rule_id, |r| r.enabled = enabled)
283 }
284
285 pub fn get_rules(&self) -> Vec<RewriteRule> {
287 self.rules.read().clone()
288 }
289
290 pub fn get_rule(&self, rule_id: &str) -> Option<RewriteRule> {
292 self.rules.read().iter().find(|r| r.id == rule_id).cloned()
293 }
294
295 pub fn stats(&self) -> RewriteStats {
297 self.metrics.stats()
298 }
299
300 fn evaluate_condition(&self, condition: &Condition, query: &str) -> bool {
302 match condition {
303 Condition::NoExistingLimit => !query.to_uppercase().contains("LIMIT"),
304 Condition::NoExistingOrderBy => !query.to_uppercase().contains("ORDER BY"),
305 Condition::HasSelectStar => {
306 let upper = query.to_uppercase();
307 upper.contains("SELECT *") || upper.contains("SELECT *")
308 }
309 Condition::SessionVar { name: _, exists } => {
310 *exists
313 }
314 Condition::ClientType { client_type: _ } => {
315 true
317 }
318 Condition::TableExists { table: _ } => {
319 true
321 }
322 Condition::And(conditions) => {
323 conditions.iter().all(|c| self.evaluate_condition(c, query))
324 }
325 Condition::Or(conditions) => {
326 conditions.iter().any(|c| self.evaluate_condition(c, query))
327 }
328 Condition::Not(condition) => !self.evaluate_condition(condition, query),
329 }
330 }
331}
332
333pub struct QueryRewriterBuilder {
335 config: RewriterConfig,
336}
337
338impl QueryRewriterBuilder {
339 pub fn new() -> Self {
341 Self {
342 config: RewriterConfig::default(),
343 }
344 }
345
346 pub fn enabled(mut self, enabled: bool) -> Self {
348 self.config.enabled = enabled;
349 self
350 }
351
352 pub fn log_rewrites(mut self, log: bool) -> Self {
354 self.config.log_rewrites = log;
355 self
356 }
357
358 pub fn log_errors(mut self, log: bool) -> Self {
360 self.config.log_errors = log;
361 self
362 }
363
364 pub fn rule(mut self, rule: impl Into<RewriteRule>) -> Self {
366 self.config.rules.push(rule.into());
367 self
368 }
369
370 pub fn rules(mut self, rules: Vec<RewriteRule>) -> Self {
372 self.config.rules.extend(rules);
373 self
374 }
375
376 pub fn expand_select_star(mut self, enabled: bool) -> Self {
378 self.config.expand_select_star = enabled;
379 self
380 }
381
382 pub fn add_default_limit(mut self, enabled: bool) -> Self {
384 self.config.add_default_limit = enabled;
385 self
386 }
387
388 pub fn default_limit(mut self, limit: u32) -> Self {
390 self.config.default_limit = limit;
391 self
392 }
393
394 pub fn build(self) -> QueryRewriter {
396 QueryRewriter::new(self.config)
397 }
398}
399
400impl Default for QueryRewriterBuilder {
401 fn default() -> Self {
402 Self::new()
403 }
404}
405
406#[derive(Debug, Clone)]
408pub struct RewriteResult {
409 pub original: String,
411
412 pub rewritten: String,
414
415 pub rules_applied: Vec<String>,
417
418 pub fingerprint: u64,
420
421 pub duration: std::time::Duration,
423}
424
425impl RewriteResult {
426 pub fn unchanged(query: &str) -> Self {
428 Self {
429 original: query.to_string(),
430 rewritten: query.to_string(),
431 rules_applied: Vec::new(),
432 fingerprint: 0,
433 duration: std::time::Duration::ZERO,
434 }
435 }
436
437 pub fn was_rewritten(&self) -> bool {
439 !self.rules_applied.is_empty()
440 }
441
442 pub fn query(&self) -> &str {
444 &self.rewritten
445 }
446}
447
448#[derive(Debug, Clone)]
450pub enum RewriteError {
451 ParseError(String),
453
454 TransformError(String),
456
457 RuleNotFound(String),
459
460 ForbiddenTable(String),
462
463 ConfigError(String),
465}
466
467impl std::fmt::Display for RewriteError {
468 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469 match self {
470 Self::ParseError(msg) => write!(f, "Parse error: {}", msg),
471 Self::TransformError(msg) => write!(f, "Transform error: {}", msg),
472 Self::RuleNotFound(id) => write!(f, "Rule not found: {}", id),
473 Self::ForbiddenTable(table) => write!(f, "Forbidden table: {}", table),
474 Self::ConfigError(msg) => write!(f, "Config error: {}", msg),
475 }
476 }
477}
478
479impl std::error::Error for RewriteError {}
480
481impl From<TransformError> for RewriteError {
482 fn from(e: TransformError) -> Self {
483 Self::TransformError(e.to_string())
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490
491 #[test]
492 fn test_rewriter_disabled() {
493 let rewriter = QueryRewriter::builder().enabled(false).build();
494
495 let result = rewriter.rewrite("SELECT * FROM users").unwrap();
496 assert!(!result.was_rewritten());
497 assert_eq!(result.query(), "SELECT * FROM users");
498 }
499
500 #[test]
501 fn test_rewriter_add_limit() {
502 let rewriter = QueryRewriter::builder()
503 .enabled(true)
504 .rule(
505 RewriteRule::build("add_limit")
506 .pattern(QueryPattern::All)
507 .transform(Transformation::AddLimit(100))
508 .condition(Condition::NoExistingLimit),
509 )
510 .build();
511
512 let result = rewriter.rewrite("SELECT * FROM users").unwrap();
513 assert!(result.was_rewritten());
514 assert!(result.rewritten.contains("LIMIT 100"));
515 }
516
517 #[test]
518 fn test_rewriter_skip_existing_limit() {
519 let rewriter = QueryRewriter::builder()
520 .enabled(true)
521 .rule(
522 RewriteRule::build("add_limit")
523 .pattern(QueryPattern::All)
524 .transform(Transformation::AddLimit(100))
525 .condition(Condition::NoExistingLimit),
526 )
527 .build();
528
529 let result = rewriter.rewrite("SELECT * FROM users LIMIT 50").unwrap();
530 assert!(!result.was_rewritten());
531 }
532
533 #[test]
534 fn test_rewriter_replace_query() {
535 let rewriter = QueryRewriter::builder()
536 .enabled(true)
537 .rule(
538 RewriteRule::build("replace")
539 .pattern(QueryPattern::Fingerprint(12345))
540 .transform(Transformation::Replace("SELECT 1".to_string())),
541 )
542 .build();
543
544 let result = rewriter.rewrite("SELECT * FROM users").unwrap();
546 assert!(!result.was_rewritten());
547 }
548
549 #[test]
550 fn test_add_remove_rule() {
551 let rewriter = QueryRewriter::builder().enabled(true).build();
552
553 assert!(rewriter.get_rules().is_empty());
554
555 rewriter.add_rule(
556 RewriteRule::build("test")
557 .pattern(QueryPattern::All)
558 .transform(Transformation::AddLimit(100)),
559 );
560
561 assert_eq!(rewriter.get_rules().len(), 1);
562
563 assert!(rewriter.remove_rule("test"));
564 assert!(rewriter.get_rules().is_empty());
565 }
566
567 #[test]
568 fn test_update_rule() {
569 let rewriter = QueryRewriter::builder()
570 .enabled(true)
571 .rule(
572 RewriteRule::build("test")
573 .pattern(QueryPattern::All)
574 .transform(Transformation::AddLimit(100)),
575 )
576 .build();
577
578 assert!(rewriter.get_rule("test").unwrap().enabled);
579
580 rewriter.set_rule_enabled("test", false);
581
582 assert!(!rewriter.get_rule("test").unwrap().enabled);
583 }
584}