use krishiv_plan::NodeOp;
pub fn contains_unnest(sql: &str) -> bool {
let upper = sql.to_ascii_uppercase();
upper.contains("UNNEST(") || upper.contains("UNNEST (")
}
pub fn contains_lateral(sql: &str) -> bool {
sql.to_ascii_uppercase().contains(" LATERAL ")
}
pub fn rewrite_lateral_unnest(sql: &str) -> String {
if !contains_lateral(sql) || !contains_unnest(sql) {
return sql.to_owned();
}
let patterns: &[(&str, &str)] = &[
(", LATERAL UNNEST(", " CROSS JOIN UNNEST("),
(",LATERAL UNNEST(", " CROSS JOIN UNNEST("),
(" LATERAL UNNEST(", " CROSS JOIN UNNEST("),
];
let mut result = sql.to_owned();
for (from, to) in patterns {
let upper_from = from.to_ascii_uppercase();
let mut search_start = 0;
loop {
let upper_result = result[search_start..].to_ascii_uppercase();
match upper_result.find(&upper_from) {
Some(rel_pos) => {
let pos = search_start + rel_pos;
result.replace_range(pos..pos + from.len(), to);
search_start = pos + to.len();
}
None => break,
}
}
}
result
}
pub fn build_unnest_op(
array_column: impl Into<String>,
output_column: impl Into<String>,
with_ordinality: bool,
) -> NodeOp {
NodeOp::Unnest {
array_column: array_column.into(),
output_column: output_column.into(),
with_ordinality,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn detects_unnest_call() {
assert!(contains_unnest("SELECT UNNEST(tags) FROM t"));
assert!(contains_unnest(
"SELECT * FROM t, LATERAL UNNEST(t.ids) AS id(v)"
));
assert!(!contains_unnest("SELECT * FROM t WHERE x = 1"));
}
#[test]
fn detects_lateral() {
assert!(contains_lateral(
"SELECT * FROM t, LATERAL UNNEST(t.ids) AS id(v)"
));
assert!(!contains_lateral(
"SELECT * FROM t CROSS JOIN UNNEST(t.ids)"
));
}
#[test]
fn rewrites_lateral_unnest_with_comma() {
let sql = "SELECT * FROM t, LATERAL UNNEST(t.tags) AS tag(value)";
let rewritten = rewrite_lateral_unnest(sql);
assert!(!rewritten.to_ascii_uppercase().contains(" LATERAL "));
assert!(rewritten.to_ascii_uppercase().contains("CROSS JOIN UNNEST"));
}
#[test]
fn rewrites_lateral_unnest_preserves_alias() {
let sql = "SELECT t.id, tag.value FROM t, LATERAL UNNEST(t.tags) AS tag(value)";
let rewritten = rewrite_lateral_unnest(sql);
assert!(
rewritten.contains("tag(value)"),
"alias preserved: {rewritten}"
);
}
#[test]
fn passthrough_plain_unnest() {
let sql = "SELECT id, UNNEST(tags) AS tag FROM t";
let rewritten = rewrite_lateral_unnest(sql);
assert_eq!(rewritten, sql, "unchanged: {rewritten}");
}
#[test]
fn passthrough_non_lateral_unnest_unchanged() {
let sql = "SELECT * FROM t CROSS JOIN UNNEST(t.ids) AS elem(v)";
let rewritten = rewrite_lateral_unnest(sql);
assert_eq!(rewritten, sql);
}
#[test]
fn build_unnest_op_returns_correct_variant() {
let op = build_unnest_op("tags", "tag", false);
match op {
NodeOp::Unnest {
array_column,
output_column,
with_ordinality,
} => {
assert_eq!(array_column, "tags");
assert_eq!(output_column, "tag");
assert!(!with_ordinality);
}
_ => panic!("expected Unnest variant"),
}
}
#[test]
fn build_unnest_op_with_ordinality() {
let op = build_unnest_op("items", "item", true);
match op {
NodeOp::Unnest {
with_ordinality, ..
} => assert!(with_ordinality),
_ => panic!("expected Unnest"),
}
}
}