#[cfg(feature = "column-lineage")]
use polyglot_sql::expressions::Expression;
use regex::Regex;
use std::sync::LazyLock;
static JINJA_TAG: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\{\{-?[\s\S]*?-?\}\}|\{%-?[\s\S]*?-?%\}").unwrap());
static JINJA_COMMENT: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\{#[\s\S]*?#\}").unwrap());
pub fn extract_select_columns(sql: &str) -> Vec<String> {
let cleaned = JINJA_COMMENT.replace_all(sql, "");
let cleaned = JINJA_TAG.replace_all(&cleaned, "__jinja__");
let select_end = match find_last_top_level_select(&cleaned) {
Some(end) => end,
None => return vec![],
};
let after_select = &cleaned[select_end..];
let select_body = match find_top_level_from(after_select) {
Some(pos) => &after_select[..pos],
None => return vec![],
};
let items = split_top_level_commas(select_body);
items
.iter()
.filter_map(|item| classify_select_item(item.trim()))
.collect()
}
fn classify_select_item(item: &str) -> Option<String> {
if item.is_empty() {
return None;
}
if item.starts_with('(') {
return extract_alias_after_paren(item);
}
let col = extract_column_name(item);
if col.is_empty() { None } else { Some(col) }
}
fn find_last_top_level_select(s: &str) -> Option<usize> {
let bytes = s.as_bytes();
let len = bytes.len();
let mut depth: u32 = 0;
let mut last_select_end: Option<usize> = None;
let mut i = 0;
while i < len {
match bytes[i] {
b'(' => depth += 1,
b')' => {
depth = depth.saturating_sub(1);
}
b's' | b'S' if depth == 0 && check_keyword_at(bytes, i, len, b"SELECT") => {
let end = i + 6;
let after = skip_whitespace(bytes, end, len);
if check_keyword_at(bytes, after, len, b"DISTINCT") {
let after_distinct = skip_whitespace(bytes, after + 8, len);
last_select_end = Some(after_distinct);
} else {
last_select_end = Some(after);
}
}
_ => {}
}
i += 1;
}
last_select_end
}
fn check_keyword_at(bytes: &[u8], i: usize, len: usize, keyword: &[u8]) -> bool {
let klen = keyword.len();
if i + klen > len {
return false;
}
for j in 0..klen {
if !bytes[i + j].eq_ignore_ascii_case(&keyword[j]) {
return false;
}
}
let before_ok = i == 0 || is_word_boundary(bytes[i - 1]);
let after_ok = i + klen >= len || is_word_boundary(bytes[i + klen]);
before_ok && after_ok
}
fn skip_whitespace(bytes: &[u8], start: usize, len: usize) -> usize {
let mut i = start;
while i < len && bytes[i].is_ascii_whitespace() {
i += 1;
}
i
}
fn is_word_boundary(b: u8) -> bool {
!b.is_ascii_alphanumeric() && b != b'_'
}
fn check_from_at(_s: &str, bytes: &[u8], i: usize, len: usize) -> bool {
if i + 4 > len {
return false;
}
let from_match = matches!(bytes[i], b'f' | b'F')
&& matches!(bytes[i + 1], b'r' | b'R')
&& matches!(bytes[i + 2], b'o' | b'O')
&& matches!(bytes[i + 3], b'm' | b'M');
if !from_match {
return false;
}
let before_ok = i == 0 || is_word_boundary(bytes[i - 1]);
let after_ok = i + 4 >= len || is_word_boundary(bytes[i + 4]);
before_ok && after_ok
}
fn find_top_level_from(s: &str) -> Option<usize> {
let bytes = s.as_bytes();
let len = bytes.len();
let mut depth: u32 = 0;
let mut i = 0;
while i < len {
match bytes[i] {
b'(' => depth += 1,
b')' => {
depth = depth.saturating_sub(1);
}
b'f' | b'F' if depth == 0 && check_from_at(s, bytes, i, len) => {
return Some(i);
}
_ => {}
}
i += 1;
}
None
}
fn split_top_level_commas(s: &str) -> Vec<String> {
let mut items = Vec::new();
let mut current = String::new();
let mut depth = 0;
for ch in s.chars() {
match ch {
'(' => {
depth += 1;
current.push(ch);
}
')' => {
depth -= 1;
current.push(ch);
}
',' if depth == 0 => {
items.push(current.clone());
current.clear();
}
_ => {
current.push(ch);
}
}
}
if !current.trim().is_empty() {
items.push(current);
}
items
}
fn extract_alias_after_paren(item: &str) -> Option<String> {
let close = item.rfind(')')?;
let after = item[close + 1..].trim();
if after.is_empty() {
return None;
}
let after = if after.len() >= 3
&& matches!(after.as_bytes()[0], b'a' | b'A')
&& matches!(after.as_bytes()[1], b's' | b'S')
&& after.as_bytes()[2].is_ascii_whitespace()
{
after[2..].trim()
} else {
after
};
if after.is_empty() {
None
} else {
Some(clean_identifier(after))
}
}
fn extract_column_name(item: &str) -> String {
let item = item.trim();
if let Some(alias) = find_last_as_alias(item) {
return clean_identifier(&alias);
}
let last_token = item.split_whitespace().last().unwrap_or(item);
if let Some(pos) = last_token.rfind('.') {
return clean_identifier(&last_token[pos + 1..]);
}
clean_identifier(last_token)
}
fn is_as_keyword_at(_item: &str, bytes: &[u8], i: usize, len: usize) -> Option<usize> {
if i + 3 >= len {
return None;
}
let as_match = matches!(bytes[i + 1], b'a' | b'A') && matches!(bytes[i + 2], b's' | b'S');
if as_match && matches!(bytes[i + 3], b' ' | b'\t' | b'\n' | b'\r') {
Some(i + 4)
} else {
None
}
}
fn find_last_as_alias(item: &str) -> Option<String> {
let bytes = item.as_bytes();
let len = bytes.len();
let mut depth = 0;
let mut last_as_pos: Option<usize> = None;
let mut i = 0;
while i < len {
match bytes[i] {
b'(' => depth += 1,
b')' if depth > 0 => {
depth -= 1;
}
b' ' | b'\t' | b'\n' | b'\r' if depth == 0 => {
if let Some(pos) = is_as_keyword_at(item, bytes, i, len) {
last_as_pos = Some(pos);
}
}
_ => {}
}
i += 1;
}
last_as_pos.map(|pos| item[pos..].trim().to_string())
}
fn clean_identifier(s: &str) -> String {
let s = s.trim();
let s = s.trim_matches('`');
let s = s.trim_matches('"');
s.to_string()
}
#[cfg(feature = "column-lineage")]
pub fn extract_select_columns_from_expr(
expr: &Expression,
schema: Option<&dyn polyglot_sql::Schema>,
) -> Vec<String> {
let mut owned = expr.clone();
polyglot_sql::lineage::expand_cte_stars(&mut owned, schema);
match &owned {
Expression::Select(select) => select
.expressions
.iter()
.filter_map(|e| match e {
Expression::Alias(a) => Some(a.alias.name.clone()),
Expression::Column(c) => {
if c.name.name == "*" {
None } else {
Some(c.name.name.clone())
}
}
Expression::Identifier(id) => Some(id.name.clone()),
Expression::Star(_) => None, _ => None,
})
.collect(),
_ => vec![],
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "column-lineage")]
#[test]
fn test_extract_from_expr_cte_star() {
let sql = r#"with
source as (select * from "raw"."raw_orders"),
renamed as (
select id as order_id, customer as customer_id, ordered_at
from source
)
select * from renamed"#;
let expr = polyglot_sql::parse_one(sql, polyglot_sql::DialectType::Generic).unwrap();
let cols = extract_select_columns_from_expr(&expr, None);
assert_eq!(cols, vec!["order_id", "customer_id", "ordered_at"]);
}
#[cfg(feature = "column-lineage")]
#[test]
fn test_extract_from_expr_cte_star_with_cast() {
let sql = r#"with
source as (
select * from "jaffle_shop"."raw"."raw_orders"
),
renamed as (
select
id as order_id,
store_id as location_id,
customer as customer_id,
subtotal as subtotal_cents,
tax_paid as tax_paid_cents,
order_total as order_total_cents,
(subtotal / 100)::numeric(16, 2) as subtotal,
(tax_paid / 100)::numeric(16, 2) as tax_paid,
(order_total / 100)::numeric(16, 2) as order_total,
date_trunc('day', ordered_at) as ordered_at
from source
)
select * from renamed"#;
let expr = polyglot_sql::parse_one(sql, polyglot_sql::DialectType::Generic).unwrap();
let cols = extract_select_columns_from_expr(&expr, None);
assert!(cols.contains(&"order_id".to_string()), "cols: {:?}", cols);
assert!(
cols.contains(&"customer_id".to_string()),
"cols: {:?}",
cols
);
assert!(cols.contains(&"ordered_at".to_string()), "cols: {:?}", cols);
assert!(
cols.contains(&"order_total".to_string()),
"cols: {:?}",
cols
);
assert_eq!(cols.len(), 10, "cols: {:?}", cols);
}
#[test]
fn test_simple_select() {
let sql = "SELECT col1, col2 FROM my_table";
let cols = extract_select_columns(sql);
assert_eq!(cols, vec!["col1", "col2"]);
}
#[test]
fn test_select_with_aliases() {
let sql = "SELECT col1 AS alias1, col2 as alias2 FROM my_table";
let cols = extract_select_columns(sql);
assert_eq!(cols, vec!["alias1", "alias2"]);
}
#[test]
fn test_select_with_table_prefixes() {
let sql = "SELECT t.col1, t.col2 FROM my_table t";
let cols = extract_select_columns(sql);
assert_eq!(cols, vec!["col1", "col2"]);
}
#[test]
fn test_select_star() {
let sql = "SELECT * FROM my_table";
let cols = extract_select_columns(sql);
assert_eq!(cols, vec!["*"]);
}
#[test]
fn test_select_distinct() {
let sql = "SELECT DISTINCT col1, col2 FROM my_table";
let cols = extract_select_columns(sql);
assert_eq!(cols, vec!["col1", "col2"]);
}
#[test]
fn test_select_with_jinja() {
let sql = r#"
{{ config(materialized='table') }}
SELECT
order_id,
{{ dbt_utils.star(from=ref('stg_orders')) }},
customer_id
FROM {{ ref('stg_orders') }}
"#;
let cols = extract_select_columns(sql);
assert_eq!(cols, vec!["order_id", "__jinja__", "customer_id"]);
}
#[test]
fn test_multiline_select() {
let sql = r#"
SELECT
order_id,
customer_id,
order_date,
status
FROM orders
"#;
let cols = extract_select_columns(sql);
assert_eq!(
cols,
vec!["order_id", "customer_id", "order_date", "status"]
);
}
#[test]
fn test_cte_gets_outer_select() {
let sql = r#"
WITH cte AS (
SELECT inner_col1, inner_col2 FROM raw_table
)
SELECT outer_col1, outer_col2 FROM cte
"#;
let cols = extract_select_columns(sql);
assert_eq!(cols, vec!["outer_col1", "outer_col2"]);
}
#[test]
fn test_multiple_ctes_gets_final_select() {
let sql = r#"
WITH cte1 AS (
SELECT * FROM raw_table
),
cte2 AS (
SELECT a, b FROM cte1
)
SELECT
onramp_name,
count(distinct client_id) as total_known_clients,
sum(total_deals) as total_deals
FROM cte2
GROUP BY 1
"#;
let cols = extract_select_columns(sql);
assert_eq!(
cols,
vec!["onramp_name", "total_known_clients", "total_deals"]
);
}
#[test]
fn test_select_with_function() {
let sql = "SELECT COUNT(*) AS total, SUM(amount) AS total_amount FROM orders";
let cols = extract_select_columns(sql);
assert_eq!(cols, vec!["total", "total_amount"]);
}
#[test]
fn test_select_table_prefix_with_alias() {
let sql = "SELECT t.col1 AS alias1, t.col2 FROM my_table t";
let cols = extract_select_columns(sql);
assert_eq!(cols, vec!["alias1", "col2"]);
}
#[test]
fn test_no_select() {
let sql = "INSERT INTO my_table VALUES (1, 2, 3)";
let cols = extract_select_columns(sql);
assert!(cols.is_empty());
}
#[test]
fn test_select_with_jinja_comments() {
let sql = r#"
{# Select all order columns #}
SELECT order_id, status FROM orders
"#;
let cols = extract_select_columns(sql);
assert_eq!(cols, vec!["order_id", "status"]);
}
#[test]
fn test_select_with_cast() {
let sql = "SELECT CAST(order_id AS INTEGER) AS order_id, status FROM orders";
let cols = extract_select_columns(sql);
assert_eq!(cols, vec!["order_id", "status"]);
}
#[test]
fn test_select_with_subquery_alias() {
let sql = "SELECT (SELECT MAX(id) FROM t) AS max_id, name FROM users";
let cols = extract_select_columns(sql);
assert_eq!(cols, vec!["max_id", "name"]);
}
#[test]
fn test_typical_dbt_model() {
let sql = r#"
{{ config(materialized='view') }}
SELECT
order_id,
customer_id,
order_date,
status,
amount
FROM {{ ref('stg_orders') }}
"#;
let cols = extract_select_columns(sql);
assert_eq!(
cols,
vec!["order_id", "customer_id", "order_date", "status", "amount"]
);
}
#[test]
fn test_select_case_insensitive() {
let sql = "select col1, col2 from my_table";
let cols = extract_select_columns(sql);
assert_eq!(cols, vec!["col1", "col2"]);
}
#[test]
fn test_select_with_multibyte_utf8_comment() {
let sql = r#"SELECT
case
when flag = true then false -- 日本語コメント
else flag
end as flag
FROM my_table"#;
let cols = extract_select_columns(sql);
assert_eq!(cols, vec!["flag"]);
}
#[test]
fn test_select_with_multibyte_utf8_string_literal() {
let sql = "SELECT '中文字符' AS label, col1 FROM my_table";
let cols = extract_select_columns(sql);
assert_eq!(cols, vec!["label", "col1"]);
}
#[test]
fn test_select_with_korean_comment_no_panic() {
let sql = "SELECT col1, col2 -- 한국어 코멘트\nFROM my_table";
let cols = extract_select_columns(sql);
assert!(!cols.is_empty());
}
#[test]
fn test_select_with_emoji_comment_no_panic() {
let sql = "SELECT col1 -- 🎉 celebration\nFROM my_table";
let cols = extract_select_columns(sql);
assert!(!cols.is_empty());
}
#[test]
fn test_select_with_backtick_identifiers() {
let sql = "SELECT `col1`, `col2` FROM my_table";
let cols = extract_select_columns(sql);
assert_eq!(cols, vec!["col1", "col2"]);
}
#[test]
fn test_extract_alias_after_paren_no_alias() {
let result = extract_alias_after_paren("(SELECT 1)");
assert!(result.is_none());
}
#[test]
fn test_extract_alias_after_paren_bare_alias() {
let result = extract_alias_after_paren("(SELECT 1) my_alias");
assert_eq!(result, Some("my_alias".to_string()));
}
#[test]
fn test_extract_alias_after_paren_as_alias() {
let result = extract_alias_after_paren("(SELECT 1) AS my_alias");
assert_eq!(result, Some("my_alias".to_string()));
}
#[test]
fn test_extract_alias_after_paren_no_paren() {
let result = extract_alias_after_paren("SELECT 1");
assert!(result.is_none());
}
}