1use super::rules::Transformation;
6use regex::Regex;
7
8pub struct TransformationEngine {
10 custom_functions: std::collections::HashMap<String, Box<dyn CustomTransform>>,
12}
13
14impl TransformationEngine {
15 pub fn new() -> Self {
17 Self {
18 custom_functions: std::collections::HashMap::new(),
19 }
20 }
21
22 pub fn register_custom(&mut self, name: String, transform: Box<dyn CustomTransform>) {
24 self.custom_functions.insert(name, transform);
25 }
26
27 pub fn apply(
29 &self,
30 query: &str,
31 transformation: &Transformation,
32 ) -> Result<String, TransformError> {
33 match transformation {
34 Transformation::NoOp => Ok(query.to_string()),
35
36 Transformation::Replace(replacement) => Ok(replacement.clone()),
37
38 Transformation::AddIndexHint { table, index } => {
39 self.add_index_hint(query, table, index)
40 }
41
42 Transformation::ExpandSelectStar { columns } => self.expand_select_star(query, columns),
43
44 Transformation::AddLimit(limit) => self.add_limit(query, *limit),
45
46 Transformation::AddWhereClause(condition) => self.add_where_clause(query, condition),
47
48 Transformation::AppendWhereAnd(condition) => self.append_where_and(query, condition),
49
50 Transformation::ReplaceTable { from, to } => self.replace_table(query, from, to),
51
52 Transformation::AddOrderBy { column, descending } => {
53 self.add_order_by(query, column, *descending)
54 }
55
56 Transformation::AddHint(hint) => Ok(format!("/*{}*/ {}", hint, query)),
57
58 Transformation::AddBranchHint(branch) => {
59 Ok(format!("/*helios:branch={}*/ {}", branch, query))
60 }
61
62 Transformation::AddTimeout(duration) => {
63 let ms = duration.as_millis();
64 Ok(format!("/*helios:timeout={}ms*/ {}", ms, query))
65 }
66
67 Transformation::Custom(name) => {
68 if let Some(transform) = self.custom_functions.get(name) {
69 transform.transform(query)
70 } else {
71 Err(TransformError::UnknownCustomFunction(name.clone()))
72 }
73 }
74
75 Transformation::Chain(transformations) => {
76 let mut result = query.to_string();
77 for t in transformations {
78 result = self.apply(&result, t)?;
79 }
80 Ok(result)
81 }
82 }
83 }
84
85 fn add_index_hint(
87 &self,
88 query: &str,
89 table: &str,
90 index: &str,
91 ) -> Result<String, TransformError> {
92 let upper = query.to_uppercase();
95
96 if let Some(pos) = upper.find("SELECT") {
97 let insert_pos = pos + 6;
98 let hint = format!(" /*+ IndexScan({} {}) */", table, index);
99
100 let mut result = query.to_string();
101 result.insert_str(insert_pos, &hint);
102 Ok(result)
103 } else {
104 Ok(format!("/*+ IndexScan({} {}) */ {}", table, index, query))
106 }
107 }
108
109 fn expand_select_star(
111 &self,
112 query: &str,
113 columns: &[String],
114 ) -> Result<String, TransformError> {
115 let re = Regex::new(r"(?i)SELECT\s+(\*|DISTINCT\s+\*|ALL\s+\*)")
117 .map_err(|e| TransformError::RegexError(e.to_string()))?;
118
119 if let Some(caps) = re.find(query) {
120 let matched = caps.as_str();
121 let is_distinct = matched.to_uppercase().contains("DISTINCT");
122 let is_all = matched.to_uppercase().contains("ALL");
123
124 let column_list = columns.join(", ");
125 let replacement = if is_distinct {
126 format!("SELECT DISTINCT {}", column_list)
127 } else if is_all {
128 format!("SELECT ALL {}", column_list)
129 } else {
130 format!("SELECT {}", column_list)
131 };
132
133 Ok(re.replace(query, replacement.as_str()).to_string())
134 } else {
135 Ok(query.to_string())
137 }
138 }
139
140 fn add_limit(&self, query: &str, limit: u32) -> Result<String, TransformError> {
142 let upper = query.to_uppercase();
143
144 if upper.contains(" LIMIT ") {
146 return Ok(query.to_string());
147 }
148
149 let trimmed = query.trim_end_matches(';').trim();
151
152 if upper.contains(" FOR ") {
154 let for_pos = upper.rfind(" FOR ").unwrap();
155 let (before_for, after_for) = trimmed.split_at(for_pos);
156 Ok(format!("{} LIMIT {}{};", before_for, limit, after_for))
157 } else {
158 Ok(format!("{} LIMIT {};", trimmed, limit))
159 }
160 }
161
162 fn add_where_clause(&self, query: &str, condition: &str) -> Result<String, TransformError> {
164 let upper = query.to_uppercase();
165
166 let trimmed = query.trim_end_matches(';').trim();
168
169 if upper.contains(" WHERE ") {
170 self.append_where_and(trimmed, condition)
172 } else {
173 let insert_keywords = [" GROUP BY", " ORDER BY", " LIMIT ", " OFFSET ", " FOR "];
175 let mut insert_pos = trimmed.len();
176
177 for keyword in &insert_keywords {
178 if let Some(pos) = upper.find(keyword) {
179 if pos < insert_pos {
180 insert_pos = pos;
181 }
182 }
183 }
184
185 let (before, after) = trimmed.split_at(insert_pos);
186 Ok(format!("{} WHERE {}{};", before, condition, after))
187 }
188 }
189
190 fn append_where_and(&self, query: &str, condition: &str) -> Result<String, TransformError> {
192 let upper = query.to_uppercase();
193 let trimmed = query.trim_end_matches(';').trim();
194
195 if let Some(where_pos) = upper.find(" WHERE ") {
196 let after_where = &upper[where_pos + 7..];
198 let end_keywords = [" GROUP BY", " ORDER BY", " LIMIT ", " OFFSET ", " FOR "];
199
200 let mut end_pos = trimmed.len();
201 for keyword in &end_keywords {
202 if let Some(pos) = after_where.find(keyword) {
203 let abs_pos = where_pos + 7 + pos;
204 if abs_pos < end_pos {
205 end_pos = abs_pos;
206 }
207 }
208 }
209
210 let (before, after) = trimmed.split_at(end_pos);
211 Ok(format!("{} AND ({}){}; ", before, condition, after))
212 } else {
213 self.add_where_clause(trimmed, condition)
215 }
216 }
217
218 fn replace_table(&self, query: &str, from: &str, to: &str) -> Result<String, TransformError> {
220 let pattern = format!(r"\b{}\b", regex::escape(from));
222 let re = Regex::new(&pattern).map_err(|e| TransformError::RegexError(e.to_string()))?;
223
224 Ok(re.replace_all(query, to).to_string())
225 }
226
227 fn add_order_by(
229 &self,
230 query: &str,
231 column: &str,
232 descending: bool,
233 ) -> Result<String, TransformError> {
234 let upper = query.to_uppercase();
235 let trimmed = query.trim_end_matches(';').trim();
236
237 if upper.contains(" ORDER BY ") {
239 return Ok(query.to_string());
240 }
241
242 let direction = if descending { "DESC" } else { "ASC" };
243
244 let insert_keywords = [" LIMIT ", " OFFSET ", " FOR "];
246 let mut insert_pos = trimmed.len();
247
248 for keyword in &insert_keywords {
249 if let Some(pos) = upper.find(keyword) {
250 if pos < insert_pos {
251 insert_pos = pos;
252 }
253 }
254 }
255
256 let (before, after) = trimmed.split_at(insert_pos);
257 Ok(format!(
258 "{} ORDER BY {} {}{};",
259 before, column, direction, after
260 ))
261 }
262}
263
264impl Default for TransformationEngine {
265 fn default() -> Self {
266 Self::new()
267 }
268}
269
270pub trait CustomTransform: Send + Sync {
272 fn transform(&self, query: &str) -> Result<String, TransformError>;
274}
275
276#[derive(Debug, Clone)]
278pub enum TransformError {
279 RegexError(String),
281
282 ParseError(String),
284
285 UnknownCustomFunction(String),
287
288 NotApplicable(String),
290}
291
292impl std::fmt::Display for TransformError {
293 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294 match self {
295 Self::RegexError(msg) => write!(f, "Regex error: {}", msg),
296 Self::ParseError(msg) => write!(f, "Parse error: {}", msg),
297 Self::UnknownCustomFunction(name) => write!(f, "Unknown custom function: {}", name),
298 Self::NotApplicable(msg) => write!(f, "Not applicable: {}", msg),
299 }
300 }
301}
302
303impl std::error::Error for TransformError {}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn test_add_limit() {
311 let engine = TransformationEngine::new();
312
313 let result = engine.add_limit("SELECT * FROM users", 100).unwrap();
314 assert!(result.contains("LIMIT 100"));
315
316 let result2 = engine
318 .add_limit("SELECT * FROM users LIMIT 50", 100)
319 .unwrap();
320 assert!(result2.contains("LIMIT 50"));
321 assert!(!result2.contains("LIMIT 100"));
322 }
323
324 #[test]
325 fn test_add_where() {
326 let engine = TransformationEngine::new();
327
328 let result = engine
329 .add_where_clause("SELECT * FROM users", "active = true")
330 .unwrap();
331 assert!(result.contains("WHERE active = true"));
332
333 let result2 = engine
335 .add_where_clause("SELECT * FROM users WHERE id = 1", "active = true")
336 .unwrap();
337 assert!(result2.contains("AND (active = true)"));
338 }
339
340 #[test]
341 fn test_replace_table() {
342 let engine = TransformationEngine::new();
343
344 let result = engine
345 .replace_table("SELECT * FROM old_users", "old_users", "users_v2")
346 .unwrap();
347 assert!(result.contains("users_v2"));
348 assert!(!result.contains("old_users"));
349 }
350
351 #[test]
352 fn test_expand_select_star() {
353 let engine = TransformationEngine::new();
354
355 let result = engine
356 .expand_select_star(
357 "SELECT * FROM users",
358 &["id".to_string(), "name".to_string(), "email".to_string()],
359 )
360 .unwrap();
361
362 assert!(result.contains("id, name, email"));
363 assert!(!result.contains("*"));
364 }
365
366 #[test]
367 fn test_expand_select_distinct_star() {
368 let engine = TransformationEngine::new();
369
370 let result = engine
371 .expand_select_star(
372 "SELECT DISTINCT * FROM users",
373 &["id".to_string(), "name".to_string()],
374 )
375 .unwrap();
376
377 assert!(result.contains("SELECT DISTINCT id, name"));
378 }
379
380 #[test]
381 fn test_add_index_hint() {
382 let engine = TransformationEngine::new();
383
384 let result = engine
385 .add_index_hint("SELECT * FROM users WHERE id = 1", "users", "idx_users_id")
386 .unwrap();
387 assert!(result.contains("IndexScan(users idx_users_id)"));
388 }
389
390 #[test]
391 fn test_add_order_by() {
392 let engine = TransformationEngine::new();
393
394 let result = engine
395 .add_order_by("SELECT * FROM users", "created_at", true)
396 .unwrap();
397 assert!(result.contains("ORDER BY created_at DESC"));
398 }
399
400 #[test]
401 fn test_add_hint() {
402 let engine = TransformationEngine::new();
403
404 let result = engine
405 .apply(
406 "SELECT * FROM users",
407 &Transformation::AddHint("parallel=4".to_string()),
408 )
409 .unwrap();
410 assert!(result.contains("/*parallel=4*/"));
411 }
412
413 #[test]
414 fn test_add_branch_hint() {
415 let engine = TransformationEngine::new();
416
417 let result = engine
418 .apply(
419 "SELECT * FROM analytics",
420 &Transformation::AddBranchHint("analytics".to_string()),
421 )
422 .unwrap();
423 assert!(result.contains("/*helios:branch=analytics*/"));
424 }
425
426 #[test]
427 fn test_chain_transformations() {
428 let engine = TransformationEngine::new();
429
430 let result = engine
431 .apply(
432 "SELECT * FROM users",
433 &Transformation::Chain(vec![
434 Transformation::AddLimit(100),
435 Transformation::AddOrderBy {
436 column: "id".to_string(),
437 descending: false,
438 },
439 ]),
440 )
441 .unwrap();
442
443 assert!(result.contains("LIMIT 100"));
444 assert!(result.contains("ORDER BY id ASC"));
445 }
446}