1use super::rules::Transformation;
6use regex::Regex;
7
8pub struct TransformationEngine {
10 custom_functions: std::collections::HashMap<String, Box<dyn CustomTransform>>,
12}
13
14impl TransformationEngine {
15 pub fn new() -> Self {
17 Self {
18 custom_functions: std::collections::HashMap::new(),
19 }
20 }
21
22 pub fn register_custom(&mut self, name: String, transform: Box<dyn CustomTransform>) {
24 self.custom_functions.insert(name, transform);
25 }
26
27 pub fn apply(&self, query: &str, transformation: &Transformation) -> Result<String, TransformError> {
29 match transformation {
30 Transformation::NoOp => Ok(query.to_string()),
31
32 Transformation::Replace(replacement) => {
33 Ok(replacement.clone())
34 }
35
36 Transformation::AddIndexHint { table, index } => {
37 self.add_index_hint(query, table, index)
38 }
39
40 Transformation::ExpandSelectStar { columns } => {
41 self.expand_select_star(query, columns)
42 }
43
44 Transformation::AddLimit(limit) => {
45 self.add_limit(query, *limit)
46 }
47
48 Transformation::AddWhereClause(condition) => {
49 self.add_where_clause(query, condition)
50 }
51
52 Transformation::AppendWhereAnd(condition) => {
53 self.append_where_and(query, condition)
54 }
55
56 Transformation::ReplaceTable { from, to } => {
57 self.replace_table(query, from, to)
58 }
59
60 Transformation::AddOrderBy { column, descending } => {
61 self.add_order_by(query, column, *descending)
62 }
63
64 Transformation::AddHint(hint) => {
65 Ok(format!("/*{}*/ {}", hint, query))
66 }
67
68 Transformation::AddBranchHint(branch) => {
69 Ok(format!("/*helios:branch={}*/ {}", branch, query))
70 }
71
72 Transformation::AddTimeout(duration) => {
73 let ms = duration.as_millis();
74 Ok(format!("/*helios:timeout={}ms*/ {}", ms, query))
75 }
76
77 Transformation::Custom(name) => {
78 if let Some(transform) = self.custom_functions.get(name) {
79 transform.transform(query)
80 } else {
81 Err(TransformError::UnknownCustomFunction(name.clone()))
82 }
83 }
84
85 Transformation::Chain(transformations) => {
86 let mut result = query.to_string();
87 for t in transformations {
88 result = self.apply(&result, t)?;
89 }
90 Ok(result)
91 }
92 }
93 }
94
95 fn add_index_hint(&self, query: &str, table: &str, index: &str) -> Result<String, TransformError> {
97 let upper = query.to_uppercase();
100
101 if let Some(pos) = upper.find("SELECT") {
102 let insert_pos = pos + 6;
103 let hint = format!(" /*+ IndexScan({} {}) */", table, index);
104
105 let mut result = query.to_string();
106 result.insert_str(insert_pos, &hint);
107 Ok(result)
108 } else {
109 Ok(format!("/*+ IndexScan({} {}) */ {}", table, index, query))
111 }
112 }
113
114 fn expand_select_star(&self, query: &str, columns: &[String]) -> Result<String, TransformError> {
116 let re = Regex::new(r"(?i)SELECT\s+(\*|DISTINCT\s+\*|ALL\s+\*)")
118 .map_err(|e| TransformError::RegexError(e.to_string()))?;
119
120 if let Some(caps) = re.find(query) {
121 let matched = caps.as_str();
122 let is_distinct = matched.to_uppercase().contains("DISTINCT");
123 let is_all = matched.to_uppercase().contains("ALL");
124
125 let column_list = columns.join(", ");
126 let replacement = if is_distinct {
127 format!("SELECT DISTINCT {}", column_list)
128 } else if is_all {
129 format!("SELECT ALL {}", column_list)
130 } else {
131 format!("SELECT {}", column_list)
132 };
133
134 Ok(re.replace(query, replacement.as_str()).to_string())
135 } else {
136 Ok(query.to_string())
138 }
139 }
140
141 fn add_limit(&self, query: &str, limit: u32) -> Result<String, TransformError> {
143 let upper = query.to_uppercase();
144
145 if upper.contains(" LIMIT ") {
147 return Ok(query.to_string());
148 }
149
150 let trimmed = query.trim_end_matches(';').trim();
152
153 if upper.contains(" FOR ") {
155 let for_pos = upper.rfind(" FOR ").unwrap();
156 let (before_for, after_for) = trimmed.split_at(for_pos);
157 Ok(format!("{} LIMIT {}{};", before_for, limit, after_for))
158 } else {
159 Ok(format!("{} LIMIT {};", trimmed, limit))
160 }
161 }
162
163 fn add_where_clause(&self, query: &str, condition: &str) -> Result<String, TransformError> {
165 let upper = query.to_uppercase();
166
167 let trimmed = query.trim_end_matches(';').trim();
169
170 if upper.contains(" WHERE ") {
171 self.append_where_and(trimmed, condition)
173 } else {
174 let insert_keywords = [" GROUP BY", " ORDER BY", " LIMIT ", " OFFSET ", " FOR "];
176 let mut insert_pos = trimmed.len();
177
178 for keyword in &insert_keywords {
179 if let Some(pos) = upper.find(keyword) {
180 if pos < insert_pos {
181 insert_pos = pos;
182 }
183 }
184 }
185
186 let (before, after) = trimmed.split_at(insert_pos);
187 Ok(format!("{} WHERE {}{};", before, condition, after))
188 }
189 }
190
191 fn append_where_and(&self, query: &str, condition: &str) -> Result<String, TransformError> {
193 let upper = query.to_uppercase();
194 let trimmed = query.trim_end_matches(';').trim();
195
196 if let Some(where_pos) = upper.find(" WHERE ") {
197 let after_where = &upper[where_pos + 7..];
199 let end_keywords = [" GROUP BY", " ORDER BY", " LIMIT ", " OFFSET ", " FOR "];
200
201 let mut end_pos = trimmed.len();
202 for keyword in &end_keywords {
203 if let Some(pos) = after_where.find(keyword) {
204 let abs_pos = where_pos + 7 + pos;
205 if abs_pos < end_pos {
206 end_pos = abs_pos;
207 }
208 }
209 }
210
211 let (before, after) = trimmed.split_at(end_pos);
212 Ok(format!("{} AND ({}){}; ", before, condition, after))
213 } else {
214 self.add_where_clause(trimmed, condition)
216 }
217 }
218
219 fn replace_table(&self, query: &str, from: &str, to: &str) -> Result<String, TransformError> {
221 let pattern = format!(r"\b{}\b", regex::escape(from));
223 let re = Regex::new(&pattern)
224 .map_err(|e| TransformError::RegexError(e.to_string()))?;
225
226 Ok(re.replace_all(query, to).to_string())
227 }
228
229 fn add_order_by(&self, query: &str, column: &str, descending: bool) -> Result<String, TransformError> {
231 let upper = query.to_uppercase();
232 let trimmed = query.trim_end_matches(';').trim();
233
234 if upper.contains(" ORDER BY ") {
236 return Ok(query.to_string());
237 }
238
239 let direction = if descending { "DESC" } else { "ASC" };
240
241 let insert_keywords = [" LIMIT ", " OFFSET ", " FOR "];
243 let mut insert_pos = trimmed.len();
244
245 for keyword in &insert_keywords {
246 if let Some(pos) = upper.find(keyword) {
247 if pos < insert_pos {
248 insert_pos = pos;
249 }
250 }
251 }
252
253 let (before, after) = trimmed.split_at(insert_pos);
254 Ok(format!("{} ORDER BY {} {}{};", before, column, direction, after))
255 }
256}
257
258impl Default for TransformationEngine {
259 fn default() -> Self {
260 Self::new()
261 }
262}
263
264pub trait CustomTransform: Send + Sync {
266 fn transform(&self, query: &str) -> Result<String, TransformError>;
268}
269
270#[derive(Debug, Clone)]
272pub enum TransformError {
273 RegexError(String),
275
276 ParseError(String),
278
279 UnknownCustomFunction(String),
281
282 NotApplicable(String),
284}
285
286impl std::fmt::Display for TransformError {
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 match self {
289 Self::RegexError(msg) => write!(f, "Regex error: {}", msg),
290 Self::ParseError(msg) => write!(f, "Parse error: {}", msg),
291 Self::UnknownCustomFunction(name) => write!(f, "Unknown custom function: {}", name),
292 Self::NotApplicable(msg) => write!(f, "Not applicable: {}", msg),
293 }
294 }
295}
296
297impl std::error::Error for TransformError {}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_add_limit() {
305 let engine = TransformationEngine::new();
306
307 let result = engine.add_limit("SELECT * FROM users", 100).unwrap();
308 assert!(result.contains("LIMIT 100"));
309
310 let result2 = engine.add_limit("SELECT * FROM users LIMIT 50", 100).unwrap();
312 assert!(result2.contains("LIMIT 50"));
313 assert!(!result2.contains("LIMIT 100"));
314 }
315
316 #[test]
317 fn test_add_where() {
318 let engine = TransformationEngine::new();
319
320 let result = engine.add_where_clause("SELECT * FROM users", "active = true").unwrap();
321 assert!(result.contains("WHERE active = true"));
322
323 let result2 = engine.add_where_clause("SELECT * FROM users WHERE id = 1", "active = true").unwrap();
325 assert!(result2.contains("AND (active = true)"));
326 }
327
328 #[test]
329 fn test_replace_table() {
330 let engine = TransformationEngine::new();
331
332 let result = engine.replace_table("SELECT * FROM old_users", "old_users", "users_v2").unwrap();
333 assert!(result.contains("users_v2"));
334 assert!(!result.contains("old_users"));
335 }
336
337 #[test]
338 fn test_expand_select_star() {
339 let engine = TransformationEngine::new();
340
341 let result = engine.expand_select_star(
342 "SELECT * FROM users",
343 &["id".to_string(), "name".to_string(), "email".to_string()]
344 ).unwrap();
345
346 assert!(result.contains("id, name, email"));
347 assert!(!result.contains("*"));
348 }
349
350 #[test]
351 fn test_expand_select_distinct_star() {
352 let engine = TransformationEngine::new();
353
354 let result = engine.expand_select_star(
355 "SELECT DISTINCT * FROM users",
356 &["id".to_string(), "name".to_string()]
357 ).unwrap();
358
359 assert!(result.contains("SELECT DISTINCT id, name"));
360 }
361
362 #[test]
363 fn test_add_index_hint() {
364 let engine = TransformationEngine::new();
365
366 let result = engine.add_index_hint("SELECT * FROM users WHERE id = 1", "users", "idx_users_id").unwrap();
367 assert!(result.contains("IndexScan(users idx_users_id)"));
368 }
369
370 #[test]
371 fn test_add_order_by() {
372 let engine = TransformationEngine::new();
373
374 let result = engine.add_order_by("SELECT * FROM users", "created_at", true).unwrap();
375 assert!(result.contains("ORDER BY created_at DESC"));
376 }
377
378 #[test]
379 fn test_add_hint() {
380 let engine = TransformationEngine::new();
381
382 let result = engine.apply("SELECT * FROM users", &Transformation::AddHint("parallel=4".to_string())).unwrap();
383 assert!(result.contains("/*parallel=4*/"));
384 }
385
386 #[test]
387 fn test_add_branch_hint() {
388 let engine = TransformationEngine::new();
389
390 let result = engine.apply("SELECT * FROM analytics", &Transformation::AddBranchHint("analytics".to_string())).unwrap();
391 assert!(result.contains("/*helios:branch=analytics*/"));
392 }
393
394 #[test]
395 fn test_chain_transformations() {
396 let engine = TransformationEngine::new();
397
398 let result = engine.apply(
399 "SELECT * FROM users",
400 &Transformation::Chain(vec![
401 Transformation::AddLimit(100),
402 Transformation::AddOrderBy {
403 column: "id".to_string(),
404 descending: false,
405 },
406 ])
407 ).unwrap();
408
409 assert!(result.contains("LIMIT 100"));
410 assert!(result.contains("ORDER BY id ASC"));
411 }
412}