Skip to main content

heliosdb_proxy/rewriter/
mod.rs

1//! Query Rewriting Module
2//!
3//! Transparent query rewriting at the proxy layer for optimization,
4//! compatibility, and security enforcement.
5//!
6//! # Features
7//!
8//! - **Pattern Matching**: Match queries by fingerprint, regex, AST, or table
9//! - **Transformations**: Index hints, SELECT * expansion, LIMIT addition
10//! - **Rule Engine**: Priority-based rule application
11//! - **AI Safety**: Agent query limits and forbidden table enforcement
12//!
13//! # Architecture
14//!
15//! ```text
16//!   Original Query → Parse → Match Rules → Apply Transformations → Rewritten Query
17//!                      │         │                  │
18//!                      │         │                  ├── Replace
19//!                      │         │                  ├── AddIndexHint
20//!                      │         ├── Fingerprint    ├── ExpandSelectStar
21//!                      │         ├── Regex          ├── AddLimit
22//!                      │         ├── AST            ├── AddWhereClause
23//!                      │         └── Table          └── ReplaceTable
24//!                      │
25//!                   SQL AST
26//! ```
27//!
28//! # Example
29//!
30//! ```rust,ignore
31//! use heliosdb::proxy::rewriter::{QueryRewriter, RewriteRule, QueryPattern, Transformation};
32//!
33//! let mut rewriter = QueryRewriter::builder()
34//!     .rule(RewriteRule::new("expand_star")
35//!         .pattern(QueryPattern::table("users"))
36//!         .transform(Transformation::ExpandSelectStar {
37//!             columns: vec!["id", "name", "email"]
38//!         }))
39//!     .rule(RewriteRule::new("add_limit")
40//!         .pattern(QueryPattern::all())
41//!         .transform(Transformation::AddLimit(1000)))
42//!     .build();
43//!
44//! let result = rewriter.rewrite("SELECT * FROM users")?;
45//! // Result: SELECT id, name, email FROM users LIMIT 1000
46//! ```
47
48pub mod config;
49pub mod rules;
50pub mod matcher;
51pub mod transformer;
52pub mod parser;
53pub mod metrics;
54
55// Re-export main types
56pub use config::{RewriterConfig, RewriterConfigBuilder};
57pub use rules::{RewriteRule, RewriteRuleBuilder, QueryPattern, AstPattern, Transformation, Condition};
58pub use matcher::{RuleMatcher, MatchResult};
59pub use transformer::{TransformationEngine, TransformError};
60pub use parser::{SqlParser, ParsedQuery, SqlStatement};
61pub use metrics::{RewriteMetrics, RewriteStats, RuleStats};
62
63use std::sync::Arc;
64use parking_lot::RwLock;
65
66/// Query rewriter
67///
68/// Main entry point for query rewriting operations.
69pub struct QueryRewriter {
70    /// Configuration
71    config: RewriterConfig,
72
73    /// SQL parser
74    parser: SqlParser,
75
76    /// Rewrite rules
77    rules: Arc<RwLock<Vec<RewriteRule>>>,
78
79    /// Rule matcher
80    matcher: Arc<RwLock<RuleMatcher>>,
81
82    /// Transformation engine
83    transformer: TransformationEngine,
84
85    /// Metrics
86    metrics: Arc<RewriteMetrics>,
87}
88
89impl QueryRewriter {
90    /// Create a new query rewriter
91    pub fn new(config: RewriterConfig) -> Self {
92        let rules = Arc::new(RwLock::new(config.rules.clone()));
93        let matcher = Arc::new(RwLock::new(RuleMatcher::new(&config.rules)));
94        let parser = SqlParser::new();
95        let transformer = TransformationEngine::new();
96        let metrics = Arc::new(RewriteMetrics::new());
97
98        Self {
99            config,
100            parser,
101            rules,
102            matcher,
103            transformer,
104            metrics,
105        }
106    }
107
108    /// Create a builder
109    pub fn builder() -> QueryRewriterBuilder {
110        QueryRewriterBuilder::new()
111    }
112
113    /// Check if rewriter is enabled
114    pub fn is_enabled(&self) -> bool {
115        self.config.enabled
116    }
117
118    /// Rewrite a query
119    ///
120    /// Returns the rewritten query and list of applied rules.
121    pub fn rewrite(&self, query: &str) -> Result<RewriteResult, RewriteError> {
122        if !self.config.enabled {
123            return Ok(RewriteResult::unchanged(query));
124        }
125
126        let start = std::time::Instant::now();
127
128        // Parse query to get fingerprint
129        let parsed = self.parser.parse(query)?;
130        let fingerprint = parsed.fingerprint();
131
132        // Check cache for previously rewritten query
133        // (In production, add caching by fingerprint)
134
135        // Match rules
136        let rules = self.rules.read();
137        let matcher = self.matcher.read();
138        let matched = matcher.match_query(&parsed, &rules);
139
140        if matched.is_empty() {
141            self.metrics.record_no_match(start.elapsed());
142            return Ok(RewriteResult::unchanged(query));
143        }
144
145        // Apply transformations
146        let mut current_query = query.to_string();
147        let mut applied_rules = Vec::new();
148
149        for rule in matched {
150            if !rule.enabled {
151                continue;
152            }
153
154            // Check condition
155            if let Some(ref condition) = rule.condition {
156                if !self.evaluate_condition(condition, &current_query) {
157                    continue;
158                }
159            }
160
161            // Apply transformation
162            match self.transformer.apply(&current_query, &rule.transformation) {
163                Ok(rewritten) => {
164                    current_query = rewritten;
165                    applied_rules.push(rule.id.clone());
166                    self.metrics.record_rule_match(&rule.id);
167                }
168                Err(e) => {
169                    if self.config.log_errors {
170                        eprintln!("Rewrite error for rule {}: {}", rule.id, e);
171                    }
172                    // Continue with other rules
173                }
174            }
175        }
176
177        let duration = start.elapsed();
178        self.metrics.record_rewrite(duration, !applied_rules.is_empty());
179
180        if applied_rules.is_empty() {
181            Ok(RewriteResult::unchanged(query))
182        } else {
183            if self.config.log_rewrites {
184                println!("Rewritten query:");
185                println!("  Original: {}", query);
186                println!("  Rewritten: {}", current_query);
187                println!("  Rules: {:?}", applied_rules);
188            }
189
190            Ok(RewriteResult {
191                original: query.to_string(),
192                rewritten: current_query,
193                rules_applied: applied_rules,
194                fingerprint,
195                duration,
196            })
197        }
198    }
199
200    /// Test rewrite without metrics recording
201    pub fn test_rewrite(&self, query: &str) -> Result<RewriteResult, RewriteError> {
202        let parsed = self.parser.parse(query)?;
203        let fingerprint = parsed.fingerprint();
204
205        let rules = self.rules.read();
206        let matcher = self.matcher.read();
207        let matched = matcher.match_query(&parsed, &rules);
208
209        let mut current_query = query.to_string();
210        let mut applied_rules = Vec::new();
211
212        for rule in matched {
213            if !rule.enabled {
214                continue;
215            }
216
217            if let Some(ref condition) = rule.condition {
218                if !self.evaluate_condition(condition, &current_query) {
219                    continue;
220                }
221            }
222
223            if let Ok(rewritten) = self.transformer.apply(&current_query, &rule.transformation) {
224                current_query = rewritten;
225                applied_rules.push(rule.id.clone());
226            }
227        }
228
229        Ok(RewriteResult {
230            original: query.to_string(),
231            rewritten: current_query,
232            rules_applied: applied_rules,
233            fingerprint,
234            duration: std::time::Duration::ZERO,
235        })
236    }
237
238    /// Add a new rule
239    pub fn add_rule(&self, rule: impl Into<RewriteRule>) {
240        let mut rules = self.rules.write();
241        rules.push(rule.into());
242
243        // Rebuild matcher
244        let mut matcher = self.matcher.write();
245        *matcher = RuleMatcher::new(&rules);
246    }
247
248    /// Remove a rule by ID
249    pub fn remove_rule(&self, rule_id: &str) -> bool {
250        let mut rules = self.rules.write();
251        let initial_len = rules.len();
252        rules.retain(|r| r.id != rule_id);
253
254        if rules.len() != initial_len {
255            let mut matcher = self.matcher.write();
256            *matcher = RuleMatcher::new(&rules);
257            true
258        } else {
259            false
260        }
261    }
262
263    /// Update a rule
264    pub fn update_rule(&self, rule_id: &str, update: impl FnOnce(&mut RewriteRule)) -> bool {
265        let mut rules = self.rules.write();
266        if let Some(rule) = rules.iter_mut().find(|r| r.id == rule_id) {
267            update(rule);
268
269            let mut matcher = self.matcher.write();
270            *matcher = RuleMatcher::new(&rules);
271            true
272        } else {
273            false
274        }
275    }
276
277    /// Enable/disable a rule
278    pub fn set_rule_enabled(&self, rule_id: &str, enabled: bool) -> bool {
279        self.update_rule(rule_id, |r| r.enabled = enabled)
280    }
281
282    /// Get all rules
283    pub fn get_rules(&self) -> Vec<RewriteRule> {
284        self.rules.read().clone()
285    }
286
287    /// Get rule by ID
288    pub fn get_rule(&self, rule_id: &str) -> Option<RewriteRule> {
289        self.rules.read().iter().find(|r| r.id == rule_id).cloned()
290    }
291
292    /// Get statistics
293    pub fn stats(&self) -> RewriteStats {
294        self.metrics.stats()
295    }
296
297    /// Evaluate a condition
298    fn evaluate_condition(&self, condition: &Condition, query: &str) -> bool {
299        match condition {
300            Condition::NoExistingLimit => {
301                !query.to_uppercase().contains("LIMIT")
302            }
303            Condition::NoExistingOrderBy => {
304                !query.to_uppercase().contains("ORDER BY")
305            }
306            Condition::HasSelectStar => {
307                let upper = query.to_uppercase();
308                upper.contains("SELECT *") || upper.contains("SELECT  *")
309            }
310            Condition::SessionVar { name: _, exists } => {
311                // In production, check session variables
312                // For now, always return true if exists is expected
313                *exists
314            }
315            Condition::ClientType { client_type: _ } => {
316                // In production, check client metadata
317                true
318            }
319            Condition::TableExists { table: _ } => {
320                // In production, check schema cache
321                true
322            }
323            Condition::And(conditions) => {
324                conditions.iter().all(|c| self.evaluate_condition(c, query))
325            }
326            Condition::Or(conditions) => {
327                conditions.iter().any(|c| self.evaluate_condition(c, query))
328            }
329            Condition::Not(condition) => {
330                !self.evaluate_condition(condition, query)
331            }
332        }
333    }
334}
335
336/// Query rewriter builder
337pub struct QueryRewriterBuilder {
338    config: RewriterConfig,
339}
340
341impl QueryRewriterBuilder {
342    /// Create a new builder
343    pub fn new() -> Self {
344        Self {
345            config: RewriterConfig::default(),
346        }
347    }
348
349    /// Enable the rewriter
350    pub fn enabled(mut self, enabled: bool) -> Self {
351        self.config.enabled = enabled;
352        self
353    }
354
355    /// Log rewrites
356    pub fn log_rewrites(mut self, log: bool) -> Self {
357        self.config.log_rewrites = log;
358        self
359    }
360
361    /// Log errors
362    pub fn log_errors(mut self, log: bool) -> Self {
363        self.config.log_errors = log;
364        self
365    }
366
367    /// Add a rule
368    pub fn rule(mut self, rule: impl Into<RewriteRule>) -> Self {
369        self.config.rules.push(rule.into());
370        self
371    }
372
373    /// Add multiple rules
374    pub fn rules(mut self, rules: Vec<RewriteRule>) -> Self {
375        self.config.rules.extend(rules);
376        self
377    }
378
379    /// Enable SELECT * expansion
380    pub fn expand_select_star(mut self, enabled: bool) -> Self {
381        self.config.expand_select_star = enabled;
382        self
383    }
384
385    /// Add default LIMIT to queries
386    pub fn add_default_limit(mut self, enabled: bool) -> Self {
387        self.config.add_default_limit = enabled;
388        self
389    }
390
391    /// Set default LIMIT value
392    pub fn default_limit(mut self, limit: u32) -> Self {
393        self.config.default_limit = limit;
394        self
395    }
396
397    /// Build the rewriter
398    pub fn build(self) -> QueryRewriter {
399        QueryRewriter::new(self.config)
400    }
401}
402
403impl Default for QueryRewriterBuilder {
404    fn default() -> Self {
405        Self::new()
406    }
407}
408
409/// Result of a rewrite operation
410#[derive(Debug, Clone)]
411pub struct RewriteResult {
412    /// Original query
413    pub original: String,
414
415    /// Rewritten query (same as original if no changes)
416    pub rewritten: String,
417
418    /// IDs of rules that were applied
419    pub rules_applied: Vec<String>,
420
421    /// Query fingerprint
422    pub fingerprint: u64,
423
424    /// Time taken to rewrite
425    pub duration: std::time::Duration,
426}
427
428impl RewriteResult {
429    /// Create an unchanged result
430    pub fn unchanged(query: &str) -> Self {
431        Self {
432            original: query.to_string(),
433            rewritten: query.to_string(),
434            rules_applied: Vec::new(),
435            fingerprint: 0,
436            duration: std::time::Duration::ZERO,
437        }
438    }
439
440    /// Check if query was modified
441    pub fn was_rewritten(&self) -> bool {
442        !self.rules_applied.is_empty()
443    }
444
445    /// Get the final query (rewritten or original)
446    pub fn query(&self) -> &str {
447        &self.rewritten
448    }
449}
450
451/// Rewrite error
452#[derive(Debug, Clone)]
453pub enum RewriteError {
454    /// Failed to parse query
455    ParseError(String),
456
457    /// Transformation failed
458    TransformError(String),
459
460    /// Rule not found
461    RuleNotFound(String),
462
463    /// Forbidden table access
464    ForbiddenTable(String),
465
466    /// Configuration error
467    ConfigError(String),
468}
469
470impl std::fmt::Display for RewriteError {
471    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472        match self {
473            Self::ParseError(msg) => write!(f, "Parse error: {}", msg),
474            Self::TransformError(msg) => write!(f, "Transform error: {}", msg),
475            Self::RuleNotFound(id) => write!(f, "Rule not found: {}", id),
476            Self::ForbiddenTable(table) => write!(f, "Forbidden table: {}", table),
477            Self::ConfigError(msg) => write!(f, "Config error: {}", msg),
478        }
479    }
480}
481
482impl std::error::Error for RewriteError {}
483
484impl From<TransformError> for RewriteError {
485    fn from(e: TransformError) -> Self {
486        Self::TransformError(e.to_string())
487    }
488}
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493
494    #[test]
495    fn test_rewriter_disabled() {
496        let rewriter = QueryRewriter::builder()
497            .enabled(false)
498            .build();
499
500        let result = rewriter.rewrite("SELECT * FROM users").unwrap();
501        assert!(!result.was_rewritten());
502        assert_eq!(result.query(), "SELECT * FROM users");
503    }
504
505    #[test]
506    fn test_rewriter_add_limit() {
507        let rewriter = QueryRewriter::builder()
508            .enabled(true)
509            .rule(RewriteRule::new("add_limit")
510                .pattern(QueryPattern::All)
511                .transform(Transformation::AddLimit(100))
512                .condition(Condition::NoExistingLimit))
513            .build();
514
515        let result = rewriter.rewrite("SELECT * FROM users").unwrap();
516        assert!(result.was_rewritten());
517        assert!(result.rewritten.contains("LIMIT 100"));
518    }
519
520    #[test]
521    fn test_rewriter_skip_existing_limit() {
522        let rewriter = QueryRewriter::builder()
523            .enabled(true)
524            .rule(RewriteRule::new("add_limit")
525                .pattern(QueryPattern::All)
526                .transform(Transformation::AddLimit(100))
527                .condition(Condition::NoExistingLimit))
528            .build();
529
530        let result = rewriter.rewrite("SELECT * FROM users LIMIT 50").unwrap();
531        assert!(!result.was_rewritten());
532    }
533
534    #[test]
535    fn test_rewriter_replace_query() {
536        let rewriter = QueryRewriter::builder()
537            .enabled(true)
538            .rule(RewriteRule::new("replace")
539                .pattern(QueryPattern::Fingerprint(12345))
540                .transform(Transformation::Replace("SELECT 1".to_string())))
541            .build();
542
543        // This won't match because fingerprint doesn't match
544        let result = rewriter.rewrite("SELECT * FROM users").unwrap();
545        assert!(!result.was_rewritten());
546    }
547
548    #[test]
549    fn test_add_remove_rule() {
550        let rewriter = QueryRewriter::builder()
551            .enabled(true)
552            .build();
553
554        assert!(rewriter.get_rules().is_empty());
555
556        rewriter.add_rule(RewriteRule::new("test")
557            .pattern(QueryPattern::All)
558            .transform(Transformation::AddLimit(100)));
559
560        assert_eq!(rewriter.get_rules().len(), 1);
561
562        assert!(rewriter.remove_rule("test"));
563        assert!(rewriter.get_rules().is_empty());
564    }
565
566    #[test]
567    fn test_update_rule() {
568        let rewriter = QueryRewriter::builder()
569            .enabled(true)
570            .rule(RewriteRule::new("test")
571                .pattern(QueryPattern::All)
572                .transform(Transformation::AddLimit(100)))
573            .build();
574
575        assert!(rewriter.get_rule("test").unwrap().enabled);
576
577        rewriter.set_rule_enabled("test", false);
578
579        assert!(!rewriter.get_rule("test").unwrap().enabled);
580    }
581}