Skip to main content

krishiv_sql/
pivot_sql.rs

1//! E5.4 — PIVOT / UNPIVOT SQL macro rewrite layer.
2//!
3//! DataFusion does not parse `PIVOT` or `UNPIVOT` natively. This module rewrites
4//! those constructs into equivalent standard SQL before passing the query to DataFusion.
5//!
6//! # PIVOT rewrite
7//!
8//! ```sql
9//! SELECT * FROM sales
10//! PIVOT (SUM(amount) FOR category IN ('food', 'tech', 'clothing'))
11//! ```
12//! becomes:
13//! ```sql
14//! SELECT
15//!   SUM(CASE WHEN category = 'food' THEN amount END) AS "food",
16//!   SUM(CASE WHEN category = 'tech' THEN amount END) AS "tech",
17//!   SUM(CASE WHEN category = 'clothing' THEN amount END) AS "clothing"
18//! FROM sales
19//! ```
20//!
21//! # UNPIVOT rewrite
22//!
23//! ```sql
24//! SELECT * FROM monthly
25//! UNPIVOT (value FOR month IN (jan, feb, mar))
26//! ```
27//! becomes a UNION ALL of individual SELECT statements:
28//! ```sql
29//! SELECT 'jan' AS month, jan AS value FROM monthly
30//! UNION ALL
31//! SELECT 'feb' AS month, feb AS value FROM monthly
32//! UNION ALL
33//! SELECT 'mar' AS month, mar AS value FROM monthly
34//! ```
35
36use crate::{SqlError, SqlResult};
37
38// ── Detection ─────────────────────────────────────────────────────────────────
39
40/// Returns `true` if `sql` contains a `PIVOT` clause (case-insensitive).
41pub fn contains_pivot(sql: &str) -> bool {
42    sql.to_ascii_uppercase().contains(" PIVOT (") || sql.to_ascii_uppercase().contains(" PIVOT(")
43}
44
45/// Returns `true` if `sql` contains an `UNPIVOT` clause (case-insensitive).
46pub fn contains_unpivot(sql: &str) -> bool {
47    sql.to_ascii_uppercase().contains(" UNPIVOT (")
48        || sql.to_ascii_uppercase().contains(" UNPIVOT(")
49}
50
51// ── PIVOT rewrite ─────────────────────────────────────────────────────────────
52
53/// Parsed representation of a PIVOT clause.
54#[derive(Debug, Clone)]
55pub struct PivotClause {
56    /// Aggregate function name (e.g. "SUM", "COUNT", "MAX").
57    pub agg_fn: String,
58    /// Column to aggregate (e.g. "amount").
59    pub agg_column: String,
60    /// Pivot dimension column (e.g. "category").
61    pub for_column: String,
62    /// Values to pivot into columns.
63    pub in_values: Vec<String>,
64    /// Source table or subquery (the part before PIVOT).
65    pub source: String,
66}
67
68/// Parse a simple `SELECT * FROM <source> PIVOT (<agg>(<col>) FOR <dim> IN (<vals>))` statement.
69///
70/// Returns `Ok(None)` when the SQL does not contain a PIVOT clause.
71pub fn parse_pivot(sql: &str) -> SqlResult<Option<PivotClause>> {
72    let upper = sql.to_ascii_uppercase();
73    let pivot_kw = " PIVOT (";
74    let pivot_pos = match upper.find(pivot_kw) {
75        Some(p) => p,
76        None => {
77            // Try without space before paren.
78            match upper.find(" PIVOT(") {
79                Some(p) => p,
80                None => return Ok(None),
81            }
82        }
83    };
84
85    let source = sql[..pivot_pos].trim().to_owned();
86
87    // Find the matching closing paren.
88    let body_start = pivot_pos + pivot_kw.len();
89    let body_end = find_closing_paren(&sql[body_start..]).ok_or_else(|| SqlError::Unsupported {
90        feature: "PIVOT: unmatched parenthesis".into(),
91    })? + body_start;
92
93    let body = sql[body_start..body_end].trim();
94    let body_upper = body.to_ascii_uppercase();
95
96    // Parse: AGG(col) FOR dim IN (v1, v2, ...)
97    let for_pos = body_upper
98        .find(" FOR ")
99        .ok_or_else(|| SqlError::Unsupported {
100            feature: "PIVOT: missing FOR keyword".into(),
101        })?;
102    let in_pos = body_upper
103        .find(" IN (")
104        .ok_or_else(|| SqlError::Unsupported {
105            feature: "PIVOT: missing IN keyword".into(),
106        })?;
107
108    let agg_expr = body[..for_pos].trim();
109    let for_column = body[for_pos + 5..in_pos].trim().to_owned();
110
111    // Parse AGG(col)
112    let lp = agg_expr.find('(').ok_or_else(|| SqlError::Unsupported {
113        feature: "PIVOT: aggregation must be in the form AGG(column)".into(),
114    })?;
115    let rp = agg_expr.rfind(')').ok_or_else(|| SqlError::Unsupported {
116        feature: "PIVOT: aggregation must end with ')'".into(),
117    })?;
118    let agg_fn = agg_expr[..lp].trim().to_owned();
119    let agg_column = agg_expr[lp + 1..rp].trim().to_owned();
120
121    // Parse IN (v1, v2, ...)
122    let in_list_start = in_pos + 5;
123    let in_list_end = body[in_list_start..]
124        .find(')')
125        .ok_or_else(|| SqlError::Unsupported {
126            feature: "PIVOT: IN list is not closed".into(),
127        })?
128        + in_list_start;
129    let in_list = &body[in_list_start..in_list_end];
130
131    let in_values: Vec<String> = in_list
132        .split(',')
133        .map(|v| v.trim().to_owned())
134        .filter(|v| !v.is_empty())
135        .collect();
136
137    if in_values.is_empty() {
138        return Err(SqlError::Unsupported {
139            feature: "PIVOT: IN list must contain at least one value".into(),
140        });
141    }
142
143    Ok(Some(PivotClause {
144        agg_fn,
145        agg_column,
146        for_column,
147        in_values,
148        source,
149    }))
150}
151
152/// Rewrite a PIVOT statement to equivalent `CASE WHEN` SQL.
153///
154/// Returns the original `sql` unchanged when no PIVOT clause is found.
155pub fn rewrite_pivot(sql: &str) -> SqlResult<String> {
156    let Some(pivot) = parse_pivot(sql)? else {
157        return Ok(sql.to_owned());
158    };
159
160    let mut cols = Vec::with_capacity(pivot.in_values.len());
161    for val in &pivot.in_values {
162        // Strip surrounding quotes from the alias name for readability.
163        let alias = val.trim_matches('\'').trim_matches('"');
164        cols.push(format!(
165            "{}(CASE WHEN {} = {} THEN {} END) AS \"{}\"",
166            pivot.agg_fn, pivot.for_column, val, pivot.agg_column, alias,
167        ));
168    }
169
170    // Strip the leading SELECT ... FROM from source to get just the FROM clause.
171    let from_clause = strip_select_star_prefix(&pivot.source);
172
173    Ok(format!("SELECT {} FROM {}", cols.join(", "), from_clause))
174}
175
176// ── UNPIVOT rewrite ───────────────────────────────────────────────────────────
177
178/// Parsed representation of an UNPIVOT clause.
179#[derive(Debug, Clone)]
180pub struct UnpivotClause {
181    /// Output column that receives the values.
182    pub value_column: String,
183    /// Output column that receives the pivot dimension name.
184    pub name_column: String,
185    /// Source columns to unpivot.
186    pub in_columns: Vec<String>,
187    /// Source table or subquery.
188    pub source: String,
189}
190
191/// Parse a simple `SELECT * FROM <source> UNPIVOT (<val_col> FOR <name_col> IN (<cols>))`.
192///
193/// Returns `Ok(None)` when the SQL does not contain an UNPIVOT clause.
194pub fn parse_unpivot(sql: &str) -> SqlResult<Option<UnpivotClause>> {
195    let upper = sql.to_ascii_uppercase();
196    let kw = " UNPIVOT (";
197    let kw_short = " UNPIVOT(";
198    let unpivot_pos = match upper.find(kw) {
199        Some(p) => p,
200        None => match upper.find(kw_short) {
201            Some(p) => p,
202            None => return Ok(None),
203        },
204    };
205
206    let source = sql[..unpivot_pos].trim().to_owned();
207    let body_start = unpivot_pos
208        + sql[unpivot_pos..]
209            .find('(')
210            .ok_or_else(|| SqlError::Unsupported {
211                feature: "UNPIVOT: missing opening parenthesis".into(),
212            })?
213        + 1;
214    let body_end = find_closing_paren(&sql[body_start..]).ok_or_else(|| SqlError::Unsupported {
215        feature: "UNPIVOT: unmatched parenthesis".into(),
216    })? + body_start;
217    let body = sql[body_start..body_end].trim();
218    let body_upper = body.to_ascii_uppercase();
219
220    let for_pos = body_upper
221        .find(" FOR ")
222        .ok_or_else(|| SqlError::Unsupported {
223            feature: "UNPIVOT: missing FOR keyword".into(),
224        })?;
225    let in_pos = body_upper
226        .find(" IN (")
227        .ok_or_else(|| SqlError::Unsupported {
228            feature: "UNPIVOT: missing IN keyword".into(),
229        })?;
230
231    let value_column = body[..for_pos].trim().to_owned();
232    let name_column = body[for_pos + 5..in_pos].trim().to_owned();
233
234    let in_list_start = in_pos + 5;
235    let in_list_end = body[in_list_start..]
236        .find(')')
237        .ok_or_else(|| SqlError::Unsupported {
238            feature: "UNPIVOT: IN list is not closed".into(),
239        })?
240        + in_list_start;
241    let in_list = &body[in_list_start..in_list_end];
242
243    let in_columns: Vec<String> = in_list
244        .split(',')
245        .map(|v| v.trim().to_owned())
246        .filter(|v| !v.is_empty())
247        .collect();
248
249    if in_columns.is_empty() {
250        return Err(SqlError::Unsupported {
251            feature: "UNPIVOT: IN list must contain at least one column".into(),
252        });
253    }
254
255    Ok(Some(UnpivotClause {
256        value_column,
257        name_column,
258        in_columns,
259        source,
260    }))
261}
262
263/// Rewrite an UNPIVOT statement to a `UNION ALL` of SELECT statements.
264///
265/// Returns the original `sql` unchanged when no UNPIVOT clause is found.
266pub fn rewrite_unpivot(sql: &str) -> SqlResult<String> {
267    let Some(unpivot) = parse_unpivot(sql)? else {
268        return Ok(sql.to_owned());
269    };
270
271    let from_clause = strip_select_star_prefix(&unpivot.source);
272
273    let mut branches = Vec::with_capacity(unpivot.in_columns.len());
274    for col in &unpivot.in_columns {
275        // Double-quote identifiers to handle reserved words and special characters.
276        let col_quoted = col.replace('"', "\"\"");
277        let name_col_quoted = unpivot.name_column.replace('"', "\"\"");
278        let val_col_quoted = unpivot.value_column.replace('"', "\"\"");
279        branches.push(format!(
280            "SELECT '{}' AS \"{}\", \"{}\" AS \"{}\" FROM {}",
281            col.replace('\'', "''"),
282            name_col_quoted,
283            col_quoted,
284            val_col_quoted,
285            from_clause,
286        ));
287    }
288
289    Ok(branches.join(" UNION ALL "))
290}
291
292/// Entry point: rewrite PIVOT or UNPIVOT if present, otherwise return unchanged.
293pub fn rewrite_pivot_unpivot(sql: &str) -> SqlResult<String> {
294    if contains_pivot(sql) {
295        rewrite_pivot(sql)
296    } else if contains_unpivot(sql) {
297        rewrite_unpivot(sql)
298    } else {
299        Ok(sql.to_owned())
300    }
301}
302
303// ── Helpers ───────────────────────────────────────────────────────────────────
304
305/// Find the index of the closing `)` matching the first `(` already consumed.
306///
307/// `s` starts *after* the opening `(`. Returns the byte index of `)` relative
308/// to `s`.
309fn find_closing_paren(s: &str) -> Option<usize> {
310    let mut depth = 1usize;
311    for (i, ch) in s.char_indices() {
312        match ch {
313            '(' => depth += 1,
314            ')' => {
315                depth -= 1;
316                if depth == 0 {
317                    return Some(i);
318                }
319            }
320            _ => {}
321        }
322    }
323    None
324}
325
326/// Strip a leading `SELECT * FROM ` or `SELECT … FROM ` prefix from the source
327/// fragment so the caller can use it directly as a FROM clause.
328fn strip_select_star_prefix(s: &str) -> &str {
329    let upper = s.to_ascii_uppercase();
330    if let Some(from_pos) = upper.rfind(" FROM ") {
331        s[from_pos + 6..].trim()
332    } else {
333        s.trim()
334    }
335}
336
337// ── Tests ─────────────────────────────────────────────────────────────────────
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342
343    // ── PIVOT ──────────────────────────────────────────────────────────────────
344
345    #[test]
346    fn detects_pivot() {
347        assert!(contains_pivot(
348            "SELECT * FROM t PIVOT (SUM(x) FOR y IN ('a'))"
349        ));
350        assert!(!contains_pivot("SELECT * FROM t WHERE x = 1"));
351    }
352
353    #[test]
354    fn parses_pivot() {
355        let sql = "SELECT * FROM sales PIVOT (SUM(amount) FOR category IN ('food', 'tech'))";
356        let pivot = parse_pivot(sql).unwrap().unwrap();
357        assert_eq!(pivot.agg_fn, "SUM");
358        assert_eq!(pivot.agg_column, "amount");
359        assert_eq!(pivot.for_column, "category");
360        assert_eq!(pivot.in_values, vec!["'food'", "'tech'"]);
361    }
362
363    #[test]
364    fn rewrites_pivot_to_case_when() {
365        let sql = "SELECT * FROM sales PIVOT (SUM(amount) FOR category IN ('food', 'tech'))";
366        let rewritten = rewrite_pivot(sql).unwrap();
367        assert!(rewritten.to_ascii_uppercase().contains("CASE WHEN"));
368        assert!(rewritten.to_ascii_uppercase().contains("SUM("));
369        assert!(rewritten.contains("'food'"));
370        assert!(rewritten.contains("'tech'"));
371        assert!(!rewritten.to_ascii_uppercase().contains("PIVOT"));
372    }
373
374    #[test]
375    fn pivot_rewrite_generates_correct_aliases() {
376        let sql = "SELECT * FROM t PIVOT (MAX(val) FOR dim IN ('x', 'y'))";
377        let rewritten = rewrite_pivot(sql).unwrap();
378        assert!(rewritten.contains("\"x\""));
379        assert!(rewritten.contains("\"y\""));
380    }
381
382    #[test]
383    fn returns_unchanged_when_no_pivot() {
384        let sql = "SELECT * FROM t WHERE x = 1";
385        let result = rewrite_pivot(sql).unwrap();
386        assert_eq!(result, sql);
387    }
388
389    #[test]
390    fn rejects_pivot_without_for() {
391        let sql = "SELECT * FROM t PIVOT (SUM(x) IN ('a'))";
392        let err = parse_pivot(sql).unwrap_err();
393        assert!(matches!(err, SqlError::Unsupported { .. }));
394    }
395
396    // ── UNPIVOT ────────────────────────────────────────────────────────────────
397
398    #[test]
399    fn detects_unpivot() {
400        assert!(contains_unpivot(
401            "SELECT * FROM t UNPIVOT (val FOR month IN (jan, feb))"
402        ));
403        assert!(!contains_unpivot("SELECT * FROM t WHERE x = 1"));
404    }
405
406    #[test]
407    fn parses_unpivot() {
408        let sql = "SELECT * FROM monthly UNPIVOT (value FOR month IN (jan, feb, mar))";
409        let unpivot = parse_unpivot(sql).unwrap().unwrap();
410        assert_eq!(unpivot.value_column, "value");
411        assert_eq!(unpivot.name_column, "month");
412        assert_eq!(unpivot.in_columns, vec!["jan", "feb", "mar"]);
413    }
414
415    #[test]
416    fn rewrites_unpivot_to_union_all() {
417        let sql = "SELECT * FROM monthly UNPIVOT (value FOR month IN (jan, feb, mar))";
418        let rewritten = rewrite_unpivot(sql).unwrap();
419        assert!(rewritten.to_ascii_uppercase().contains("UNION ALL"));
420        assert!(rewritten.contains("'jan'"));
421        assert!(rewritten.contains("'feb'"));
422        assert!(rewritten.contains("'mar'"));
423        assert!(!rewritten.to_ascii_uppercase().contains("UNPIVOT"));
424    }
425
426    #[test]
427    fn returns_unchanged_when_no_unpivot() {
428        let sql = "SELECT * FROM t";
429        let result = rewrite_unpivot(sql).unwrap();
430        assert_eq!(result, sql);
431    }
432
433    #[test]
434    fn rewrite_pivot_unpivot_dispatches_correctly() {
435        let pivot_sql = "SELECT * FROM t PIVOT (SUM(v) FOR k IN ('a', 'b'))";
436        let result = rewrite_pivot_unpivot(pivot_sql).unwrap();
437        assert!(result.to_ascii_uppercase().contains("CASE WHEN"));
438
439        let unpivot_sql = "SELECT * FROM t UNPIVOT (val FOR month IN (jan, feb))";
440        let result2 = rewrite_pivot_unpivot(unpivot_sql).unwrap();
441        assert!(result2.to_ascii_uppercase().contains("UNION ALL"));
442
443        let plain = "SELECT * FROM t";
444        let result3 = rewrite_pivot_unpivot(plain).unwrap();
445        assert_eq!(result3, plain);
446    }
447}