Skip to main content

heliosdb_proxy/rewriter/
mod.rs

1//! Query Rewriting Module
2//!
3//! Transparent query rewriting at the proxy layer for optimization,
4//! compatibility, and security enforcement.
5//!
6//! # Features
7//!
8//! - **Pattern Matching**: Match queries by fingerprint, regex, AST, or table
9//! - **Transformations**: Index hints, SELECT * expansion, LIMIT addition
10//! - **Rule Engine**: Priority-based rule application
11//! - **AI Safety**: Agent query limits and forbidden table enforcement
12//!
13//! # Architecture
14//!
15//! ```text
16//!   Original Query → Parse → Match Rules → Apply Transformations → Rewritten Query
17//!                      │         │                  │
18//!                      │         │                  ├── Replace
19//!                      │         │                  ├── AddIndexHint
20//!                      │         ├── Fingerprint    ├── ExpandSelectStar
21//!                      │         ├── Regex          ├── AddLimit
22//!                      │         ├── AST            ├── AddWhereClause
23//!                      │         └── Table          └── ReplaceTable
24//!                      │
25//!                   SQL AST
26//! ```
27//!
28//! # Example
29//!
30//! ```rust,ignore
31//! use heliosdb::proxy::rewriter::{QueryRewriter, RewriteRule, QueryPattern, Transformation};
32//!
33//! let mut rewriter = QueryRewriter::builder()
34//!     .rule(RewriteRule::build("expand_star")
35//!         .pattern(QueryPattern::table("users"))
36//!         .transform(Transformation::ExpandSelectStar {
37//!             columns: vec!["id", "name", "email"]
38//!         }))
39//!     .rule(RewriteRule::build("add_limit")
40//!         .pattern(QueryPattern::all())
41//!         .transform(Transformation::AddLimit(1000)))
42//!     .build();
43//!
44//! let result = rewriter.rewrite("SELECT * FROM users")?;
45//! // Result: SELECT id, name, email FROM users LIMIT 1000
46//! ```
47
48pub mod config;
49pub mod matcher;
50pub mod metrics;
51pub mod parser;
52pub mod rules;
53pub mod transformer;
54
55// Re-export main types
56pub use config::{RewriterConfig, RewriterConfigBuilder};
57pub use matcher::{MatchResult, RuleMatcher};
58pub use metrics::{RewriteMetrics, RewriteStats, RuleStats};
59pub use parser::{ParsedQuery, SqlParser, SqlStatement};
60pub use rules::{
61    AstPattern, Condition, QueryPattern, RewriteRule, RewriteRuleBuilder, Transformation,
62};
63pub use transformer::{TransformError, TransformationEngine};
64
65use parking_lot::RwLock;
66use std::sync::Arc;
67
68/// Query rewriter
69///
70/// Main entry point for query rewriting operations.
71pub struct QueryRewriter {
72    /// Configuration
73    config: RewriterConfig,
74
75    /// SQL parser
76    parser: SqlParser,
77
78    /// Rewrite rules
79    rules: Arc<RwLock<Vec<RewriteRule>>>,
80
81    /// Rule matcher
82    matcher: Arc<RwLock<RuleMatcher>>,
83
84    /// Transformation engine
85    transformer: TransformationEngine,
86
87    /// Metrics
88    metrics: Arc<RewriteMetrics>,
89}
90
91impl QueryRewriter {
92    /// Create a new query rewriter
93    pub fn new(config: RewriterConfig) -> Self {
94        let rules = Arc::new(RwLock::new(config.rules.clone()));
95        let matcher = Arc::new(RwLock::new(RuleMatcher::new(&config.rules)));
96        let parser = SqlParser::new();
97        let transformer = TransformationEngine::new();
98        let metrics = Arc::new(RewriteMetrics::new());
99
100        Self {
101            config,
102            parser,
103            rules,
104            matcher,
105            transformer,
106            metrics,
107        }
108    }
109
110    /// Create a builder
111    pub fn builder() -> QueryRewriterBuilder {
112        QueryRewriterBuilder::new()
113    }
114
115    /// Check if rewriter is enabled
116    pub fn is_enabled(&self) -> bool {
117        self.config.enabled
118    }
119
120    /// Rewrite a query
121    ///
122    /// Returns the rewritten query and list of applied rules.
123    pub fn rewrite(&self, query: &str) -> Result<RewriteResult, RewriteError> {
124        if !self.config.enabled {
125            return Ok(RewriteResult::unchanged(query));
126        }
127
128        let start = std::time::Instant::now();
129
130        // Parse query to get fingerprint
131        let parsed = self.parser.parse(query)?;
132        let fingerprint = parsed.fingerprint();
133
134        // Check cache for previously rewritten query
135        // (In production, add caching by fingerprint)
136
137        // Match rules
138        let rules = self.rules.read();
139        let matcher = self.matcher.read();
140        let matched = matcher.match_query(&parsed, &rules);
141
142        if matched.is_empty() {
143            self.metrics.record_no_match(start.elapsed());
144            return Ok(RewriteResult::unchanged(query));
145        }
146
147        // Apply transformations
148        let mut current_query = query.to_string();
149        let mut applied_rules = Vec::new();
150
151        for rule in matched {
152            if !rule.enabled {
153                continue;
154            }
155
156            // Check condition
157            if let Some(ref condition) = rule.condition {
158                if !self.evaluate_condition(condition, &current_query) {
159                    continue;
160                }
161            }
162
163            // Apply transformation
164            match self.transformer.apply(&current_query, &rule.transformation) {
165                Ok(rewritten) => {
166                    current_query = rewritten;
167                    applied_rules.push(rule.id.clone());
168                    self.metrics.record_rule_match(&rule.id);
169                }
170                Err(e) => {
171                    if self.config.log_errors {
172                        eprintln!("Rewrite error for rule {}: {}", rule.id, e);
173                    }
174                    // Continue with other rules
175                }
176            }
177        }
178
179        let duration = start.elapsed();
180        self.metrics
181            .record_rewrite(duration, !applied_rules.is_empty());
182
183        if applied_rules.is_empty() {
184            Ok(RewriteResult::unchanged(query))
185        } else {
186            if self.config.log_rewrites {
187                println!("Rewritten query:");
188                println!("  Original: {}", query);
189                println!("  Rewritten: {}", current_query);
190                println!("  Rules: {:?}", applied_rules);
191            }
192
193            Ok(RewriteResult {
194                original: query.to_string(),
195                rewritten: current_query,
196                rules_applied: applied_rules,
197                fingerprint,
198                duration,
199            })
200        }
201    }
202
203    /// Test rewrite without metrics recording
204    pub fn test_rewrite(&self, query: &str) -> Result<RewriteResult, RewriteError> {
205        let parsed = self.parser.parse(query)?;
206        let fingerprint = parsed.fingerprint();
207
208        let rules = self.rules.read();
209        let matcher = self.matcher.read();
210        let matched = matcher.match_query(&parsed, &rules);
211
212        let mut current_query = query.to_string();
213        let mut applied_rules = Vec::new();
214
215        for rule in matched {
216            if !rule.enabled {
217                continue;
218            }
219
220            if let Some(ref condition) = rule.condition {
221                if !self.evaluate_condition(condition, &current_query) {
222                    continue;
223                }
224            }
225
226            if let Ok(rewritten) = self.transformer.apply(&current_query, &rule.transformation) {
227                current_query = rewritten;
228                applied_rules.push(rule.id.clone());
229            }
230        }
231
232        Ok(RewriteResult {
233            original: query.to_string(),
234            rewritten: current_query,
235            rules_applied: applied_rules,
236            fingerprint,
237            duration: std::time::Duration::ZERO,
238        })
239    }
240
241    /// Add a new rule
242    pub fn add_rule(&self, rule: impl Into<RewriteRule>) {
243        let mut rules = self.rules.write();
244        rules.push(rule.into());
245
246        // Rebuild matcher
247        let mut matcher = self.matcher.write();
248        *matcher = RuleMatcher::new(&rules);
249    }
250
251    /// Remove a rule by ID
252    pub fn remove_rule(&self, rule_id: &str) -> bool {
253        let mut rules = self.rules.write();
254        let initial_len = rules.len();
255        rules.retain(|r| r.id != rule_id);
256
257        if rules.len() != initial_len {
258            let mut matcher = self.matcher.write();
259            *matcher = RuleMatcher::new(&rules);
260            true
261        } else {
262            false
263        }
264    }
265
266    /// Update a rule
267    pub fn update_rule(&self, rule_id: &str, update: impl FnOnce(&mut RewriteRule)) -> bool {
268        let mut rules = self.rules.write();
269        if let Some(rule) = rules.iter_mut().find(|r| r.id == rule_id) {
270            update(rule);
271
272            let mut matcher = self.matcher.write();
273            *matcher = RuleMatcher::new(&rules);
274            true
275        } else {
276            false
277        }
278    }
279
280    /// Enable/disable a rule
281    pub fn set_rule_enabled(&self, rule_id: &str, enabled: bool) -> bool {
282        self.update_rule(rule_id, |r| r.enabled = enabled)
283    }
284
285    /// Get all rules
286    pub fn get_rules(&self) -> Vec<RewriteRule> {
287        self.rules.read().clone()
288    }
289
290    /// Get rule by ID
291    pub fn get_rule(&self, rule_id: &str) -> Option<RewriteRule> {
292        self.rules.read().iter().find(|r| r.id == rule_id).cloned()
293    }
294
295    /// Get statistics
296    pub fn stats(&self) -> RewriteStats {
297        self.metrics.stats()
298    }
299
300    /// Evaluate a condition
301    fn evaluate_condition(&self, condition: &Condition, query: &str) -> bool {
302        match condition {
303            Condition::NoExistingLimit => !query.to_uppercase().contains("LIMIT"),
304            Condition::NoExistingOrderBy => !query.to_uppercase().contains("ORDER BY"),
305            Condition::HasSelectStar => {
306                let upper = query.to_uppercase();
307                upper.contains("SELECT *") || upper.contains("SELECT  *")
308            }
309            Condition::SessionVar { name: _, exists } => {
310                // In production, check session variables
311                // For now, always return true if exists is expected
312                *exists
313            }
314            Condition::ClientType { client_type: _ } => {
315                // In production, check client metadata
316                true
317            }
318            Condition::TableExists { table: _ } => {
319                // In production, check schema cache
320                true
321            }
322            Condition::And(conditions) => {
323                conditions.iter().all(|c| self.evaluate_condition(c, query))
324            }
325            Condition::Or(conditions) => {
326                conditions.iter().any(|c| self.evaluate_condition(c, query))
327            }
328            Condition::Not(condition) => !self.evaluate_condition(condition, query),
329        }
330    }
331}
332
333/// Query rewriter builder
334pub struct QueryRewriterBuilder {
335    config: RewriterConfig,
336}
337
338impl QueryRewriterBuilder {
339    /// Create a new builder
340    pub fn new() -> Self {
341        Self {
342            config: RewriterConfig::default(),
343        }
344    }
345
346    /// Enable the rewriter
347    pub fn enabled(mut self, enabled: bool) -> Self {
348        self.config.enabled = enabled;
349        self
350    }
351
352    /// Log rewrites
353    pub fn log_rewrites(mut self, log: bool) -> Self {
354        self.config.log_rewrites = log;
355        self
356    }
357
358    /// Log errors
359    pub fn log_errors(mut self, log: bool) -> Self {
360        self.config.log_errors = log;
361        self
362    }
363
364    /// Add a rule
365    pub fn rule(mut self, rule: impl Into<RewriteRule>) -> Self {
366        self.config.rules.push(rule.into());
367        self
368    }
369
370    /// Add multiple rules
371    pub fn rules(mut self, rules: Vec<RewriteRule>) -> Self {
372        self.config.rules.extend(rules);
373        self
374    }
375
376    /// Enable SELECT * expansion
377    pub fn expand_select_star(mut self, enabled: bool) -> Self {
378        self.config.expand_select_star = enabled;
379        self
380    }
381
382    /// Add default LIMIT to queries
383    pub fn add_default_limit(mut self, enabled: bool) -> Self {
384        self.config.add_default_limit = enabled;
385        self
386    }
387
388    /// Set default LIMIT value
389    pub fn default_limit(mut self, limit: u32) -> Self {
390        self.config.default_limit = limit;
391        self
392    }
393
394    /// Build the rewriter
395    pub fn build(self) -> QueryRewriter {
396        QueryRewriter::new(self.config)
397    }
398}
399
400impl Default for QueryRewriterBuilder {
401    fn default() -> Self {
402        Self::new()
403    }
404}
405
406/// Result of a rewrite operation
407#[derive(Debug, Clone)]
408pub struct RewriteResult {
409    /// Original query
410    pub original: String,
411
412    /// Rewritten query (same as original if no changes)
413    pub rewritten: String,
414
415    /// IDs of rules that were applied
416    pub rules_applied: Vec<String>,
417
418    /// Query fingerprint
419    pub fingerprint: u64,
420
421    /// Time taken to rewrite
422    pub duration: std::time::Duration,
423}
424
425impl RewriteResult {
426    /// Create an unchanged result
427    pub fn unchanged(query: &str) -> Self {
428        Self {
429            original: query.to_string(),
430            rewritten: query.to_string(),
431            rules_applied: Vec::new(),
432            fingerprint: 0,
433            duration: std::time::Duration::ZERO,
434        }
435    }
436
437    /// Check if query was modified
438    pub fn was_rewritten(&self) -> bool {
439        !self.rules_applied.is_empty()
440    }
441
442    /// Get the final query (rewritten or original)
443    pub fn query(&self) -> &str {
444        &self.rewritten
445    }
446}
447
448/// Rewrite error
449#[derive(Debug, Clone)]
450pub enum RewriteError {
451    /// Failed to parse query
452    ParseError(String),
453
454    /// Transformation failed
455    TransformError(String),
456
457    /// Rule not found
458    RuleNotFound(String),
459
460    /// Forbidden table access
461    ForbiddenTable(String),
462
463    /// Configuration error
464    ConfigError(String),
465}
466
467impl std::fmt::Display for RewriteError {
468    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469        match self {
470            Self::ParseError(msg) => write!(f, "Parse error: {}", msg),
471            Self::TransformError(msg) => write!(f, "Transform error: {}", msg),
472            Self::RuleNotFound(id) => write!(f, "Rule not found: {}", id),
473            Self::ForbiddenTable(table) => write!(f, "Forbidden table: {}", table),
474            Self::ConfigError(msg) => write!(f, "Config error: {}", msg),
475        }
476    }
477}
478
479impl std::error::Error for RewriteError {}
480
481impl From<TransformError> for RewriteError {
482    fn from(e: TransformError) -> Self {
483        Self::TransformError(e.to_string())
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490
491    #[test]
492    fn test_rewriter_disabled() {
493        let rewriter = QueryRewriter::builder().enabled(false).build();
494
495        let result = rewriter.rewrite("SELECT * FROM users").unwrap();
496        assert!(!result.was_rewritten());
497        assert_eq!(result.query(), "SELECT * FROM users");
498    }
499
500    #[test]
501    fn test_rewriter_add_limit() {
502        let rewriter = QueryRewriter::builder()
503            .enabled(true)
504            .rule(
505                RewriteRule::build("add_limit")
506                    .pattern(QueryPattern::All)
507                    .transform(Transformation::AddLimit(100))
508                    .condition(Condition::NoExistingLimit),
509            )
510            .build();
511
512        let result = rewriter.rewrite("SELECT * FROM users").unwrap();
513        assert!(result.was_rewritten());
514        assert!(result.rewritten.contains("LIMIT 100"));
515    }
516
517    #[test]
518    fn test_rewriter_skip_existing_limit() {
519        let rewriter = QueryRewriter::builder()
520            .enabled(true)
521            .rule(
522                RewriteRule::build("add_limit")
523                    .pattern(QueryPattern::All)
524                    .transform(Transformation::AddLimit(100))
525                    .condition(Condition::NoExistingLimit),
526            )
527            .build();
528
529        let result = rewriter.rewrite("SELECT * FROM users LIMIT 50").unwrap();
530        assert!(!result.was_rewritten());
531    }
532
533    #[test]
534    fn test_rewriter_replace_query() {
535        let rewriter = QueryRewriter::builder()
536            .enabled(true)
537            .rule(
538                RewriteRule::build("replace")
539                    .pattern(QueryPattern::Fingerprint(12345))
540                    .transform(Transformation::Replace("SELECT 1".to_string())),
541            )
542            .build();
543
544        // This won't match because fingerprint doesn't match
545        let result = rewriter.rewrite("SELECT * FROM users").unwrap();
546        assert!(!result.was_rewritten());
547    }
548
549    #[test]
550    fn test_add_remove_rule() {
551        let rewriter = QueryRewriter::builder().enabled(true).build();
552
553        assert!(rewriter.get_rules().is_empty());
554
555        rewriter.add_rule(
556            RewriteRule::build("test")
557                .pattern(QueryPattern::All)
558                .transform(Transformation::AddLimit(100)),
559        );
560
561        assert_eq!(rewriter.get_rules().len(), 1);
562
563        assert!(rewriter.remove_rule("test"));
564        assert!(rewriter.get_rules().is_empty());
565    }
566
567    #[test]
568    fn test_update_rule() {
569        let rewriter = QueryRewriter::builder()
570            .enabled(true)
571            .rule(
572                RewriteRule::build("test")
573                    .pattern(QueryPattern::All)
574                    .transform(Transformation::AddLimit(100)),
575            )
576            .build();
577
578        assert!(rewriter.get_rule("test").unwrap().enabled);
579
580        rewriter.set_rule_enabled("test", false);
581
582        assert!(!rewriter.get_rule("test").unwrap().enabled);
583    }
584}