Skip to main content

heliosdb_proxy/rewriter/
transformer.rs

1//! Transformation Engine
2//!
3//! Applies transformations to SQL queries.
4
5use super::rules::Transformation;
6use regex::Regex;
7
8/// Transformation engine
9pub struct TransformationEngine {
10    /// Custom transformation functions
11    custom_functions: std::collections::HashMap<String, Box<dyn CustomTransform>>,
12}
13
14impl TransformationEngine {
15    /// Create a new transformation engine
16    pub fn new() -> Self {
17        Self {
18            custom_functions: std::collections::HashMap::new(),
19        }
20    }
21
22    /// Register a custom transformation function
23    pub fn register_custom(&mut self, name: String, transform: Box<dyn CustomTransform>) {
24        self.custom_functions.insert(name, transform);
25    }
26
27    /// Apply a transformation to a query
28    pub fn apply(&self, query: &str, transformation: &Transformation) -> Result<String, TransformError> {
29        match transformation {
30            Transformation::NoOp => Ok(query.to_string()),
31
32            Transformation::Replace(replacement) => {
33                Ok(replacement.clone())
34            }
35
36            Transformation::AddIndexHint { table, index } => {
37                self.add_index_hint(query, table, index)
38            }
39
40            Transformation::ExpandSelectStar { columns } => {
41                self.expand_select_star(query, columns)
42            }
43
44            Transformation::AddLimit(limit) => {
45                self.add_limit(query, *limit)
46            }
47
48            Transformation::AddWhereClause(condition) => {
49                self.add_where_clause(query, condition)
50            }
51
52            Transformation::AppendWhereAnd(condition) => {
53                self.append_where_and(query, condition)
54            }
55
56            Transformation::ReplaceTable { from, to } => {
57                self.replace_table(query, from, to)
58            }
59
60            Transformation::AddOrderBy { column, descending } => {
61                self.add_order_by(query, column, *descending)
62            }
63
64            Transformation::AddHint(hint) => {
65                Ok(format!("/*{}*/ {}", hint, query))
66            }
67
68            Transformation::AddBranchHint(branch) => {
69                Ok(format!("/*helios:branch={}*/ {}", branch, query))
70            }
71
72            Transformation::AddTimeout(duration) => {
73                let ms = duration.as_millis();
74                Ok(format!("/*helios:timeout={}ms*/ {}", ms, query))
75            }
76
77            Transformation::Custom(name) => {
78                if let Some(transform) = self.custom_functions.get(name) {
79                    transform.transform(query)
80                } else {
81                    Err(TransformError::UnknownCustomFunction(name.clone()))
82                }
83            }
84
85            Transformation::Chain(transformations) => {
86                let mut result = query.to_string();
87                for t in transformations {
88                    result = self.apply(&result, t)?;
89                }
90                Ok(result)
91            }
92        }
93    }
94
95    /// Add index hint to query
96    fn add_index_hint(&self, query: &str, table: &str, index: &str) -> Result<String, TransformError> {
97        // PostgreSQL style: /*+ IndexScan(table index) */
98        // Insert after SELECT keyword
99        let upper = query.to_uppercase();
100
101        if let Some(pos) = upper.find("SELECT") {
102            let insert_pos = pos + 6;
103            let hint = format!(" /*+ IndexScan({} {}) */", table, index);
104
105            let mut result = query.to_string();
106            result.insert_str(insert_pos, &hint);
107            Ok(result)
108        } else {
109            // For non-SELECT queries, prepend the hint
110            Ok(format!("/*+ IndexScan({} {}) */ {}", table, index, query))
111        }
112    }
113
114    /// Expand SELECT * to column list
115    fn expand_select_star(&self, query: &str, columns: &[String]) -> Result<String, TransformError> {
116        // Find SELECT * pattern and replace with column list
117        let re = Regex::new(r"(?i)SELECT\s+(\*|DISTINCT\s+\*|ALL\s+\*)")
118            .map_err(|e| TransformError::RegexError(e.to_string()))?;
119
120        if let Some(caps) = re.find(query) {
121            let matched = caps.as_str();
122            let is_distinct = matched.to_uppercase().contains("DISTINCT");
123            let is_all = matched.to_uppercase().contains("ALL");
124
125            let column_list = columns.join(", ");
126            let replacement = if is_distinct {
127                format!("SELECT DISTINCT {}", column_list)
128            } else if is_all {
129                format!("SELECT ALL {}", column_list)
130            } else {
131                format!("SELECT {}", column_list)
132            };
133
134            Ok(re.replace(query, replacement.as_str()).to_string())
135        } else {
136            // No SELECT * found, return unchanged
137            Ok(query.to_string())
138        }
139    }
140
141    /// Add LIMIT clause
142    fn add_limit(&self, query: &str, limit: u32) -> Result<String, TransformError> {
143        let upper = query.to_uppercase();
144
145        // Don't add if LIMIT already exists
146        if upper.contains(" LIMIT ") {
147            return Ok(query.to_string());
148        }
149
150        // Remove trailing semicolon if present
151        let trimmed = query.trim_end_matches(';').trim();
152
153        // Add LIMIT before potential FOR UPDATE/SHARE clause
154        if upper.contains(" FOR ") {
155            let for_pos = upper.rfind(" FOR ").unwrap();
156            let (before_for, after_for) = trimmed.split_at(for_pos);
157            Ok(format!("{} LIMIT {}{};", before_for, limit, after_for))
158        } else {
159            Ok(format!("{} LIMIT {};", trimmed, limit))
160        }
161    }
162
163    /// Add WHERE clause
164    fn add_where_clause(&self, query: &str, condition: &str) -> Result<String, TransformError> {
165        let upper = query.to_uppercase();
166
167        // Remove trailing semicolon
168        let trimmed = query.trim_end_matches(';').trim();
169
170        if upper.contains(" WHERE ") {
171            // Add to existing WHERE with AND
172            self.append_where_and(trimmed, condition)
173        } else {
174            // Find position to insert WHERE (before GROUP BY, ORDER BY, LIMIT, etc.)
175            let insert_keywords = [" GROUP BY", " ORDER BY", " LIMIT ", " OFFSET ", " FOR "];
176            let mut insert_pos = trimmed.len();
177
178            for keyword in &insert_keywords {
179                if let Some(pos) = upper.find(keyword) {
180                    if pos < insert_pos {
181                        insert_pos = pos;
182                    }
183                }
184            }
185
186            let (before, after) = trimmed.split_at(insert_pos);
187            Ok(format!("{} WHERE {}{};", before, condition, after))
188        }
189    }
190
191    /// Append to existing WHERE clause with AND
192    fn append_where_and(&self, query: &str, condition: &str) -> Result<String, TransformError> {
193        let upper = query.to_uppercase();
194        let trimmed = query.trim_end_matches(';').trim();
195
196        if let Some(where_pos) = upper.find(" WHERE ") {
197            // Find end of WHERE clause
198            let after_where = &upper[where_pos + 7..];
199            let end_keywords = [" GROUP BY", " ORDER BY", " LIMIT ", " OFFSET ", " FOR "];
200
201            let mut end_pos = trimmed.len();
202            for keyword in &end_keywords {
203                if let Some(pos) = after_where.find(keyword) {
204                    let abs_pos = where_pos + 7 + pos;
205                    if abs_pos < end_pos {
206                        end_pos = abs_pos;
207                    }
208                }
209            }
210
211            let (before, after) = trimmed.split_at(end_pos);
212            Ok(format!("{} AND ({}){}; ", before, condition, after))
213        } else {
214            // No WHERE, add new WHERE clause
215            self.add_where_clause(trimmed, condition)
216        }
217    }
218
219    /// Replace table name
220    fn replace_table(&self, query: &str, from: &str, to: &str) -> Result<String, TransformError> {
221        // Use word-boundary aware replacement
222        let pattern = format!(r"\b{}\b", regex::escape(from));
223        let re = Regex::new(&pattern)
224            .map_err(|e| TransformError::RegexError(e.to_string()))?;
225
226        Ok(re.replace_all(query, to).to_string())
227    }
228
229    /// Add ORDER BY clause
230    fn add_order_by(&self, query: &str, column: &str, descending: bool) -> Result<String, TransformError> {
231        let upper = query.to_uppercase();
232        let trimmed = query.trim_end_matches(';').trim();
233
234        // Don't add if ORDER BY already exists
235        if upper.contains(" ORDER BY ") {
236            return Ok(query.to_string());
237        }
238
239        let direction = if descending { "DESC" } else { "ASC" };
240
241        // Find position to insert (before LIMIT, OFFSET, FOR)
242        let insert_keywords = [" LIMIT ", " OFFSET ", " FOR "];
243        let mut insert_pos = trimmed.len();
244
245        for keyword in &insert_keywords {
246            if let Some(pos) = upper.find(keyword) {
247                if pos < insert_pos {
248                    insert_pos = pos;
249                }
250            }
251        }
252
253        let (before, after) = trimmed.split_at(insert_pos);
254        Ok(format!("{} ORDER BY {} {}{};", before, column, direction, after))
255    }
256}
257
258impl Default for TransformationEngine {
259    fn default() -> Self {
260        Self::new()
261    }
262}
263
264/// Custom transformation trait
265pub trait CustomTransform: Send + Sync {
266    /// Transform the query
267    fn transform(&self, query: &str) -> Result<String, TransformError>;
268}
269
270/// Transform error
271#[derive(Debug, Clone)]
272pub enum TransformError {
273    /// Regex error
274    RegexError(String),
275
276    /// Parse error
277    ParseError(String),
278
279    /// Unknown custom function
280    UnknownCustomFunction(String),
281
282    /// Transformation not applicable
283    NotApplicable(String),
284}
285
286impl std::fmt::Display for TransformError {
287    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288        match self {
289            Self::RegexError(msg) => write!(f, "Regex error: {}", msg),
290            Self::ParseError(msg) => write!(f, "Parse error: {}", msg),
291            Self::UnknownCustomFunction(name) => write!(f, "Unknown custom function: {}", name),
292            Self::NotApplicable(msg) => write!(f, "Not applicable: {}", msg),
293        }
294    }
295}
296
297impl std::error::Error for TransformError {}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_add_limit() {
305        let engine = TransformationEngine::new();
306
307        let result = engine.add_limit("SELECT * FROM users", 100).unwrap();
308        assert!(result.contains("LIMIT 100"));
309
310        // Should not add duplicate LIMIT
311        let result2 = engine.add_limit("SELECT * FROM users LIMIT 50", 100).unwrap();
312        assert!(result2.contains("LIMIT 50"));
313        assert!(!result2.contains("LIMIT 100"));
314    }
315
316    #[test]
317    fn test_add_where() {
318        let engine = TransformationEngine::new();
319
320        let result = engine.add_where_clause("SELECT * FROM users", "active = true").unwrap();
321        assert!(result.contains("WHERE active = true"));
322
323        // Should add AND to existing WHERE
324        let result2 = engine.add_where_clause("SELECT * FROM users WHERE id = 1", "active = true").unwrap();
325        assert!(result2.contains("AND (active = true)"));
326    }
327
328    #[test]
329    fn test_replace_table() {
330        let engine = TransformationEngine::new();
331
332        let result = engine.replace_table("SELECT * FROM old_users", "old_users", "users_v2").unwrap();
333        assert!(result.contains("users_v2"));
334        assert!(!result.contains("old_users"));
335    }
336
337    #[test]
338    fn test_expand_select_star() {
339        let engine = TransformationEngine::new();
340
341        let result = engine.expand_select_star(
342            "SELECT * FROM users",
343            &["id".to_string(), "name".to_string(), "email".to_string()]
344        ).unwrap();
345
346        assert!(result.contains("id, name, email"));
347        assert!(!result.contains("*"));
348    }
349
350    #[test]
351    fn test_expand_select_distinct_star() {
352        let engine = TransformationEngine::new();
353
354        let result = engine.expand_select_star(
355            "SELECT DISTINCT * FROM users",
356            &["id".to_string(), "name".to_string()]
357        ).unwrap();
358
359        assert!(result.contains("SELECT DISTINCT id, name"));
360    }
361
362    #[test]
363    fn test_add_index_hint() {
364        let engine = TransformationEngine::new();
365
366        let result = engine.add_index_hint("SELECT * FROM users WHERE id = 1", "users", "idx_users_id").unwrap();
367        assert!(result.contains("IndexScan(users idx_users_id)"));
368    }
369
370    #[test]
371    fn test_add_order_by() {
372        let engine = TransformationEngine::new();
373
374        let result = engine.add_order_by("SELECT * FROM users", "created_at", true).unwrap();
375        assert!(result.contains("ORDER BY created_at DESC"));
376    }
377
378    #[test]
379    fn test_add_hint() {
380        let engine = TransformationEngine::new();
381
382        let result = engine.apply("SELECT * FROM users", &Transformation::AddHint("parallel=4".to_string())).unwrap();
383        assert!(result.contains("/*parallel=4*/"));
384    }
385
386    #[test]
387    fn test_add_branch_hint() {
388        let engine = TransformationEngine::new();
389
390        let result = engine.apply("SELECT * FROM analytics", &Transformation::AddBranchHint("analytics".to_string())).unwrap();
391        assert!(result.contains("/*helios:branch=analytics*/"));
392    }
393
394    #[test]
395    fn test_chain_transformations() {
396        let engine = TransformationEngine::new();
397
398        let result = engine.apply(
399            "SELECT * FROM users",
400            &Transformation::Chain(vec![
401                Transformation::AddLimit(100),
402                Transformation::AddOrderBy {
403                    column: "id".to_string(),
404                    descending: false,
405                },
406            ])
407        ).unwrap();
408
409        assert!(result.contains("LIMIT 100"));
410        assert!(result.contains("ORDER BY id ASC"));
411    }
412}