1use crate::{SqlError, SqlResult};
37
38pub fn contains_pivot(sql: &str) -> bool {
42 sql.to_ascii_uppercase().contains(" PIVOT (") || sql.to_ascii_uppercase().contains(" PIVOT(")
43}
44
45pub fn contains_unpivot(sql: &str) -> bool {
47 sql.to_ascii_uppercase().contains(" UNPIVOT (")
48 || sql.to_ascii_uppercase().contains(" UNPIVOT(")
49}
50
51#[derive(Debug, Clone)]
55pub struct PivotClause {
56 pub agg_fn: String,
58 pub agg_column: String,
60 pub for_column: String,
62 pub in_values: Vec<String>,
64 pub source: String,
66}
67
68pub fn parse_pivot(sql: &str) -> SqlResult<Option<PivotClause>> {
72 let upper = sql.to_ascii_uppercase();
73 let pivot_kw = " PIVOT (";
74 let pivot_pos = match upper.find(pivot_kw) {
75 Some(p) => p,
76 None => {
77 match upper.find(" PIVOT(") {
79 Some(p) => p,
80 None => return Ok(None),
81 }
82 }
83 };
84
85 let source = sql[..pivot_pos].trim().to_owned();
86
87 let body_start = pivot_pos + pivot_kw.len();
89 let body_end = find_closing_paren(&sql[body_start..]).ok_or_else(|| SqlError::Unsupported {
90 feature: "PIVOT: unmatched parenthesis".into(),
91 })? + body_start;
92
93 let body = sql[body_start..body_end].trim();
94 let body_upper = body.to_ascii_uppercase();
95
96 let for_pos = body_upper
98 .find(" FOR ")
99 .ok_or_else(|| SqlError::Unsupported {
100 feature: "PIVOT: missing FOR keyword".into(),
101 })?;
102 let in_pos = body_upper
103 .find(" IN (")
104 .ok_or_else(|| SqlError::Unsupported {
105 feature: "PIVOT: missing IN keyword".into(),
106 })?;
107
108 let agg_expr = body[..for_pos].trim();
109 let for_column = body[for_pos + 5..in_pos].trim().to_owned();
110
111 let lp = agg_expr.find('(').ok_or_else(|| SqlError::Unsupported {
113 feature: "PIVOT: aggregation must be in the form AGG(column)".into(),
114 })?;
115 let rp = agg_expr.rfind(')').ok_or_else(|| SqlError::Unsupported {
116 feature: "PIVOT: aggregation must end with ')'".into(),
117 })?;
118 let agg_fn = agg_expr[..lp].trim().to_owned();
119 let agg_column = agg_expr[lp + 1..rp].trim().to_owned();
120
121 let in_list_start = in_pos + 5;
123 let in_list_end = body[in_list_start..]
124 .find(')')
125 .ok_or_else(|| SqlError::Unsupported {
126 feature: "PIVOT: IN list is not closed".into(),
127 })?
128 + in_list_start;
129 let in_list = &body[in_list_start..in_list_end];
130
131 let in_values: Vec<String> = in_list
132 .split(',')
133 .map(|v| v.trim().to_owned())
134 .filter(|v| !v.is_empty())
135 .collect();
136
137 if in_values.is_empty() {
138 return Err(SqlError::Unsupported {
139 feature: "PIVOT: IN list must contain at least one value".into(),
140 });
141 }
142
143 Ok(Some(PivotClause {
144 agg_fn,
145 agg_column,
146 for_column,
147 in_values,
148 source,
149 }))
150}
151
152pub fn rewrite_pivot(sql: &str) -> SqlResult<String> {
156 let Some(pivot) = parse_pivot(sql)? else {
157 return Ok(sql.to_owned());
158 };
159
160 let mut cols = Vec::with_capacity(pivot.in_values.len());
161 for val in &pivot.in_values {
162 let alias = val.trim_matches('\'').trim_matches('"');
164 cols.push(format!(
165 "{}(CASE WHEN {} = {} THEN {} END) AS \"{}\"",
166 pivot.agg_fn, pivot.for_column, val, pivot.agg_column, alias,
167 ));
168 }
169
170 let from_clause = strip_select_star_prefix(&pivot.source);
172
173 Ok(format!("SELECT {} FROM {}", cols.join(", "), from_clause))
174}
175
176#[derive(Debug, Clone)]
180pub struct UnpivotClause {
181 pub value_column: String,
183 pub name_column: String,
185 pub in_columns: Vec<String>,
187 pub source: String,
189}
190
191pub fn parse_unpivot(sql: &str) -> SqlResult<Option<UnpivotClause>> {
195 let upper = sql.to_ascii_uppercase();
196 let kw = " UNPIVOT (";
197 let kw_short = " UNPIVOT(";
198 let unpivot_pos = match upper.find(kw) {
199 Some(p) => p,
200 None => match upper.find(kw_short) {
201 Some(p) => p,
202 None => return Ok(None),
203 },
204 };
205
206 let source = sql[..unpivot_pos].trim().to_owned();
207 let body_start = unpivot_pos
208 + sql[unpivot_pos..]
209 .find('(')
210 .ok_or_else(|| SqlError::Unsupported {
211 feature: "UNPIVOT: missing opening parenthesis".into(),
212 })?
213 + 1;
214 let body_end = find_closing_paren(&sql[body_start..]).ok_or_else(|| SqlError::Unsupported {
215 feature: "UNPIVOT: unmatched parenthesis".into(),
216 })? + body_start;
217 let body = sql[body_start..body_end].trim();
218 let body_upper = body.to_ascii_uppercase();
219
220 let for_pos = body_upper
221 .find(" FOR ")
222 .ok_or_else(|| SqlError::Unsupported {
223 feature: "UNPIVOT: missing FOR keyword".into(),
224 })?;
225 let in_pos = body_upper
226 .find(" IN (")
227 .ok_or_else(|| SqlError::Unsupported {
228 feature: "UNPIVOT: missing IN keyword".into(),
229 })?;
230
231 let value_column = body[..for_pos].trim().to_owned();
232 let name_column = body[for_pos + 5..in_pos].trim().to_owned();
233
234 let in_list_start = in_pos + 5;
235 let in_list_end = body[in_list_start..]
236 .find(')')
237 .ok_or_else(|| SqlError::Unsupported {
238 feature: "UNPIVOT: IN list is not closed".into(),
239 })?
240 + in_list_start;
241 let in_list = &body[in_list_start..in_list_end];
242
243 let in_columns: Vec<String> = in_list
244 .split(',')
245 .map(|v| v.trim().to_owned())
246 .filter(|v| !v.is_empty())
247 .collect();
248
249 if in_columns.is_empty() {
250 return Err(SqlError::Unsupported {
251 feature: "UNPIVOT: IN list must contain at least one column".into(),
252 });
253 }
254
255 Ok(Some(UnpivotClause {
256 value_column,
257 name_column,
258 in_columns,
259 source,
260 }))
261}
262
263pub fn rewrite_unpivot(sql: &str) -> SqlResult<String> {
267 let Some(unpivot) = parse_unpivot(sql)? else {
268 return Ok(sql.to_owned());
269 };
270
271 let from_clause = strip_select_star_prefix(&unpivot.source);
272
273 let mut branches = Vec::with_capacity(unpivot.in_columns.len());
274 for col in &unpivot.in_columns {
275 let col_quoted = col.replace('"', "\"\"");
277 let name_col_quoted = unpivot.name_column.replace('"', "\"\"");
278 let val_col_quoted = unpivot.value_column.replace('"', "\"\"");
279 branches.push(format!(
280 "SELECT '{}' AS \"{}\", \"{}\" AS \"{}\" FROM {}",
281 col.replace('\'', "''"),
282 name_col_quoted,
283 col_quoted,
284 val_col_quoted,
285 from_clause,
286 ));
287 }
288
289 Ok(branches.join(" UNION ALL "))
290}
291
292pub fn rewrite_pivot_unpivot(sql: &str) -> SqlResult<String> {
294 if contains_pivot(sql) {
295 rewrite_pivot(sql)
296 } else if contains_unpivot(sql) {
297 rewrite_unpivot(sql)
298 } else {
299 Ok(sql.to_owned())
300 }
301}
302
303fn find_closing_paren(s: &str) -> Option<usize> {
310 let mut depth = 1usize;
311 for (i, ch) in s.char_indices() {
312 match ch {
313 '(' => depth += 1,
314 ')' => {
315 depth -= 1;
316 if depth == 0 {
317 return Some(i);
318 }
319 }
320 _ => {}
321 }
322 }
323 None
324}
325
326fn strip_select_star_prefix(s: &str) -> &str {
329 let upper = s.to_ascii_uppercase();
330 if let Some(from_pos) = upper.rfind(" FROM ") {
331 s[from_pos + 6..].trim()
332 } else {
333 s.trim()
334 }
335}
336
337#[cfg(test)]
340mod tests {
341 use super::*;
342
343 #[test]
346 fn detects_pivot() {
347 assert!(contains_pivot(
348 "SELECT * FROM t PIVOT (SUM(x) FOR y IN ('a'))"
349 ));
350 assert!(!contains_pivot("SELECT * FROM t WHERE x = 1"));
351 }
352
353 #[test]
354 fn parses_pivot() {
355 let sql = "SELECT * FROM sales PIVOT (SUM(amount) FOR category IN ('food', 'tech'))";
356 let pivot = parse_pivot(sql).unwrap().unwrap();
357 assert_eq!(pivot.agg_fn, "SUM");
358 assert_eq!(pivot.agg_column, "amount");
359 assert_eq!(pivot.for_column, "category");
360 assert_eq!(pivot.in_values, vec!["'food'", "'tech'"]);
361 }
362
363 #[test]
364 fn rewrites_pivot_to_case_when() {
365 let sql = "SELECT * FROM sales PIVOT (SUM(amount) FOR category IN ('food', 'tech'))";
366 let rewritten = rewrite_pivot(sql).unwrap();
367 assert!(rewritten.to_ascii_uppercase().contains("CASE WHEN"));
368 assert!(rewritten.to_ascii_uppercase().contains("SUM("));
369 assert!(rewritten.contains("'food'"));
370 assert!(rewritten.contains("'tech'"));
371 assert!(!rewritten.to_ascii_uppercase().contains("PIVOT"));
372 }
373
374 #[test]
375 fn pivot_rewrite_generates_correct_aliases() {
376 let sql = "SELECT * FROM t PIVOT (MAX(val) FOR dim IN ('x', 'y'))";
377 let rewritten = rewrite_pivot(sql).unwrap();
378 assert!(rewritten.contains("\"x\""));
379 assert!(rewritten.contains("\"y\""));
380 }
381
382 #[test]
383 fn returns_unchanged_when_no_pivot() {
384 let sql = "SELECT * FROM t WHERE x = 1";
385 let result = rewrite_pivot(sql).unwrap();
386 assert_eq!(result, sql);
387 }
388
389 #[test]
390 fn rejects_pivot_without_for() {
391 let sql = "SELECT * FROM t PIVOT (SUM(x) IN ('a'))";
392 let err = parse_pivot(sql).unwrap_err();
393 assert!(matches!(err, SqlError::Unsupported { .. }));
394 }
395
396 #[test]
399 fn detects_unpivot() {
400 assert!(contains_unpivot(
401 "SELECT * FROM t UNPIVOT (val FOR month IN (jan, feb))"
402 ));
403 assert!(!contains_unpivot("SELECT * FROM t WHERE x = 1"));
404 }
405
406 #[test]
407 fn parses_unpivot() {
408 let sql = "SELECT * FROM monthly UNPIVOT (value FOR month IN (jan, feb, mar))";
409 let unpivot = parse_unpivot(sql).unwrap().unwrap();
410 assert_eq!(unpivot.value_column, "value");
411 assert_eq!(unpivot.name_column, "month");
412 assert_eq!(unpivot.in_columns, vec!["jan", "feb", "mar"]);
413 }
414
415 #[test]
416 fn rewrites_unpivot_to_union_all() {
417 let sql = "SELECT * FROM monthly UNPIVOT (value FOR month IN (jan, feb, mar))";
418 let rewritten = rewrite_unpivot(sql).unwrap();
419 assert!(rewritten.to_ascii_uppercase().contains("UNION ALL"));
420 assert!(rewritten.contains("'jan'"));
421 assert!(rewritten.contains("'feb'"));
422 assert!(rewritten.contains("'mar'"));
423 assert!(!rewritten.to_ascii_uppercase().contains("UNPIVOT"));
424 }
425
426 #[test]
427 fn returns_unchanged_when_no_unpivot() {
428 let sql = "SELECT * FROM t";
429 let result = rewrite_unpivot(sql).unwrap();
430 assert_eq!(result, sql);
431 }
432
433 #[test]
434 fn rewrite_pivot_unpivot_dispatches_correctly() {
435 let pivot_sql = "SELECT * FROM t PIVOT (SUM(v) FOR k IN ('a', 'b'))";
436 let result = rewrite_pivot_unpivot(pivot_sql).unwrap();
437 assert!(result.to_ascii_uppercase().contains("CASE WHEN"));
438
439 let unpivot_sql = "SELECT * FROM t UNPIVOT (val FOR month IN (jan, feb))";
440 let result2 = rewrite_pivot_unpivot(unpivot_sql).unwrap();
441 assert!(result2.to_ascii_uppercase().contains("UNION ALL"));
442
443 let plain = "SELECT * FROM t";
444 let result3 = rewrite_pivot_unpivot(plain).unwrap();
445 assert_eq!(result3, plain);
446 }
447}