krishiv_sql/
spark_sql_ext.rs1use crate::{SqlError, SqlResult};
14
15pub fn contains_lateral_view(sql: &str) -> bool {
19 let upper = sql.to_uppercase();
20 upper.contains("LATERAL VIEW") || upper.contains("LATERAL VIEW OUTER")
21}
22
23pub fn rewrite_lateral_view(sql: &str) -> SqlResult<String> {
44 if !contains_lateral_view(sql) {
45 return Ok(sql.to_string());
46 }
47
48 let mut result = sql.to_string();
49
50 while let Some(pos) = find_keyword_boundary(&result, "LATERAL VIEW OUTER") {
52 if let Some(replacement) = rewrite_lateral_view_at(&result, pos, "LATERAL VIEW OUTER", true)
53 {
54 result = replacement;
55 } else {
56 break;
57 }
58 }
59
60 while let Some(pos) = find_keyword_boundary(&result, "LATERAL VIEW") {
62 if let Some(replacement) = rewrite_lateral_view_at(&result, pos, "LATERAL VIEW", false) {
63 result = replacement;
64 } else {
65 break;
66 }
67 }
68
69 Ok(result)
70}
71
72fn rewrite_lateral_view_at(sql: &str, pos: usize, keyword: &str, is_outer: bool) -> Option<String> {
74 let before = &sql[..pos];
75 let after_keyword = &sql[pos + keyword.len()..];
76
77 let trimmed = after_keyword.trim_start();
80 let keyword_offset = after_keyword.len() - trimmed.len();
81
82 let upper_trimmed = trimmed.to_uppercase();
84 let as_pos = upper_trimmed.find(" AS ")?;
85 let func_call = trimmed[..as_pos].trim();
86
87 let alias_start = as_pos + 4;
89 let alias_text = &trimmed[alias_start..];
90
91 let alias_len = find_alias_length(alias_text);
93 let alias_part = alias_text[..alias_len].trim();
94
95 let consumed = keyword.len() + keyword_offset + as_pos + 4 + alias_len;
97 let rest = &sql[pos + consumed..];
98
99 let join_type = if is_outer {
100 "LEFT JOIN LATERAL"
101 } else {
102 "CROSS JOIN LATERAL"
103 };
104
105 let on_clause = if is_outer { " ON TRUE" } else { "" };
106
107 Some(format!(
108 "{} {} {} AS {}{}{}",
109 before, join_type, func_call, alias_part, on_clause, rest
110 ))
111}
112
113fn find_alias_length(text: &str) -> usize {
115 let bytes = text.as_bytes();
116 let mut i = 0;
117
118 while bytes.get(i).is_some_and(|&b| b == b' ' || b == b'\t') {
120 i += 1;
121 }
122
123 let name_start = i;
125 while bytes
126 .get(i)
127 .is_some_and(|b| b.is_ascii_alphanumeric() || *b == b'_')
128 {
129 i += 1;
130 }
131
132 if i == name_start {
133 return 0;
134 }
135
136 while bytes.get(i).is_some_and(|&b| b == b' ') {
138 i += 1;
139 }
140 if bytes.get(i).is_some_and(|&b| b == b'(') {
141 i += 1;
143 let mut depth = 1;
144 while i < bytes.len() && depth > 0 {
145 let Some(&b) = bytes.get(i) else {
146 break;
147 };
148 match b {
149 b'(' => depth += 1,
150 b')' => depth -= 1,
151 _ => {}
152 }
153 i += 1;
154 }
155 }
156
157 i
158}
159
160fn find_keyword_boundary(sql: &str, keyword: &str) -> Option<usize> {
161 let upper = sql.to_uppercase();
162 let keyword_upper = keyword.to_uppercase();
163
164 let mut search_start = 0;
165 while let Some(pos) = upper[search_start..].find(&keyword_upper) {
166 let abs_pos = search_start + pos;
167 let before_ok = abs_pos == 0
169 || sql
170 .as_bytes()
171 .get(abs_pos - 1)
172 .is_some_and(|&b| b == b' ' || b == b',' || b == b'\n' || b == b'\t');
173 let after_pos = abs_pos + keyword.len();
175 let after_ok = after_pos >= sql.len()
176 || sql
177 .as_bytes()
178 .get(after_pos)
179 .is_some_and(|&b| b == b' ' || b == b'\n' || b == b'\t' || b == b'(');
180
181 if before_ok && after_ok {
182 return Some(abs_pos);
183 }
184 search_start = abs_pos + 1;
185 }
186 None
187}
188
189pub fn contains_tablesample(sql: &str) -> bool {
193 sql.to_uppercase().contains("TABLESAMPLE")
194}
195
196pub fn rewrite_tablesample(sql: &str) -> SqlResult<String> {
209 if !contains_tablesample(sql) {
210 return Ok(sql.to_string());
211 }
212
213 let upper = sql.to_uppercase();
214
215 if let Some(pos) = upper.find("TABLESAMPLE") {
217 let after = sql[pos + "TABLESAMPLE".len()..].trim_start();
218 if !after.starts_with('(') {
219 return Err(SqlError::DataFusion {
220 message: "TABLESAMPLE requires parentheses: TABLESAMPLE (n PERCENT)".into(),
221 });
222 }
223 if let Some(close) = after.find(')') {
224 let inner = after[1..close].trim().to_uppercase();
225 if inner.ends_with("PERCENT") || inner.ends_with("ROWS") || inner.ends_with("BUCKET") {
226 return Ok(sql.to_string());
227 }
228 if inner.parse::<f64>().is_ok() {
230 return Ok(sql.to_string());
231 }
232 return Err(SqlError::DataFusion {
233 message: format!("TABLESAMPLE requires PERCENT, ROWS, or BUCKET: got '{inner}'"),
234 });
235 }
236 }
237
238 Ok(sql.to_string())
239}
240
241pub fn contains_transform(sql: &str) -> bool {
245 sql.to_uppercase().contains("TRANSFORM(") || sql.to_uppercase().contains("TRANSFORM (")
246}
247
248pub fn rewrite_transform(sql: &str) -> SqlResult<String> {
253 Ok(sql.to_string())
255}
256
257pub fn contains_describe_extended(sql: &str) -> bool {
261 let upper = sql.to_uppercase();
262 (upper.contains("DESCRIBE") || upper.contains("DESC"))
263 && upper.contains("TABLE")
264 && upper.contains("EXTENDED")
265}
266
267pub fn rewrite_describe_extended(sql: &str) -> SqlResult<String> {
273 if !contains_describe_extended(sql) {
274 return Ok(sql.to_string());
275 }
276
277 let result = regex_replace(sql, r"(?i)\bEXTENDED\b\s*", "")?;
279 Ok(result.trim().to_string())
280}
281
282pub fn contains_show_tblproperties(sql: &str) -> bool {
286 sql.to_uppercase().contains("SHOW TBLPROPERTIES")
287}
288
289pub fn rewrite_show_tblproperties(sql: &str) -> SqlResult<String> {
291 if !contains_show_tblproperties(sql) {
292 return Ok(sql.to_string());
293 }
294
295 let upper = sql.to_uppercase();
296 if let Some(pos) = upper.find("SHOW TBLPROPERTIES") {
298 let after = sql[pos + "SHOW TBLPROPERTIES".len()..].trim_start();
299 let table_name = after.trim_end_matches(';').trim();
301 if table_name.is_empty() {
302 return Err(SqlError::DataFusion {
303 message: "SHOW TBLPROPERTIES requires a table name".into(),
304 });
305 }
306 return Ok(format!(
308 "SELECT key, value FROM information_schema.table_properties WHERE table_name = '{table_name}'"
309 ));
310 }
311
312 Ok(sql.to_string())
313}
314
315fn regex_replace(input: &str, pattern: &str, replacement: &str) -> SqlResult<String> {
319 let _ = replacement;
321
322 if pattern == r"(?i)\bEXTENDED\b\s*" {
324 let mut result = input.to_string();
326 while let Some(pos) = result.to_uppercase().find("EXTENDED") {
327 let bytes = result.as_bytes();
329 let before_ok =
330 pos == 0 || bytes.get(pos - 1).is_some_and(|&b| b == b' ' || b == b'\t');
331 let after_pos = pos + "EXTENDED".len();
332 let after_ok = after_pos >= result.len()
333 || bytes
334 .get(after_pos)
335 .is_some_and(|&b| b == b' ' || b == b'\t' || b == b'\n');
336
337 if before_ok && after_ok {
338 let end = if bytes.get(after_pos).is_some_and(|&b| b == b' ') {
340 after_pos + 1
341 } else {
342 after_pos
343 };
344 result = format!("{}{}", &result[..pos], &result[end..]);
345 } else {
346 break;
347 }
348 }
349 return Ok(result);
350 }
351
352 Ok(input.to_string())
353}
354
355pub fn preprocess_spark_sql(sql: &str) -> SqlResult<String> {
359 let mut result = sql.to_string();
360
361 result = rewrite_lateral_view(&result)?;
363 result = rewrite_tablesample(&result)?;
364 result = rewrite_transform(&result)?;
365 result = rewrite_describe_extended(&result)?;
366 result = rewrite_show_tblproperties(&result)?;
367
368 Ok(result)
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
378 fn lateral_view_basic() {
379 let sql = "SELECT id, val FROM t LATERAL VIEW explode(tags) AS tag";
380 let result = rewrite_lateral_view(sql).unwrap();
381 assert!(result.contains("CROSS JOIN LATERAL explode(tags) AS tag"));
382 assert!(!result.contains("LATERAL VIEW"));
383 }
384
385 #[test]
386 fn lateral_view_outer() {
387 let sql = "SELECT id, val FROM t LATERAL VIEW OUTER explode(tags) AS tag";
388 let result = rewrite_lateral_view(sql).unwrap();
389 assert!(result.contains("LEFT JOIN LATERAL explode(tags) AS tag ON TRUE"));
390 assert!(!result.contains("LATERAL VIEW"));
391 }
392
393 #[test]
394 fn lateral_view_with_column_list() {
395 let sql = "SELECT id, val FROM t LATERAL VIEW posexplode(arr) AS pos, val";
396 let result = rewrite_lateral_view(sql).unwrap();
397 assert!(result.contains("CROSS JOIN LATERAL"));
398 }
399
400 #[test]
401 fn lateral_view_no_change_when_absent() {
402 let sql = "SELECT * FROM t WHERE id = 1";
403 let result = rewrite_lateral_view(sql).unwrap();
404 assert_eq!(result, sql);
405 }
406
407 #[test]
408 fn contains_lateral_view_true() {
409 assert!(contains_lateral_view(
410 "SELECT * FROM t LATERAL VIEW explode(a) AS x"
411 ));
412 assert!(contains_lateral_view(
413 "SELECT * FROM t LATERAL VIEW OUTER explode(a) AS x"
414 ));
415 assert!(!contains_lateral_view("SELECT * FROM t"));
416 }
417
418 #[test]
421 fn tablesample_passthrough() {
422 let sql = "SELECT * FROM t TABLESAMPLE (10 PERCENT)";
423 let result = rewrite_tablesample(sql).unwrap();
424 assert_eq!(result, sql);
425 }
426
427 #[test]
428 fn tablesample_rows() {
429 let sql = "SELECT * FROM t TABLESAMPLE (100 ROWS)";
430 let result = rewrite_tablesample(sql).unwrap();
431 assert_eq!(result, sql);
432 }
433
434 #[test]
435 fn tablesample_no_parens_errors() {
436 let sql = "SELECT * FROM t TABLESAMPLE 10 PERCENT";
437 let result = rewrite_tablesample(sql);
438 assert!(result.is_err());
439 }
440
441 #[test]
442 fn contains_tablesample_true() {
443 assert!(contains_tablesample(
444 "SELECT * FROM t TABLESAMPLE (10 PERCENT)"
445 ));
446 assert!(!contains_tablesample("SELECT * FROM t"));
447 }
448
449 #[test]
452 fn describe_extended_rewrite() {
453 let sql = "DESCRIBE TABLE EXTENDED my_table";
454 let result = rewrite_describe_extended(sql).unwrap();
455 assert!(!result.to_uppercase().contains("EXTENDED"));
456 assert!(result.contains("my_table"));
457 }
458
459 #[test]
460 fn describe_extended_case_insensitive() {
461 let sql = "desc table extended my_table";
462 let result = rewrite_describe_extended(sql).unwrap();
463 assert!(!result.to_uppercase().contains("EXTENDED"));
464 }
465
466 #[test]
467 fn contains_describe_extended_true() {
468 assert!(contains_describe_extended("DESCRIBE TABLE EXTENDED t"));
469 assert!(contains_describe_extended("desc table extended t"));
470 assert!(!contains_describe_extended("DESCRIBE TABLE t"));
471 }
472
473 #[test]
476 fn show_tblproperties_rewrite() {
477 let sql = "SHOW TBLPROPERTIES my_table";
478 let result = rewrite_show_tblproperties(sql).unwrap();
479 assert!(result.contains("my_table"));
480 assert!(result.contains("information_schema"));
481 }
482
483 #[test]
484 fn show_tblproperties_with_semicolon() {
485 let sql = "SHOW TBLPROPERTIES my_table;";
486 let result = rewrite_show_tblproperties(sql).unwrap();
487 assert!(result.contains("my_table"));
488 }
489
490 #[test]
491 fn show_tblproperties_empty_errors() {
492 let sql = "SHOW TBLPROPERTIES";
493 let result = rewrite_show_tblproperties(sql);
494 assert!(result.is_err());
495 }
496
497 #[test]
500 fn preprocess_spark_sql_lateral_view() {
501 let sql = "SELECT id, val FROM t LATERAL VIEW explode(tags) AS tag";
502 let result = preprocess_spark_sql(sql).unwrap();
503 assert!(result.contains("CROSS JOIN LATERAL"));
504 }
505
506 #[test]
507 fn preprocess_spark_sql_passthrough() {
508 let sql = "SELECT 1 + 1";
509 let result = preprocess_spark_sql(sql).unwrap();
510 assert_eq!(result, sql);
511 }
512}