Skip to main content

heliosdb_proxy/rewriter/
transformer.rs

1//! Transformation Engine
2//!
3//! Applies transformations to SQL queries.
4
5use super::rules::Transformation;
6use regex::Regex;
7
8/// Transformation engine
9pub struct TransformationEngine {
10    /// Custom transformation functions
11    custom_functions: std::collections::HashMap<String, Box<dyn CustomTransform>>,
12}
13
14impl TransformationEngine {
15    /// Create a new transformation engine
16    pub fn new() -> Self {
17        Self {
18            custom_functions: std::collections::HashMap::new(),
19        }
20    }
21
22    /// Register a custom transformation function
23    pub fn register_custom(&mut self, name: String, transform: Box<dyn CustomTransform>) {
24        self.custom_functions.insert(name, transform);
25    }
26
27    /// Apply a transformation to a query
28    pub fn apply(
29        &self,
30        query: &str,
31        transformation: &Transformation,
32    ) -> Result<String, TransformError> {
33        match transformation {
34            Transformation::NoOp => Ok(query.to_string()),
35
36            Transformation::Replace(replacement) => Ok(replacement.clone()),
37
38            Transformation::AddIndexHint { table, index } => {
39                self.add_index_hint(query, table, index)
40            }
41
42            Transformation::ExpandSelectStar { columns } => self.expand_select_star(query, columns),
43
44            Transformation::AddLimit(limit) => self.add_limit(query, *limit),
45
46            Transformation::AddWhereClause(condition) => self.add_where_clause(query, condition),
47
48            Transformation::AppendWhereAnd(condition) => self.append_where_and(query, condition),
49
50            Transformation::ReplaceTable { from, to } => self.replace_table(query, from, to),
51
52            Transformation::AddOrderBy { column, descending } => {
53                self.add_order_by(query, column, *descending)
54            }
55
56            Transformation::AddHint(hint) => Ok(format!("/*{}*/ {}", hint, query)),
57
58            Transformation::AddBranchHint(branch) => {
59                Ok(format!("/*helios:branch={}*/ {}", branch, query))
60            }
61
62            Transformation::AddTimeout(duration) => {
63                let ms = duration.as_millis();
64                Ok(format!("/*helios:timeout={}ms*/ {}", ms, query))
65            }
66
67            Transformation::Custom(name) => {
68                if let Some(transform) = self.custom_functions.get(name) {
69                    transform.transform(query)
70                } else {
71                    Err(TransformError::UnknownCustomFunction(name.clone()))
72                }
73            }
74
75            Transformation::Chain(transformations) => {
76                let mut result = query.to_string();
77                for t in transformations {
78                    result = self.apply(&result, t)?;
79                }
80                Ok(result)
81            }
82        }
83    }
84
85    /// Add index hint to query
86    fn add_index_hint(
87        &self,
88        query: &str,
89        table: &str,
90        index: &str,
91    ) -> Result<String, TransformError> {
92        // PostgreSQL style: /*+ IndexScan(table index) */
93        // Insert after SELECT keyword
94        let upper = query.to_uppercase();
95
96        if let Some(pos) = upper.find("SELECT") {
97            let insert_pos = pos + 6;
98            let hint = format!(" /*+ IndexScan({} {}) */", table, index);
99
100            let mut result = query.to_string();
101            result.insert_str(insert_pos, &hint);
102            Ok(result)
103        } else {
104            // For non-SELECT queries, prepend the hint
105            Ok(format!("/*+ IndexScan({} {}) */ {}", table, index, query))
106        }
107    }
108
109    /// Expand SELECT * to column list
110    fn expand_select_star(
111        &self,
112        query: &str,
113        columns: &[String],
114    ) -> Result<String, TransformError> {
115        // Find SELECT * pattern and replace with column list
116        let re = Regex::new(r"(?i)SELECT\s+(\*|DISTINCT\s+\*|ALL\s+\*)")
117            .map_err(|e| TransformError::RegexError(e.to_string()))?;
118
119        if let Some(caps) = re.find(query) {
120            let matched = caps.as_str();
121            let is_distinct = matched.to_uppercase().contains("DISTINCT");
122            let is_all = matched.to_uppercase().contains("ALL");
123
124            let column_list = columns.join(", ");
125            let replacement = if is_distinct {
126                format!("SELECT DISTINCT {}", column_list)
127            } else if is_all {
128                format!("SELECT ALL {}", column_list)
129            } else {
130                format!("SELECT {}", column_list)
131            };
132
133            Ok(re.replace(query, replacement.as_str()).to_string())
134        } else {
135            // No SELECT * found, return unchanged
136            Ok(query.to_string())
137        }
138    }
139
140    /// Add LIMIT clause
141    fn add_limit(&self, query: &str, limit: u32) -> Result<String, TransformError> {
142        let upper = query.to_uppercase();
143
144        // Don't add if LIMIT already exists
145        if upper.contains(" LIMIT ") {
146            return Ok(query.to_string());
147        }
148
149        // Remove trailing semicolon if present
150        let trimmed = query.trim_end_matches(';').trim();
151
152        // Add LIMIT before potential FOR UPDATE/SHARE clause
153        if upper.contains(" FOR ") {
154            let for_pos = upper.rfind(" FOR ").unwrap();
155            let (before_for, after_for) = trimmed.split_at(for_pos);
156            Ok(format!("{} LIMIT {}{};", before_for, limit, after_for))
157        } else {
158            Ok(format!("{} LIMIT {};", trimmed, limit))
159        }
160    }
161
162    /// Add WHERE clause
163    fn add_where_clause(&self, query: &str, condition: &str) -> Result<String, TransformError> {
164        let upper = query.to_uppercase();
165
166        // Remove trailing semicolon
167        let trimmed = query.trim_end_matches(';').trim();
168
169        if upper.contains(" WHERE ") {
170            // Add to existing WHERE with AND
171            self.append_where_and(trimmed, condition)
172        } else {
173            // Find position to insert WHERE (before GROUP BY, ORDER BY, LIMIT, etc.)
174            let insert_keywords = [" GROUP BY", " ORDER BY", " LIMIT ", " OFFSET ", " FOR "];
175            let mut insert_pos = trimmed.len();
176
177            for keyword in &insert_keywords {
178                if let Some(pos) = upper.find(keyword) {
179                    if pos < insert_pos {
180                        insert_pos = pos;
181                    }
182                }
183            }
184
185            let (before, after) = trimmed.split_at(insert_pos);
186            Ok(format!("{} WHERE {}{};", before, condition, after))
187        }
188    }
189
190    /// Append to existing WHERE clause with AND
191    fn append_where_and(&self, query: &str, condition: &str) -> Result<String, TransformError> {
192        let upper = query.to_uppercase();
193        let trimmed = query.trim_end_matches(';').trim();
194
195        if let Some(where_pos) = upper.find(" WHERE ") {
196            // Find end of WHERE clause
197            let after_where = &upper[where_pos + 7..];
198            let end_keywords = [" GROUP BY", " ORDER BY", " LIMIT ", " OFFSET ", " FOR "];
199
200            let mut end_pos = trimmed.len();
201            for keyword in &end_keywords {
202                if let Some(pos) = after_where.find(keyword) {
203                    let abs_pos = where_pos + 7 + pos;
204                    if abs_pos < end_pos {
205                        end_pos = abs_pos;
206                    }
207                }
208            }
209
210            let (before, after) = trimmed.split_at(end_pos);
211            Ok(format!("{} AND ({}){}; ", before, condition, after))
212        } else {
213            // No WHERE, add new WHERE clause
214            self.add_where_clause(trimmed, condition)
215        }
216    }
217
218    /// Replace table name
219    fn replace_table(&self, query: &str, from: &str, to: &str) -> Result<String, TransformError> {
220        // Use word-boundary aware replacement
221        let pattern = format!(r"\b{}\b", regex::escape(from));
222        let re = Regex::new(&pattern).map_err(|e| TransformError::RegexError(e.to_string()))?;
223
224        Ok(re.replace_all(query, to).to_string())
225    }
226
227    /// Add ORDER BY clause
228    fn add_order_by(
229        &self,
230        query: &str,
231        column: &str,
232        descending: bool,
233    ) -> Result<String, TransformError> {
234        let upper = query.to_uppercase();
235        let trimmed = query.trim_end_matches(';').trim();
236
237        // Don't add if ORDER BY already exists
238        if upper.contains(" ORDER BY ") {
239            return Ok(query.to_string());
240        }
241
242        let direction = if descending { "DESC" } else { "ASC" };
243
244        // Find position to insert (before LIMIT, OFFSET, FOR)
245        let insert_keywords = [" LIMIT ", " OFFSET ", " FOR "];
246        let mut insert_pos = trimmed.len();
247
248        for keyword in &insert_keywords {
249            if let Some(pos) = upper.find(keyword) {
250                if pos < insert_pos {
251                    insert_pos = pos;
252                }
253            }
254        }
255
256        let (before, after) = trimmed.split_at(insert_pos);
257        Ok(format!(
258            "{} ORDER BY {} {}{};",
259            before, column, direction, after
260        ))
261    }
262}
263
264impl Default for TransformationEngine {
265    fn default() -> Self {
266        Self::new()
267    }
268}
269
270/// Custom transformation trait
271pub trait CustomTransform: Send + Sync {
272    /// Transform the query
273    fn transform(&self, query: &str) -> Result<String, TransformError>;
274}
275
276/// Transform error
277#[derive(Debug, Clone)]
278pub enum TransformError {
279    /// Regex error
280    RegexError(String),
281
282    /// Parse error
283    ParseError(String),
284
285    /// Unknown custom function
286    UnknownCustomFunction(String),
287
288    /// Transformation not applicable
289    NotApplicable(String),
290}
291
292impl std::fmt::Display for TransformError {
293    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294        match self {
295            Self::RegexError(msg) => write!(f, "Regex error: {}", msg),
296            Self::ParseError(msg) => write!(f, "Parse error: {}", msg),
297            Self::UnknownCustomFunction(name) => write!(f, "Unknown custom function: {}", name),
298            Self::NotApplicable(msg) => write!(f, "Not applicable: {}", msg),
299        }
300    }
301}
302
303impl std::error::Error for TransformError {}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_add_limit() {
311        let engine = TransformationEngine::new();
312
313        let result = engine.add_limit("SELECT * FROM users", 100).unwrap();
314        assert!(result.contains("LIMIT 100"));
315
316        // Should not add duplicate LIMIT
317        let result2 = engine
318            .add_limit("SELECT * FROM users LIMIT 50", 100)
319            .unwrap();
320        assert!(result2.contains("LIMIT 50"));
321        assert!(!result2.contains("LIMIT 100"));
322    }
323
324    #[test]
325    fn test_add_where() {
326        let engine = TransformationEngine::new();
327
328        let result = engine
329            .add_where_clause("SELECT * FROM users", "active = true")
330            .unwrap();
331        assert!(result.contains("WHERE active = true"));
332
333        // Should add AND to existing WHERE
334        let result2 = engine
335            .add_where_clause("SELECT * FROM users WHERE id = 1", "active = true")
336            .unwrap();
337        assert!(result2.contains("AND (active = true)"));
338    }
339
340    #[test]
341    fn test_replace_table() {
342        let engine = TransformationEngine::new();
343
344        let result = engine
345            .replace_table("SELECT * FROM old_users", "old_users", "users_v2")
346            .unwrap();
347        assert!(result.contains("users_v2"));
348        assert!(!result.contains("old_users"));
349    }
350
351    #[test]
352    fn test_expand_select_star() {
353        let engine = TransformationEngine::new();
354
355        let result = engine
356            .expand_select_star(
357                "SELECT * FROM users",
358                &["id".to_string(), "name".to_string(), "email".to_string()],
359            )
360            .unwrap();
361
362        assert!(result.contains("id, name, email"));
363        assert!(!result.contains("*"));
364    }
365
366    #[test]
367    fn test_expand_select_distinct_star() {
368        let engine = TransformationEngine::new();
369
370        let result = engine
371            .expand_select_star(
372                "SELECT DISTINCT * FROM users",
373                &["id".to_string(), "name".to_string()],
374            )
375            .unwrap();
376
377        assert!(result.contains("SELECT DISTINCT id, name"));
378    }
379
380    #[test]
381    fn test_add_index_hint() {
382        let engine = TransformationEngine::new();
383
384        let result = engine
385            .add_index_hint("SELECT * FROM users WHERE id = 1", "users", "idx_users_id")
386            .unwrap();
387        assert!(result.contains("IndexScan(users idx_users_id)"));
388    }
389
390    #[test]
391    fn test_add_order_by() {
392        let engine = TransformationEngine::new();
393
394        let result = engine
395            .add_order_by("SELECT * FROM users", "created_at", true)
396            .unwrap();
397        assert!(result.contains("ORDER BY created_at DESC"));
398    }
399
400    #[test]
401    fn test_add_hint() {
402        let engine = TransformationEngine::new();
403
404        let result = engine
405            .apply(
406                "SELECT * FROM users",
407                &Transformation::AddHint("parallel=4".to_string()),
408            )
409            .unwrap();
410        assert!(result.contains("/*parallel=4*/"));
411    }
412
413    #[test]
414    fn test_add_branch_hint() {
415        let engine = TransformationEngine::new();
416
417        let result = engine
418            .apply(
419                "SELECT * FROM analytics",
420                &Transformation::AddBranchHint("analytics".to_string()),
421            )
422            .unwrap();
423        assert!(result.contains("/*helios:branch=analytics*/"));
424    }
425
426    #[test]
427    fn test_chain_transformations() {
428        let engine = TransformationEngine::new();
429
430        let result = engine
431            .apply(
432                "SELECT * FROM users",
433                &Transformation::Chain(vec![
434                    Transformation::AddLimit(100),
435                    Transformation::AddOrderBy {
436                        column: "id".to_string(),
437                        descending: false,
438                    },
439                ]),
440            )
441            .unwrap();
442
443        assert!(result.contains("LIMIT 100"));
444        assert!(result.contains("ORDER BY id ASC"));
445    }
446}