pub fn parse_select_columns(sql: &str) -> Result<Vec<String>, String> {
extract_columns_regex(sql)
}
pub fn parse_select_columns_with_expressions(sql: &str) -> Result<Vec<(String, String)>, String> {
extract_columns_with_expressions_regex(sql)
}
fn extract_columns_regex(sql: &str) -> Result<Vec<String>, String> {
let mut columns = Vec::new();
let sql_lower = sql.to_lowercase();
let cte_offset = skip_cte_preamble(&sql_lower)?;
let select_start = sql_lower[cte_offset..]
.find("select")
.map(|p| p + cte_offset)
.ok_or("No SELECT keyword found")?;
let union_bound = find_outer_union(&sql_lower, select_start).unwrap_or(sql_lower.len());
let from_start =
find_outer_from(&sql_lower, select_start, union_bound).ok_or("No FROM keyword found")?;
if from_start <= select_start {
return Err("FROM appears before SELECT".to_string());
}
let col_start = skip_distinct_clause(&sql_lower, select_start + 6);
let select_clause = &sql[col_start..from_start].trim();
if select_clause.is_empty() {
return Err("No columns found in SELECT statement".to_string());
}
let parts = split_by_top_level_comma(select_clause);
for part in parts {
let trimmed = part.trim();
if trimmed.is_empty() {
continue;
}
let col_name = extract_column_name(trimmed)?;
columns.push(col_name);
}
if columns.is_empty() {
return Err("No columns found in SELECT statement".to_string());
}
Ok(columns)
}
fn find_outer_from(sql_lower: &str, after_pos: usize, end_bound: usize) -> Option<usize> {
let bytes = sql_lower.as_bytes();
let mut depth: i32 = 0;
let mut i = after_pos;
let len = bytes.len().min(end_bound);
while i < len {
match bytes[i] {
b'(' => {
depth += 1;
i += 1;
}
b')' => {
depth = depth.saturating_sub(1);
i += 1;
}
b'\'' => {
i += 1;
while i < len && bytes[i] != b'\'' {
i += 1;
}
if i < len {
i += 1;
} }
_ => {
if depth == 0 && i + 4 <= len && &bytes[i..i + 4] == b"from" {
let before_ok =
i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
let after_ok = i + 4 >= len
|| (!bytes[i + 4].is_ascii_alphanumeric() && bytes[i + 4] != b'_');
if before_ok && after_ok {
return Some(i);
}
}
i += 1;
}
}
}
None
}
pub fn find_outer_union(sql_lower: &str, start: usize) -> Option<usize> {
let bytes = sql_lower.as_bytes();
let len = bytes.len();
let mut depth: i32 = 0;
let mut i = start;
while i < len {
match bytes[i] {
b'(' => {
depth += 1;
i += 1;
}
b')' => {
depth = depth.saturating_sub(1);
i += 1;
}
b'\'' => {
i += 1;
while i < len && bytes[i] != b'\'' {
i += 1;
}
if i < len {
i += 1;
}
}
b'"' => {
i += 1;
while i < len && bytes[i] != b'"' {
i += 1;
}
if i < len {
i += 1;
}
}
_ => {
if depth == 0 && i + 5 <= len && &bytes[i..i + 5] == b"union" {
let before_ok =
i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
let after_ok = i + 5 >= len
|| (!bytes[i + 5].is_ascii_alphanumeric() && bytes[i + 5] != b'_');
if before_ok && after_ok {
return Some(i);
}
}
i += 1;
}
}
}
None
}
fn skip_distinct_clause(sql_lower: &str, after_select: usize) -> usize {
let bytes = sql_lower.as_bytes();
let len = bytes.len();
let mut i = after_select;
while i < len && bytes[i].is_ascii_whitespace() {
i += 1;
}
if i + 8 > len || &bytes[i..i + 8] != b"distinct" {
return after_select; }
let after_distinct = i + 8;
if after_distinct < len
&& (bytes[after_distinct].is_ascii_alphanumeric() || bytes[after_distinct] == b'_')
{
return after_select; }
i += 8;
while i < len && bytes[i].is_ascii_whitespace() {
i += 1;
}
if i + 2 <= len && &bytes[i..i + 2] == b"on" {
let after_on = i + 2;
if after_on >= len || (!bytes[after_on].is_ascii_alphanumeric() && bytes[after_on] != b'_')
{
i += 2;
while i < len && bytes[i].is_ascii_whitespace() {
i += 1;
}
if i < len
&& bytes[i] == b'('
&& let Ok(end) = skip_paren_block(bytes, i)
{
i = end;
}
}
}
i
}
pub fn extract_distinct_on_keys(sql: &str) -> Result<Vec<String>, String> {
let sql_lower = sql.to_lowercase();
let cte_offset = skip_cte_preamble(&sql_lower)?;
let select_start = sql_lower[cte_offset..]
.find("select")
.map(|p| p + cte_offset)
.ok_or("No SELECT keyword found")?;
let bytes = sql_lower.as_bytes();
let len = bytes.len();
let mut i = select_start + 6;
while i < len && bytes[i].is_ascii_whitespace() {
i += 1;
}
if i + 8 > len || &bytes[i..i + 8] != b"distinct" {
return Ok(vec![]);
}
let after_distinct = i + 8;
if after_distinct < len
&& (bytes[after_distinct].is_ascii_alphanumeric() || bytes[after_distinct] == b'_')
{
return Ok(vec![]);
}
i += 8;
while i < len && bytes[i].is_ascii_whitespace() {
i += 1;
}
if i + 2 > len || &bytes[i..i + 2] != b"on" {
return Ok(vec![]); }
let after_on = i + 2;
if after_on < len && (bytes[after_on].is_ascii_alphanumeric() || bytes[after_on] == b'_') {
return Ok(vec![]); }
i += 2;
while i < len && bytes[i].is_ascii_whitespace() {
i += 1;
}
if i >= len || bytes[i] != b'(' {
return Ok(vec![]);
}
let close = skip_paren_block(bytes, i)?;
let inner = &sql[i + 1..close - 1].trim();
let keys = split_by_top_level_comma(inner)
.into_iter()
.map(|k| {
let trimmed = k.trim().to_string();
trimmed
.split('.')
.next_back()
.unwrap_or(&trimmed)
.trim()
.to_string()
})
.filter(|k| !k.is_empty())
.collect();
Ok(keys)
}
fn skip_cte_preamble(sql_lower: &str) -> Result<usize, String> {
let bytes = sql_lower.as_bytes();
let len = bytes.len();
let mut i = 0;
while i < len && bytes[i].is_ascii_whitespace() {
i += 1;
}
if i + 4 > len || &bytes[i..i + 4] != b"with" {
return Ok(0);
}
let after_with = i + 4;
if after_with < len && (bytes[after_with].is_ascii_alphanumeric() || bytes[after_with] == b'_')
{
return Ok(0); }
i += 4;
while i < len && bytes[i].is_ascii_whitespace() {
i += 1;
}
if i + 9 <= len && &bytes[i..i + 9] == b"recursive" {
let after_rec = i + 9;
if after_rec >= len
|| (!bytes[after_rec].is_ascii_alphanumeric() && bytes[after_rec] != b'_')
{
return Err("WITH RECURSIVE is not supported in TVIEWs. \
Consider using a non-recursive CTE or a subquery."
.to_string());
}
}
loop {
while i < len && bytes[i].is_ascii_whitespace() {
i += 1;
}
if i >= len {
return Err("Unexpected end of SQL while parsing CTE preamble".to_string());
}
if bytes[i] == b'"' {
i += 1;
while i < len && bytes[i] != b'"' {
i += 1;
}
if i >= len {
return Err("Unterminated quoted identifier in CTE name".to_string());
}
i += 1; } else if bytes[i].is_ascii_alphabetic() || bytes[i] == b'_' {
while i < len && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
i += 1;
}
} else {
return Ok(i);
}
while i < len && bytes[i].is_ascii_whitespace() {
i += 1;
}
if i < len && bytes[i] == b'(' {
i = skip_paren_block(bytes, i)?;
while i < len && bytes[i].is_ascii_whitespace() {
i += 1;
}
}
if i + 2 > len || &bytes[i..i + 2] != b"as" {
return Err(format!(
"Expected AS keyword in CTE definition at byte offset {i}"
));
}
let after_as = i + 2;
if after_as < len && (bytes[after_as].is_ascii_alphanumeric() || bytes[after_as] == b'_') {
return Err(format!(
"Expected AS keyword in CTE definition at byte offset {i}"
));
}
i += 2;
while i < len && bytes[i].is_ascii_whitespace() {
i += 1;
}
if i >= len || bytes[i] != b'(' {
return Err(format!("Expected '(' for CTE body at byte offset {i}"));
}
i = skip_paren_block(bytes, i)?;
while i < len && bytes[i].is_ascii_whitespace() {
i += 1;
}
if i >= len {
return Err("SQL ends after CTE body with no main SELECT".to_string());
}
if bytes[i] == b',' {
i += 1; } else {
return Ok(i);
}
}
}
fn skip_paren_block(bytes: &[u8], start: usize) -> Result<usize, String> {
let len = bytes.len();
let mut i = start;
let mut depth: i32 = 0;
while i < len {
match bytes[i] {
b'(' => {
depth += 1;
i += 1;
}
b')' => {
depth -= 1;
i += 1;
if depth == 0 {
return Ok(i);
}
}
b'\'' => {
i += 1;
loop {
if i >= len {
break;
}
if bytes[i] == b'\'' {
i += 1;
if i < len && bytes[i] == b'\'' {
i += 1; } else {
break; }
} else {
i += 1;
}
}
}
b'"' => {
i += 1;
while i < len && bytes[i] != b'"' {
i += 1;
}
if i < len {
i += 1; }
}
_ => {
i += 1;
}
}
}
Err("Unbalanced parentheses in SQL".to_string())
}
fn extract_columns_with_expressions_regex(sql: &str) -> Result<Vec<(String, String)>, String> {
let mut columns = Vec::new();
let sql_lower = sql.to_lowercase();
let cte_offset = skip_cte_preamble(&sql_lower)?;
let select_start = sql_lower[cte_offset..]
.find("select")
.map(|p| p + cte_offset)
.ok_or("No SELECT keyword found")?;
let union_bound = find_outer_union(&sql_lower, select_start).unwrap_or(sql_lower.len());
let from_start =
find_outer_from(&sql_lower, select_start, union_bound).ok_or("No FROM keyword found")?;
if from_start <= select_start {
return Err("FROM appears before SELECT".to_string());
}
let col_start = skip_distinct_clause(&sql_lower, select_start + 6);
let select_clause = &sql[col_start..from_start].trim();
if select_clause.is_empty() {
return Err("No columns found in SELECT statement".to_string());
}
let parts = split_by_top_level_comma(select_clause);
for part in parts {
let trimmed = part.trim();
if trimmed.is_empty() {
continue;
}
let col_name = extract_column_name(trimmed)?;
columns.push((col_name, trimmed.to_string()));
}
if columns.is_empty() {
return Err("No columns found in SELECT statement".to_string());
}
Ok(columns)
}
fn split_by_top_level_comma(s: &str) -> Vec<String> {
let mut parts = Vec::new();
let mut current = String::new();
let mut paren_depth: i32 = 0;
let mut in_single_quote = false;
let mut in_double_quote = false;
let mut prev_char = '\0';
for c in s.chars() {
match c {
'(' if !in_single_quote && !in_double_quote => {
paren_depth += 1;
current.push(c);
}
')' if !in_single_quote && !in_double_quote => {
paren_depth = paren_depth.saturating_sub(1);
current.push(c);
}
'\'' if !in_double_quote => {
if prev_char != '\\' {
in_single_quote = !in_single_quote;
}
current.push(c);
}
'"' if !in_single_quote => {
if prev_char != '\\' {
in_double_quote = !in_double_quote;
}
current.push(c);
}
',' if paren_depth == 0 && !in_single_quote && !in_double_quote => {
parts.push(current.trim().to_string());
current.clear();
}
_ => {
current.push(c);
}
}
prev_char = c;
}
if !current.trim().is_empty() {
parts.push(current.trim().to_string());
}
parts
}
fn extract_column_name(part: &str) -> Result<String, String> {
let part_lower = part.to_lowercase();
if let Some(as_pos) = find_last_as(&part_lower) {
let alias_part = &part[as_pos + 2..].trim();
if alias_part.is_empty() {
return Err("Empty alias after AS".to_string());
}
return Ok((*alias_part).to_string());
}
let words: Vec<&str> = part.split_whitespace().collect();
if words.is_empty() {
return Err("Empty column expression".to_string());
}
let last_word = words
.last()
.ok_or_else(|| "Unexpected empty words vector".to_string())?;
let clean_name = last_word.trim_end_matches(|c: char| !c.is_alphanumeric() && c != '_');
if clean_name.is_empty() {
return Err("Could not extract column name".to_string());
}
let col = clean_name.split('.').next_back().unwrap_or(clean_name);
Ok(col.to_string())
}
fn find_last_as(sql_lower: &str) -> Option<usize> {
let mut last_as_pos = None;
let bytes = sql_lower.as_bytes();
for (i, _) in sql_lower.match_indices("as") {
let preceded_by_space = i == 0 || bytes[i - 1].is_ascii_whitespace();
let followed_by_space = i + 2 >= sql_lower.len() || bytes[i + 2].is_ascii_whitespace();
if !preceded_by_space || !followed_by_space {
continue;
}
let before = &sql_lower[..i];
let paren_depth = before.chars().fold(0i32, |depth, c| match c {
'(' => depth + 1,
')' => depth.saturating_sub(1),
_ => depth,
});
if paren_depth == 0 {
last_as_pos = Some(i);
}
}
last_as_pos
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_columns_simple() {
let sql = "SELECT id, name, data FROM users";
let cols = parse_select_columns(sql).unwrap();
assert_eq!(cols, vec!["id", "name", "data"]);
}
#[test]
fn test_extract_columns_with_alias() {
let sql = "SELECT u.id AS user_id, u.name, 'literal' AS data FROM users u";
let cols = parse_select_columns(sql).unwrap();
assert_eq!(cols, vec!["user_id", "name", "data"]);
}
#[test]
fn test_extract_columns_table_qualified() {
let sql = "SELECT u.id, u.name, p.title FROM users u JOIN posts p ON u.id = p.user_id";
let cols = parse_select_columns(sql).unwrap();
assert_eq!(cols, vec!["id", "name", "title"]);
}
#[test]
fn test_extract_columns_complex_expression() {
let sql =
"SELECT pk_post, id, jsonb_build_object('id', id, 'title', title) AS data FROM posts";
let cols = parse_select_columns(sql).unwrap();
assert_eq!(cols, vec!["pk_post", "id", "data"]);
}
#[test]
fn test_extract_columns_empty_select() {
let sql = "SELECT FROM users";
let result = parse_select_columns(sql);
assert!(result.is_err());
assert!(result.unwrap_err().contains("No columns"));
}
#[test]
fn test_extract_columns_no_select() {
let sql = "FROM users SELECT id";
let result = parse_select_columns(sql);
assert!(result.is_err());
assert!(result.unwrap_err().contains("No FROM keyword found"));
}
#[test]
fn test_extract_column_name_simple() {
assert_eq!(extract_column_name("id").unwrap(), "id");
assert_eq!(extract_column_name("pk_post").unwrap(), "pk_post");
assert_eq!(extract_column_name("u.name").unwrap(), "name");
}
#[test]
fn test_extract_column_name_with_alias() {
assert_eq!(extract_column_name("u.id AS user_id").unwrap(), "user_id");
assert_eq!(
extract_column_name("jsonb_build_object('key', 'value') AS data").unwrap(),
"data"
);
}
#[test]
fn test_find_last_as() {
assert_eq!(find_last_as("id as user_id"), Some(3));
assert_eq!(
find_last_as("jsonb_build_object('id', id) as data"),
Some(29)
);
assert_eq!(find_last_as("id"), None);
}
#[test]
fn test_find_last_as_nested() {
let sql = "jsonb_build_object('id', id) as data, name as full_name";
assert_eq!(find_last_as(sql), Some(43)); }
#[test]
fn test_parse_cte_columns() {
let sql = "WITH labels AS (SELECT item_id, label FROM tb_i18n) \
SELECT i.pk_item, i.id, i.name, l.label AS data \
FROM tb_item i LEFT JOIN labels l ON l.item_id = i.pk_item";
let cols = parse_select_columns(sql).unwrap();
assert!(
cols.contains(&"pk_item".to_string()),
"expected pk_item, got {cols:?}"
);
assert!(
cols.contains(&"id".to_string()),
"expected id, got {cols:?}"
);
assert!(
cols.contains(&"name".to_string()),
"expected name, got {cols:?}"
);
assert!(
cols.contains(&"data".to_string()),
"expected data, got {cols:?}"
);
assert!(
!cols.contains(&"item_id".to_string()),
"CTE-only column item_id leaked: {cols:?}"
);
assert!(
!cols.contains(&"label".to_string()),
"CTE-only column label leaked: {cols:?}"
);
}
#[test]
fn test_parse_cte_columns_with_expressions() {
let sql = "WITH labels AS (SELECT item_id, label FROM tb_i18n) \
SELECT i.pk_item, i.id, i.name, l.label AS data \
FROM tb_item i LEFT JOIN labels l ON l.item_id = i.pk_item";
let cols = parse_select_columns_with_expressions(sql).unwrap();
let names: Vec<&str> = cols.iter().map(|(n, _)| n.as_str()).collect();
assert!(
names.contains(&"pk_item"),
"expected pk_item, got {names:?}"
);
assert!(names.contains(&"data"), "expected data, got {names:?}");
assert!(
!names.contains(&"item_id"),
"CTE-only column item_id leaked: {names:?}"
);
}
#[test]
fn test_parse_multiple_ctes() {
let sql = "WITH a AS (SELECT x FROM t1), b AS (SELECT y FROM a) \
SELECT pk_item, id, b.y AS data FROM tb_item JOIN b ON b.y = tb_item.pk_item";
let cols = parse_select_columns(sql).unwrap();
assert_eq!(cols, vec!["pk_item", "id", "data"]);
}
#[test]
fn test_parse_cte_with_explicit_column_list() {
let sql = "WITH labeled(item_id, label) AS (SELECT item_id, label FROM tb_i18n) \
SELECT pk_item, id, label AS data FROM tb_item \
JOIN labeled ON labeled.item_id = tb_item.pk_item";
let cols = parse_select_columns(sql).unwrap();
assert!(cols.contains(&"pk_item".to_string()));
assert!(cols.contains(&"data".to_string()));
assert!(!cols.contains(&"item_id".to_string()));
}
#[test]
fn test_parse_cte_with_paren_in_string_literal() {
let sql = "WITH filtered AS (SELECT id FROM t WHERE name = 'a)b') \
SELECT pk_item, id, name AS data FROM tb_item";
let cols = parse_select_columns(sql).unwrap();
assert_eq!(cols, vec!["pk_item", "id", "data"]);
}
#[test]
fn test_parse_cte_with_nested_subquery() {
let sql = "WITH top AS (SELECT id FROM t WHERE id IN (SELECT id FROM s)) \
SELECT pk_item, id, name AS data FROM tb_item";
let cols = parse_select_columns(sql).unwrap();
assert_eq!(cols, vec!["pk_item", "id", "data"]);
}
#[test]
fn test_recursive_cte_rejected() {
let sql = "WITH RECURSIVE tree AS (SELECT 1) SELECT pk_item, id, name AS data FROM tb_item";
let result = parse_select_columns(sql);
assert!(result.is_err(), "expected error for WITH RECURSIVE");
assert!(
result.unwrap_err().contains("RECURSIVE"),
"error should mention RECURSIVE"
);
}
#[test]
fn test_recursive_cte_rejected_lowercase() {
let sql = "with recursive tree as (select 1) select pk_item, id, name as data from tb_item";
let result = parse_select_columns(sql);
assert!(result.is_err(), "expected error for with recursive");
assert!(result.unwrap_err().contains("RECURSIVE"));
}
#[test]
fn test_skip_paren_block_simple() {
let s = "(hello world) rest";
let end = skip_paren_block(s.as_bytes(), 0).unwrap();
assert_eq!(end, 13); }
#[test]
fn test_skip_paren_block_nested() {
let s = "((a) (b)) rest";
let end = skip_paren_block(s.as_bytes(), 0).unwrap();
assert_eq!(end, 9);
}
#[test]
fn test_skip_paren_block_string_with_paren() {
let s = "(WHERE name = 'a)b') rest";
let end = skip_paren_block(s.as_bytes(), 0).unwrap();
assert_eq!(end, 20);
}
#[test]
fn test_non_cte_sql_unchanged() {
let sql = "SELECT pk_post, id, jsonb_build_object('id', id) AS data FROM tb_post";
let cols = parse_select_columns(sql).unwrap();
assert_eq!(cols, vec!["pk_post", "id", "data"]);
}
#[test]
fn test_parse_distinct_on_columns() {
let sql = "SELECT DISTINCT ON (c.id) c.pk_contract, c.id, c.name, \
jsonb_build_object('id', c.id, 'name', c.name) AS data \
FROM tenant.tb_contract c ORDER BY c.id, c.version DESC";
let cols = parse_select_columns(sql).unwrap();
assert_eq!(
cols,
vec!["pk_contract", "id", "name", "data"],
"got: {cols:?}"
);
}
#[test]
fn test_parse_distinct_on_columns_with_expressions() {
let sql = "SELECT DISTINCT ON (c.id) c.pk_contract, c.id, c.name, \
jsonb_build_object('id', c.id) AS data \
FROM tenant.tb_contract c ORDER BY c.id, c.version DESC";
let cols = parse_select_columns_with_expressions(sql).unwrap();
let names: Vec<&str> = cols.iter().map(|(n, _)| n.as_str()).collect();
assert!(names.contains(&"pk_contract"), "got: {names:?}");
assert!(names.contains(&"data"), "got: {names:?}");
}
#[test]
fn test_parse_plain_distinct_columns() {
let sql = "SELECT DISTINCT pk_item, id, name AS data FROM tb_item";
let cols = parse_select_columns(sql).unwrap();
assert_eq!(cols, vec!["pk_item", "id", "data"]);
}
#[test]
fn test_parse_composite_distinct_on() {
let sql = "SELECT DISTINCT ON (c.tenant_id, c.id) c.pk_contract, c.id, c.name AS data \
FROM tb_contract c ORDER BY c.tenant_id, c.id, c.version DESC";
let cols = parse_select_columns(sql).unwrap();
assert_eq!(cols, vec!["pk_contract", "id", "data"]);
}
#[test]
fn test_extract_distinct_on_keys_single() {
let sql = "SELECT DISTINCT ON (c.id) c.pk_contract, c.id FROM tb_contract c";
let keys = extract_distinct_on_keys(sql).unwrap();
assert_eq!(keys, vec!["id"]);
}
#[test]
fn test_extract_distinct_on_keys_composite() {
let sql = "SELECT DISTINCT ON (c.tenant_id, c.id) c.pk_contract, c.id FROM tb_contract c";
let keys = extract_distinct_on_keys(sql).unwrap();
assert_eq!(keys, vec!["tenant_id", "id"]);
}
#[test]
fn test_extract_distinct_on_keys_none() {
let sql = "SELECT pk_post, id, name FROM tb_post";
let keys = extract_distinct_on_keys(sql).unwrap();
assert!(keys.is_empty(), "expected no keys, got {keys:?}");
}
#[test]
fn test_extract_distinct_on_keys_plain_distinct() {
let sql = "SELECT DISTINCT pk_post, id FROM tb_post";
let keys = extract_distinct_on_keys(sql).unwrap();
assert!(keys.is_empty());
}
#[test]
fn test_extract_distinct_on_keys_with_cte() {
let sql = "WITH cte AS (SELECT 1) SELECT DISTINCT ON (c.id) c.pk_contract, c.id FROM tb_contract c";
let keys = extract_distinct_on_keys(sql).unwrap();
assert_eq!(keys, vec!["id"]);
}
#[test]
fn test_non_distinct_sql_unchanged() {
let sql = "SELECT pk_post, id, data FROM tb_post JOIN tb_other ON TRUE";
let cols = parse_select_columns(sql).unwrap();
assert_eq!(cols, vec!["pk_post", "id", "data"]);
}
#[test]
fn test_parse_union_all_columns() {
let sql = "SELECT i.pk_item, i.id, l.label AS name, \
jsonb_build_object('id', i.id, 'name', l.label) AS data \
FROM catalog.tb_item i \
JOIN catalog.tb_item_i18n l ON l.item_id = i.pk_item \
WHERE l.locale = 'en' \
UNION ALL \
SELECT i.pk_item, i.id, i.name, \
jsonb_build_object('id', i.id, 'name', i.name) AS data \
FROM catalog.tb_item i \
WHERE NOT EXISTS (SELECT 1 FROM catalog.tb_item_i18n l \
WHERE l.item_id = i.pk_item AND l.locale = 'en')";
let cols = parse_select_columns(sql).unwrap();
assert_eq!(cols, vec!["pk_item", "id", "name", "data"]);
}
#[test]
fn test_parse_union_deduplicated_columns() {
let sql = "SELECT pk_post, id, title AS name FROM tb_post \
UNION \
SELECT pk_post, id, slug AS name FROM tb_post_draft";
let cols = parse_select_columns(sql).unwrap();
assert_eq!(cols, vec!["pk_post", "id", "name"]);
}
#[test]
fn test_union_in_string_literal_not_matched() {
let sql = "SELECT pk_post, id, 'UNION ALL rocks' AS label FROM tb_post";
let cols = parse_select_columns(sql).unwrap();
assert_eq!(cols, vec!["pk_post", "id", "label"]);
}
#[test]
fn test_union_inside_subquery_not_matched() {
let sql = "SELECT pk_post, id, (SELECT MAX(v) FROM (SELECT 1 AS v UNION SELECT 2 AS v) s) AS top \
FROM tb_post";
let cols = parse_select_columns(sql).unwrap();
assert_eq!(cols, vec!["pk_post", "id", "top"]);
}
#[test]
fn test_parse_cte_with_union_all() {
let sql = "WITH fallback AS (SELECT pk_item, id, name FROM tb_item) \
SELECT pk_item, id, name AS label FROM tb_item \
UNION ALL \
SELECT pk_item, id, name AS label FROM fallback";
let cols = parse_select_columns(sql).unwrap();
assert_eq!(cols, vec!["pk_item", "id", "label"]);
}
#[test]
fn test_parse_distinct_on_union_all() {
let sql = "SELECT DISTINCT ON (c.id) c.pk_contract, c.id, c.name, \
jsonb_build_object('id', c.id) AS data \
FROM tb_contract c ORDER BY c.id, c.version DESC \
UNION ALL \
SELECT pk_draft, id, name, \
jsonb_build_object('id', id) AS data \
FROM tb_contract_draft";
let cols = parse_select_columns(sql).unwrap();
assert_eq!(cols[0], "pk_contract");
assert_eq!(cols[1], "id");
assert_eq!(cols[2], "name");
assert_eq!(cols[3], "data");
}
}