1pub mod config;
49pub mod rules;
50pub mod matcher;
51pub mod transformer;
52pub mod parser;
53pub mod metrics;
54
55pub use config::{RewriterConfig, RewriterConfigBuilder};
57pub use rules::{RewriteRule, RewriteRuleBuilder, QueryPattern, AstPattern, Transformation, Condition};
58pub use matcher::{RuleMatcher, MatchResult};
59pub use transformer::{TransformationEngine, TransformError};
60pub use parser::{SqlParser, ParsedQuery, SqlStatement};
61pub use metrics::{RewriteMetrics, RewriteStats, RuleStats};
62
63use std::sync::Arc;
64use parking_lot::RwLock;
65
66pub struct QueryRewriter {
70 config: RewriterConfig,
72
73 parser: SqlParser,
75
76 rules: Arc<RwLock<Vec<RewriteRule>>>,
78
79 matcher: Arc<RwLock<RuleMatcher>>,
81
82 transformer: TransformationEngine,
84
85 metrics: Arc<RewriteMetrics>,
87}
88
89impl QueryRewriter {
90 pub fn new(config: RewriterConfig) -> Self {
92 let rules = Arc::new(RwLock::new(config.rules.clone()));
93 let matcher = Arc::new(RwLock::new(RuleMatcher::new(&config.rules)));
94 let parser = SqlParser::new();
95 let transformer = TransformationEngine::new();
96 let metrics = Arc::new(RewriteMetrics::new());
97
98 Self {
99 config,
100 parser,
101 rules,
102 matcher,
103 transformer,
104 metrics,
105 }
106 }
107
108 pub fn builder() -> QueryRewriterBuilder {
110 QueryRewriterBuilder::new()
111 }
112
113 pub fn is_enabled(&self) -> bool {
115 self.config.enabled
116 }
117
118 pub fn rewrite(&self, query: &str) -> Result<RewriteResult, RewriteError> {
122 if !self.config.enabled {
123 return Ok(RewriteResult::unchanged(query));
124 }
125
126 let start = std::time::Instant::now();
127
128 let parsed = self.parser.parse(query)?;
130 let fingerprint = parsed.fingerprint();
131
132 let rules = self.rules.read();
137 let matcher = self.matcher.read();
138 let matched = matcher.match_query(&parsed, &rules);
139
140 if matched.is_empty() {
141 self.metrics.record_no_match(start.elapsed());
142 return Ok(RewriteResult::unchanged(query));
143 }
144
145 let mut current_query = query.to_string();
147 let mut applied_rules = Vec::new();
148
149 for rule in matched {
150 if !rule.enabled {
151 continue;
152 }
153
154 if let Some(ref condition) = rule.condition {
156 if !self.evaluate_condition(condition, ¤t_query) {
157 continue;
158 }
159 }
160
161 match self.transformer.apply(¤t_query, &rule.transformation) {
163 Ok(rewritten) => {
164 current_query = rewritten;
165 applied_rules.push(rule.id.clone());
166 self.metrics.record_rule_match(&rule.id);
167 }
168 Err(e) => {
169 if self.config.log_errors {
170 eprintln!("Rewrite error for rule {}: {}", rule.id, e);
171 }
172 }
174 }
175 }
176
177 let duration = start.elapsed();
178 self.metrics.record_rewrite(duration, !applied_rules.is_empty());
179
180 if applied_rules.is_empty() {
181 Ok(RewriteResult::unchanged(query))
182 } else {
183 if self.config.log_rewrites {
184 println!("Rewritten query:");
185 println!(" Original: {}", query);
186 println!(" Rewritten: {}", current_query);
187 println!(" Rules: {:?}", applied_rules);
188 }
189
190 Ok(RewriteResult {
191 original: query.to_string(),
192 rewritten: current_query,
193 rules_applied: applied_rules,
194 fingerprint,
195 duration,
196 })
197 }
198 }
199
200 pub fn test_rewrite(&self, query: &str) -> Result<RewriteResult, RewriteError> {
202 let parsed = self.parser.parse(query)?;
203 let fingerprint = parsed.fingerprint();
204
205 let rules = self.rules.read();
206 let matcher = self.matcher.read();
207 let matched = matcher.match_query(&parsed, &rules);
208
209 let mut current_query = query.to_string();
210 let mut applied_rules = Vec::new();
211
212 for rule in matched {
213 if !rule.enabled {
214 continue;
215 }
216
217 if let Some(ref condition) = rule.condition {
218 if !self.evaluate_condition(condition, ¤t_query) {
219 continue;
220 }
221 }
222
223 if let Ok(rewritten) = self.transformer.apply(¤t_query, &rule.transformation) {
224 current_query = rewritten;
225 applied_rules.push(rule.id.clone());
226 }
227 }
228
229 Ok(RewriteResult {
230 original: query.to_string(),
231 rewritten: current_query,
232 rules_applied: applied_rules,
233 fingerprint,
234 duration: std::time::Duration::ZERO,
235 })
236 }
237
238 pub fn add_rule(&self, rule: impl Into<RewriteRule>) {
240 let mut rules = self.rules.write();
241 rules.push(rule.into());
242
243 let mut matcher = self.matcher.write();
245 *matcher = RuleMatcher::new(&rules);
246 }
247
248 pub fn remove_rule(&self, rule_id: &str) -> bool {
250 let mut rules = self.rules.write();
251 let initial_len = rules.len();
252 rules.retain(|r| r.id != rule_id);
253
254 if rules.len() != initial_len {
255 let mut matcher = self.matcher.write();
256 *matcher = RuleMatcher::new(&rules);
257 true
258 } else {
259 false
260 }
261 }
262
263 pub fn update_rule(&self, rule_id: &str, update: impl FnOnce(&mut RewriteRule)) -> bool {
265 let mut rules = self.rules.write();
266 if let Some(rule) = rules.iter_mut().find(|r| r.id == rule_id) {
267 update(rule);
268
269 let mut matcher = self.matcher.write();
270 *matcher = RuleMatcher::new(&rules);
271 true
272 } else {
273 false
274 }
275 }
276
277 pub fn set_rule_enabled(&self, rule_id: &str, enabled: bool) -> bool {
279 self.update_rule(rule_id, |r| r.enabled = enabled)
280 }
281
282 pub fn get_rules(&self) -> Vec<RewriteRule> {
284 self.rules.read().clone()
285 }
286
287 pub fn get_rule(&self, rule_id: &str) -> Option<RewriteRule> {
289 self.rules.read().iter().find(|r| r.id == rule_id).cloned()
290 }
291
292 pub fn stats(&self) -> RewriteStats {
294 self.metrics.stats()
295 }
296
297 fn evaluate_condition(&self, condition: &Condition, query: &str) -> bool {
299 match condition {
300 Condition::NoExistingLimit => {
301 !query.to_uppercase().contains("LIMIT")
302 }
303 Condition::NoExistingOrderBy => {
304 !query.to_uppercase().contains("ORDER BY")
305 }
306 Condition::HasSelectStar => {
307 let upper = query.to_uppercase();
308 upper.contains("SELECT *") || upper.contains("SELECT *")
309 }
310 Condition::SessionVar { name: _, exists } => {
311 *exists
314 }
315 Condition::ClientType { client_type: _ } => {
316 true
318 }
319 Condition::TableExists { table: _ } => {
320 true
322 }
323 Condition::And(conditions) => {
324 conditions.iter().all(|c| self.evaluate_condition(c, query))
325 }
326 Condition::Or(conditions) => {
327 conditions.iter().any(|c| self.evaluate_condition(c, query))
328 }
329 Condition::Not(condition) => {
330 !self.evaluate_condition(condition, query)
331 }
332 }
333 }
334}
335
336pub struct QueryRewriterBuilder {
338 config: RewriterConfig,
339}
340
341impl QueryRewriterBuilder {
342 pub fn new() -> Self {
344 Self {
345 config: RewriterConfig::default(),
346 }
347 }
348
349 pub fn enabled(mut self, enabled: bool) -> Self {
351 self.config.enabled = enabled;
352 self
353 }
354
355 pub fn log_rewrites(mut self, log: bool) -> Self {
357 self.config.log_rewrites = log;
358 self
359 }
360
361 pub fn log_errors(mut self, log: bool) -> Self {
363 self.config.log_errors = log;
364 self
365 }
366
367 pub fn rule(mut self, rule: impl Into<RewriteRule>) -> Self {
369 self.config.rules.push(rule.into());
370 self
371 }
372
373 pub fn rules(mut self, rules: Vec<RewriteRule>) -> Self {
375 self.config.rules.extend(rules);
376 self
377 }
378
379 pub fn expand_select_star(mut self, enabled: bool) -> Self {
381 self.config.expand_select_star = enabled;
382 self
383 }
384
385 pub fn add_default_limit(mut self, enabled: bool) -> Self {
387 self.config.add_default_limit = enabled;
388 self
389 }
390
391 pub fn default_limit(mut self, limit: u32) -> Self {
393 self.config.default_limit = limit;
394 self
395 }
396
397 pub fn build(self) -> QueryRewriter {
399 QueryRewriter::new(self.config)
400 }
401}
402
403impl Default for QueryRewriterBuilder {
404 fn default() -> Self {
405 Self::new()
406 }
407}
408
409#[derive(Debug, Clone)]
411pub struct RewriteResult {
412 pub original: String,
414
415 pub rewritten: String,
417
418 pub rules_applied: Vec<String>,
420
421 pub fingerprint: u64,
423
424 pub duration: std::time::Duration,
426}
427
428impl RewriteResult {
429 pub fn unchanged(query: &str) -> Self {
431 Self {
432 original: query.to_string(),
433 rewritten: query.to_string(),
434 rules_applied: Vec::new(),
435 fingerprint: 0,
436 duration: std::time::Duration::ZERO,
437 }
438 }
439
440 pub fn was_rewritten(&self) -> bool {
442 !self.rules_applied.is_empty()
443 }
444
445 pub fn query(&self) -> &str {
447 &self.rewritten
448 }
449}
450
451#[derive(Debug, Clone)]
453pub enum RewriteError {
454 ParseError(String),
456
457 TransformError(String),
459
460 RuleNotFound(String),
462
463 ForbiddenTable(String),
465
466 ConfigError(String),
468}
469
470impl std::fmt::Display for RewriteError {
471 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472 match self {
473 Self::ParseError(msg) => write!(f, "Parse error: {}", msg),
474 Self::TransformError(msg) => write!(f, "Transform error: {}", msg),
475 Self::RuleNotFound(id) => write!(f, "Rule not found: {}", id),
476 Self::ForbiddenTable(table) => write!(f, "Forbidden table: {}", table),
477 Self::ConfigError(msg) => write!(f, "Config error: {}", msg),
478 }
479 }
480}
481
482impl std::error::Error for RewriteError {}
483
484impl From<TransformError> for RewriteError {
485 fn from(e: TransformError) -> Self {
486 Self::TransformError(e.to_string())
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493
494 #[test]
495 fn test_rewriter_disabled() {
496 let rewriter = QueryRewriter::builder()
497 .enabled(false)
498 .build();
499
500 let result = rewriter.rewrite("SELECT * FROM users").unwrap();
501 assert!(!result.was_rewritten());
502 assert_eq!(result.query(), "SELECT * FROM users");
503 }
504
505 #[test]
506 fn test_rewriter_add_limit() {
507 let rewriter = QueryRewriter::builder()
508 .enabled(true)
509 .rule(RewriteRule::new("add_limit")
510 .pattern(QueryPattern::All)
511 .transform(Transformation::AddLimit(100))
512 .condition(Condition::NoExistingLimit))
513 .build();
514
515 let result = rewriter.rewrite("SELECT * FROM users").unwrap();
516 assert!(result.was_rewritten());
517 assert!(result.rewritten.contains("LIMIT 100"));
518 }
519
520 #[test]
521 fn test_rewriter_skip_existing_limit() {
522 let rewriter = QueryRewriter::builder()
523 .enabled(true)
524 .rule(RewriteRule::new("add_limit")
525 .pattern(QueryPattern::All)
526 .transform(Transformation::AddLimit(100))
527 .condition(Condition::NoExistingLimit))
528 .build();
529
530 let result = rewriter.rewrite("SELECT * FROM users LIMIT 50").unwrap();
531 assert!(!result.was_rewritten());
532 }
533
534 #[test]
535 fn test_rewriter_replace_query() {
536 let rewriter = QueryRewriter::builder()
537 .enabled(true)
538 .rule(RewriteRule::new("replace")
539 .pattern(QueryPattern::Fingerprint(12345))
540 .transform(Transformation::Replace("SELECT 1".to_string())))
541 .build();
542
543 let result = rewriter.rewrite("SELECT * FROM users").unwrap();
545 assert!(!result.was_rewritten());
546 }
547
548 #[test]
549 fn test_add_remove_rule() {
550 let rewriter = QueryRewriter::builder()
551 .enabled(true)
552 .build();
553
554 assert!(rewriter.get_rules().is_empty());
555
556 rewriter.add_rule(RewriteRule::new("test")
557 .pattern(QueryPattern::All)
558 .transform(Transformation::AddLimit(100)));
559
560 assert_eq!(rewriter.get_rules().len(), 1);
561
562 assert!(rewriter.remove_rule("test"));
563 assert!(rewriter.get_rules().is_empty());
564 }
565
566 #[test]
567 fn test_update_rule() {
568 let rewriter = QueryRewriter::builder()
569 .enabled(true)
570 .rule(RewriteRule::new("test")
571 .pattern(QueryPattern::All)
572 .transform(Transformation::AddLimit(100)))
573 .build();
574
575 assert!(rewriter.get_rule("test").unwrap().enabled);
576
577 rewriter.set_rule_enabled("test", false);
578
579 assert!(!rewriter.get_rule("test").unwrap().enabled);
580 }
581}