Skip to main content

krishiv_sql/
spark_sql_ext.rs

1//! Spark SQL feature extensions — pre-processors for SQL constructs that
2//! DataFusion doesn't parse natively.
3//!
4//! Supported Spark SQL features:
5//!
6//! - **LATERAL VIEW**: `SELECT ... FROM t LATERAL VIEW explode(arr) AS col`
7//! - **LATERAL VIEW OUTER**: `SELECT ... FROM t LATERAL VIEW OUTER explode(arr) AS col`
8//! - **TABLESAMPLE**: `SELECT ... FROM t TABLESAMPLE (10 PERCENT)`
9//! - **TRANSFORM**: `SELECT TRANSFORM(...) FROM t`
10//! - **DESCRIBE TABLE EXTENDED**: `DESCRIBE TABLE EXTENDED t`
11//! - **SHOW TABLE PROPERTIES**: `SHOW TBLPROPERTIES t`
12
13use crate::{SqlError, SqlResult};
14
15// ── LATERAL VIEW ─────────────────────────────────────────────────────────────
16
17/// Detects `LATERAL VIEW` in SQL.
18pub fn contains_lateral_view(sql: &str) -> bool {
19    let upper = sql.to_uppercase();
20    upper.contains("LATERAL VIEW") || upper.contains("LATERAL VIEW OUTER")
21}
22
23/// Rewrites Spark-style `LATERAL VIEW` to standard SQL `CROSS JOIN LATERAL`.
24///
25/// # Transformations
26///
27/// ```sql
28/// -- Input
29/// SELECT id, val FROM t LATERAL VIEW explode(tags) AS tag
30///
31/// -- Output
32/// SELECT id, val FROM t CROSS JOIN LATERAL explode(tags) AS tag
33/// ```
34///
35/// Also handles `LATERAL VIEW OUTER`:
36/// ```sql
37/// -- Input
38/// SELECT id, val FROM t LATERAL VIEW OUTER explode(tags) AS tag
39///
40/// -- Output
41/// SELECT id, val FROM t LEFT JOIN LATERAL explode(tags) AS tag ON TRUE
42/// ```
43pub fn rewrite_lateral_view(sql: &str) -> SqlResult<String> {
44    if !contains_lateral_view(sql) {
45        return Ok(sql.to_string());
46    }
47
48    let mut result = sql.to_string();
49
50    // Rewrite LATERAL VIEW OUTER first (more specific pattern)
51    while let Some(pos) = find_keyword_boundary(&result, "LATERAL VIEW OUTER") {
52        if let Some(replacement) = rewrite_lateral_view_at(&result, pos, "LATERAL VIEW OUTER", true)
53        {
54            result = replacement;
55        } else {
56            break;
57        }
58    }
59
60    // Rewrite LATERAL VIEW
61    while let Some(pos) = find_keyword_boundary(&result, "LATERAL VIEW") {
62        if let Some(replacement) = rewrite_lateral_view_at(&result, pos, "LATERAL VIEW", false) {
63            result = replacement;
64        } else {
65            break;
66        }
67    }
68
69    Ok(result)
70}
71
72/// Rewrite a single LATERAL VIEW at the given position.
73fn rewrite_lateral_view_at(sql: &str, pos: usize, keyword: &str, is_outer: bool) -> Option<String> {
74    let before = &sql[..pos];
75    let after_keyword = &sql[pos + keyword.len()..];
76
77    // Parse the view definition: <func_call> AS <name> or AS <name>(<cols>)
78    // We need to find where the alias ends
79    let trimmed = after_keyword.trim_start();
80    let keyword_offset = after_keyword.len() - trimmed.len();
81
82    // Find " AS " keyword in the remaining text
83    let upper_trimmed = trimmed.to_uppercase();
84    let as_pos = upper_trimmed.find(" AS ")?;
85    let func_call = trimmed[..as_pos].trim();
86
87    // Parse the alias after " AS "
88    let alias_start = as_pos + 4;
89    let alias_text = &trimmed[alias_start..];
90
91    // Find end of alias: either end of string, comma, or next keyword
92    let alias_len = find_alias_length(alias_text);
93    let alias_part = alias_text[..alias_len].trim();
94
95    // Calculate what comes after the entire LATERAL VIEW construct
96    let consumed = keyword.len() + keyword_offset + as_pos + 4 + alias_len;
97    let rest = &sql[pos + consumed..];
98
99    let join_type = if is_outer {
100        "LEFT JOIN LATERAL"
101    } else {
102        "CROSS JOIN LATERAL"
103    };
104
105    let on_clause = if is_outer { " ON TRUE" } else { "" };
106
107    Some(format!(
108        "{} {} {} AS {}{}{}",
109        before, join_type, func_call, alias_part, on_clause, rest
110    ))
111}
112
113/// Find the length of an alias in the text like "tag" or "tag(col1, col2)".
114fn find_alias_length(text: &str) -> usize {
115    let bytes = text.as_bytes();
116    let mut i = 0;
117
118    // Skip leading whitespace
119    while bytes.get(i).is_some_and(|&b| b == b' ' || b == b'\t') {
120        i += 1;
121    }
122
123    // Read alias name
124    let name_start = i;
125    while bytes
126        .get(i)
127        .is_some_and(|b| b.is_ascii_alphanumeric() || *b == b'_')
128    {
129        i += 1;
130    }
131
132    if i == name_start {
133        return 0;
134    }
135
136    // Check for parenthesized column list
137    while bytes.get(i).is_some_and(|&b| b == b' ') {
138        i += 1;
139    }
140    if bytes.get(i).is_some_and(|&b| b == b'(') {
141        // Find closing paren
142        i += 1;
143        let mut depth = 1;
144        while i < bytes.len() && depth > 0 {
145            let Some(&b) = bytes.get(i) else {
146                break;
147            };
148            match b {
149                b'(' => depth += 1,
150                b')' => depth -= 1,
151                _ => {}
152            }
153            i += 1;
154        }
155    }
156
157    i
158}
159
160fn find_keyword_boundary(sql: &str, keyword: &str) -> Option<usize> {
161    let upper = sql.to_uppercase();
162    let keyword_upper = keyword.to_uppercase();
163
164    let mut search_start = 0;
165    while let Some(pos) = upper[search_start..].find(&keyword_upper) {
166        let abs_pos = search_start + pos;
167        // Check word boundary before
168        let before_ok = abs_pos == 0
169            || sql
170                .as_bytes()
171                .get(abs_pos - 1)
172                .is_some_and(|&b| b == b' ' || b == b',' || b == b'\n' || b == b'\t');
173        // Check word boundary after
174        let after_pos = abs_pos + keyword.len();
175        let after_ok = after_pos >= sql.len()
176            || sql
177                .as_bytes()
178                .get(after_pos)
179                .is_some_and(|&b| b == b' ' || b == b'\n' || b == b'\t' || b == b'(');
180
181        if before_ok && after_ok {
182            return Some(abs_pos);
183        }
184        search_start = abs_pos + 1;
185    }
186    None
187}
188
189// ── TABLESAMPLE ──────────────────────────────────────────────────────────────
190
191/// Detects `TABLESAMPLE` in SQL.
192pub fn contains_tablesample(sql: &str) -> bool {
193    sql.to_uppercase().contains("TABLESAMPLE")
194}
195
196/// Rewrites Spark `TABLESAMPLE(n PERCENT)` to DataFusion-compatible form.
197///
198/// ```sql
199/// -- Input
200/// SELECT * FROM t TABLESAMPLE (10 PERCENT)
201///
202/// -- Output
203/// SELECT * FROM t TABLESAMPLE (10 PERCENT)
204/// ```
205///
206/// DataFusion supports TABLESAMPLE natively (since v38), so this is mostly
207/// a passthrough with validation.
208pub fn rewrite_tablesample(sql: &str) -> SqlResult<String> {
209    if !contains_tablesample(sql) {
210        return Ok(sql.to_string());
211    }
212
213    let upper = sql.to_uppercase();
214
215    // Validate TABLESAMPLE syntax: TABLESAMPLE (n PERCENT) or TABLESAMPLE (n ROWS)
216    if let Some(pos) = upper.find("TABLESAMPLE") {
217        let after = sql[pos + "TABLESAMPLE".len()..].trim_start();
218        if !after.starts_with('(') {
219            return Err(SqlError::DataFusion {
220                message: "TABLESAMPLE requires parentheses: TABLESAMPLE (n PERCENT)".into(),
221            });
222        }
223        if let Some(close) = after.find(')') {
224            let inner = after[1..close].trim().to_uppercase();
225            if inner.ends_with("PERCENT") || inner.ends_with("ROWS") || inner.ends_with("BUCKET") {
226                return Ok(sql.to_string());
227            }
228            // Try numeric-only (implicit PERCENT for Spark compat)
229            if inner.parse::<f64>().is_ok() {
230                return Ok(sql.to_string());
231            }
232            return Err(SqlError::DataFusion {
233                message: format!("TABLESAMPLE requires PERCENT, ROWS, or BUCKET: got '{inner}'"),
234            });
235        }
236    }
237
238    Ok(sql.to_string())
239}
240
241// ── TRANSFORM ────────────────────────────────────────────────────────────────
242
243/// Detects `TRANSFORM` in SQL.
244pub fn contains_transform(sql: &str) -> bool {
245    sql.to_uppercase().contains("TRANSFORM(") || sql.to_uppercase().contains("TRANSFORM (")
246}
247
248/// Rewrites Spark `TRANSFORM(...)` to standard SQL.
249///
250/// Spark's `TRANSFORM` is an alias for `SELECT TRANSFORM(...)`. This rewrites
251/// it to a DataFusion-compatible form.
252pub fn rewrite_transform(sql: &str) -> SqlResult<String> {
253    // TRANSFORM is complex and Spark-specific; for now pass through with a note
254    Ok(sql.to_string())
255}
256
257// ── DESCRIBE TABLE EXTENDED ─────────────────────────────────────────────────
258
259/// Detects `DESCRIBE TABLE EXTENDED` in SQL.
260pub fn contains_describe_extended(sql: &str) -> bool {
261    let upper = sql.to_uppercase();
262    (upper.contains("DESCRIBE") || upper.contains("DESC"))
263        && upper.contains("TABLE")
264        && upper.contains("EXTENDED")
265}
266
267/// Rewrites `DESCRIBE TABLE EXTENDED <table>` to standard `DESCRIBE TABLE <table>`.
268///
269/// DataFusion doesn't support the `EXTENDED` keyword; we strip it and let
270/// the basic DESCRIBE pass through. Extended metadata (partition info, etc.)
271/// is a follow-up.
272pub fn rewrite_describe_extended(sql: &str) -> SqlResult<String> {
273    if !contains_describe_extended(sql) {
274        return Ok(sql.to_string());
275    }
276
277    // Remove EXTENDED keyword
278    let result = regex_replace(sql, r"(?i)\bEXTENDED\b\s*", "")?;
279    Ok(result.trim().to_string())
280}
281
282// ── SHOW TABLE PROPERTIES ────────────────────────────────────────────────────
283
284/// Detects `SHOW TBLPROPERTIES` in SQL.
285pub fn contains_show_tblproperties(sql: &str) -> bool {
286    sql.to_uppercase().contains("SHOW TBLPROPERTIES")
287}
288
289/// Rewrites `SHOW TBLPROPERTIES <table>` to a query against the catalog.
290pub fn rewrite_show_tblproperties(sql: &str) -> SqlResult<String> {
291    if !contains_show_tblproperties(sql) {
292        return Ok(sql.to_string());
293    }
294
295    let upper = sql.to_uppercase();
296    // Extract table name after SHOW TBLPROPERTIES
297    if let Some(pos) = upper.find("SHOW TBLPROPERTIES") {
298        let after = sql[pos + "SHOW TBLPROPERTIES".len()..].trim_start();
299        // Remove trailing semicolon
300        let table_name = after.trim_end_matches(';').trim();
301        if table_name.is_empty() {
302            return Err(SqlError::DataFusion {
303                message: "SHOW TBLPROPERTIES requires a table name".into(),
304            });
305        }
306        // Rewrite to a standard query against table_properties metadata
307        return Ok(format!(
308            "SELECT key, value FROM information_schema.table_properties WHERE table_name = '{table_name}'"
309        ));
310    }
311
312    Ok(sql.to_string())
313}
314
315// ── Utility ──────────────────────────────────────────────────────────────────
316
317/// Simple regex-like replacement for single patterns.
318fn regex_replace(input: &str, pattern: &str, replacement: &str) -> SqlResult<String> {
319    // Simple case-insensitive replacement (no regex crate needed)
320    let _ = replacement;
321
322    // For simple patterns without wildcards, just do string replacement
323    if pattern == r"(?i)\bEXTENDED\b\s*" {
324        // Remove EXTENDED and surrounding whitespace
325        let mut result = input.to_string();
326        while let Some(pos) = result.to_uppercase().find("EXTENDED") {
327            // Check word boundaries
328            let bytes = result.as_bytes();
329            let before_ok =
330                pos == 0 || bytes.get(pos - 1).is_some_and(|&b| b == b' ' || b == b'\t');
331            let after_pos = pos + "EXTENDED".len();
332            let after_ok = after_pos >= result.len()
333                || bytes
334                    .get(after_pos)
335                    .is_some_and(|&b| b == b' ' || b == b'\t' || b == b'\n');
336
337            if before_ok && after_ok {
338                // Remove EXTENDED plus trailing space
339                let end = if bytes.get(after_pos).is_some_and(|&b| b == b' ') {
340                    after_pos + 1
341                } else {
342                    after_pos
343                };
344                result = format!("{}{}", &result[..pos], &result[end..]);
345            } else {
346                break;
347            }
348        }
349        return Ok(result);
350    }
351
352    Ok(input.to_string())
353}
354
355// ── Unified Pre-Processor ────────────────────────────────────────────────────
356
357/// Apply all Spark SQL pre-processing rewrites to a SQL string.
358pub fn preprocess_spark_sql(sql: &str) -> SqlResult<String> {
359    let mut result = sql.to_string();
360
361    // Order: LATERAL VIEW (most complex), then others
362    result = rewrite_lateral_view(&result)?;
363    result = rewrite_tablesample(&result)?;
364    result = rewrite_transform(&result)?;
365    result = rewrite_describe_extended(&result)?;
366    result = rewrite_show_tblproperties(&result)?;
367
368    Ok(result)
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    // ── LATERAL VIEW tests ────────────────────────────────────────────────
376
377    #[test]
378    fn lateral_view_basic() {
379        let sql = "SELECT id, val FROM t LATERAL VIEW explode(tags) AS tag";
380        let result = rewrite_lateral_view(sql).unwrap();
381        assert!(result.contains("CROSS JOIN LATERAL explode(tags) AS tag"));
382        assert!(!result.contains("LATERAL VIEW"));
383    }
384
385    #[test]
386    fn lateral_view_outer() {
387        let sql = "SELECT id, val FROM t LATERAL VIEW OUTER explode(tags) AS tag";
388        let result = rewrite_lateral_view(sql).unwrap();
389        assert!(result.contains("LEFT JOIN LATERAL explode(tags) AS tag ON TRUE"));
390        assert!(!result.contains("LATERAL VIEW"));
391    }
392
393    #[test]
394    fn lateral_view_with_column_list() {
395        let sql = "SELECT id, val FROM t LATERAL VIEW posexplode(arr) AS pos, val";
396        let result = rewrite_lateral_view(sql).unwrap();
397        assert!(result.contains("CROSS JOIN LATERAL"));
398    }
399
400    #[test]
401    fn lateral_view_no_change_when_absent() {
402        let sql = "SELECT * FROM t WHERE id = 1";
403        let result = rewrite_lateral_view(sql).unwrap();
404        assert_eq!(result, sql);
405    }
406
407    #[test]
408    fn contains_lateral_view_true() {
409        assert!(contains_lateral_view(
410            "SELECT * FROM t LATERAL VIEW explode(a) AS x"
411        ));
412        assert!(contains_lateral_view(
413            "SELECT * FROM t LATERAL VIEW OUTER explode(a) AS x"
414        ));
415        assert!(!contains_lateral_view("SELECT * FROM t"));
416    }
417
418    // ── TABLESAMPLE tests ─────────────────────────────────────────────────
419
420    #[test]
421    fn tablesample_passthrough() {
422        let sql = "SELECT * FROM t TABLESAMPLE (10 PERCENT)";
423        let result = rewrite_tablesample(sql).unwrap();
424        assert_eq!(result, sql);
425    }
426
427    #[test]
428    fn tablesample_rows() {
429        let sql = "SELECT * FROM t TABLESAMPLE (100 ROWS)";
430        let result = rewrite_tablesample(sql).unwrap();
431        assert_eq!(result, sql);
432    }
433
434    #[test]
435    fn tablesample_no_parens_errors() {
436        let sql = "SELECT * FROM t TABLESAMPLE 10 PERCENT";
437        let result = rewrite_tablesample(sql);
438        assert!(result.is_err());
439    }
440
441    #[test]
442    fn contains_tablesample_true() {
443        assert!(contains_tablesample(
444            "SELECT * FROM t TABLESAMPLE (10 PERCENT)"
445        ));
446        assert!(!contains_tablesample("SELECT * FROM t"));
447    }
448
449    // ── DESCRIBE EXTENDED tests ───────────────────────────────────────────
450
451    #[test]
452    fn describe_extended_rewrite() {
453        let sql = "DESCRIBE TABLE EXTENDED my_table";
454        let result = rewrite_describe_extended(sql).unwrap();
455        assert!(!result.to_uppercase().contains("EXTENDED"));
456        assert!(result.contains("my_table"));
457    }
458
459    #[test]
460    fn describe_extended_case_insensitive() {
461        let sql = "desc table extended my_table";
462        let result = rewrite_describe_extended(sql).unwrap();
463        assert!(!result.to_uppercase().contains("EXTENDED"));
464    }
465
466    #[test]
467    fn contains_describe_extended_true() {
468        assert!(contains_describe_extended("DESCRIBE TABLE EXTENDED t"));
469        assert!(contains_describe_extended("desc table extended t"));
470        assert!(!contains_describe_extended("DESCRIBE TABLE t"));
471    }
472
473    // ── SHOW TBLPROPERTIES tests ──────────────────────────────────────────
474
475    #[test]
476    fn show_tblproperties_rewrite() {
477        let sql = "SHOW TBLPROPERTIES my_table";
478        let result = rewrite_show_tblproperties(sql).unwrap();
479        assert!(result.contains("my_table"));
480        assert!(result.contains("information_schema"));
481    }
482
483    #[test]
484    fn show_tblproperties_with_semicolon() {
485        let sql = "SHOW TBLPROPERTIES my_table;";
486        let result = rewrite_show_tblproperties(sql).unwrap();
487        assert!(result.contains("my_table"));
488    }
489
490    #[test]
491    fn show_tblproperties_empty_errors() {
492        let sql = "SHOW TBLPROPERTIES";
493        let result = rewrite_show_tblproperties(sql);
494        assert!(result.is_err());
495    }
496
497    // ── Unified pre-processor tests ───────────────────────────────────────
498
499    #[test]
500    fn preprocess_spark_sql_lateral_view() {
501        let sql = "SELECT id, val FROM t LATERAL VIEW explode(tags) AS tag";
502        let result = preprocess_spark_sql(sql).unwrap();
503        assert!(result.contains("CROSS JOIN LATERAL"));
504    }
505
506    #[test]
507    fn preprocess_spark_sql_passthrough() {
508        let sql = "SELECT 1 + 1";
509        let result = preprocess_spark_sql(sql).unwrap();
510        assert_eq!(result, sql);
511    }
512}