Skip to main content

dlin_core/parser/
columns.rs

1#[cfg(feature = "column-lineage")]
2use polyglot_sql::expressions::Expression;
3use regex::Regex;
4use std::sync::LazyLock;
5
6/// Regex to strip Jinja tags {{ ... }} and {%- ... -%} etc.
7static JINJA_TAG: LazyLock<Regex> =
8    LazyLock::new(|| Regex::new(r"\{\{-?[\s\S]*?-?\}\}|\{%-?[\s\S]*?-?%\}").unwrap());
9
10/// Regex to strip Jinja comments {# ... #}
11static JINJA_COMMENT: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\{#[\s\S]*?#\}").unwrap());
12
13/// Extract column names from the outermost SELECT clause of a SQL string.
14///
15/// This is a best-effort regex-based extraction, not a full SQL parser.
16/// It handles:
17/// - `SELECT col1, col2 FROM ...` -> `["col1", "col2"]`
18/// - `SELECT t.col1 AS alias1` -> `["alias1"]`
19/// - `SELECT col1 as alias1` -> `["alias1"]`
20/// - `SELECT *` -> `["*"]`
21/// - `SELECT DISTINCT col1, col2` -> `["col1", "col2"]`
22/// - Jinja tags are stripped before parsing
23/// - Subqueries in parentheses are skipped
24/// - Multiline SELECT clauses are handled
25pub fn extract_select_columns(sql: &str) -> Vec<String> {
26    // Strip Jinja comments and tags
27    let cleaned = JINJA_COMMENT.replace_all(sql, "");
28    let cleaned = JINJA_TAG.replace_all(&cleaned, "__jinja__");
29
30    // Find the last top-level SELECT keyword (not inside parentheses).
31    // This handles CTEs correctly: `WITH cte AS (SELECT ... ) SELECT ...`
32    // where the CTE's SELECT is inside parentheses.
33    let select_end = match find_last_top_level_select(&cleaned) {
34        Some(end) => end,
35        None => return vec![],
36    };
37
38    // Find the first top-level FROM after the SELECT (not inside parentheses)
39    let after_select = &cleaned[select_end..];
40    let select_body = match find_top_level_from(after_select) {
41        Some(pos) => &after_select[..pos],
42        None => return vec![],
43    };
44
45    // Split on commas, but not commas inside parentheses
46    let items = split_top_level_commas(select_body);
47
48    items
49        .iter()
50        .filter_map(|item| classify_select_item(item.trim()))
51        .collect()
52}
53
54/// Classify a single SELECT item and return its column name, if any.
55fn classify_select_item(item: &str) -> Option<String> {
56    if item.is_empty() {
57        return None;
58    }
59
60    // Items starting with '(' are subqueries; check for alias after closing paren
61    if item.starts_with('(') {
62        return extract_alias_after_paren(item);
63    }
64
65    let col = extract_column_name(item);
66    if col.is_empty() { None } else { Some(col) }
67}
68
69/// Find the byte offset just after the last top-level SELECT keyword (not inside parentheses).
70/// Also skips DISTINCT if present. Returns the position where the column list begins.
71fn find_last_top_level_select(s: &str) -> Option<usize> {
72    let bytes = s.as_bytes();
73    let len = bytes.len();
74    let mut depth: u32 = 0;
75    let mut last_select_end: Option<usize> = None;
76    let mut i = 0;
77
78    while i < len {
79        match bytes[i] {
80            b'(' => depth += 1,
81            b')' => {
82                depth = depth.saturating_sub(1);
83            }
84            b's' | b'S' if depth == 0 && check_keyword_at(bytes, i, len, b"SELECT") => {
85                let end = i + 6;
86                // Skip optional DISTINCT
87                let after = skip_whitespace(bytes, end, len);
88                if check_keyword_at(bytes, after, len, b"DISTINCT") {
89                    let after_distinct = skip_whitespace(bytes, after + 8, len);
90                    last_select_end = Some(after_distinct);
91                } else {
92                    last_select_end = Some(after);
93                }
94            }
95            _ => {}
96        }
97        i += 1;
98    }
99
100    last_select_end
101}
102
103/// Check if the bytes at position `i` match the given keyword (case-insensitive)
104/// with word boundaries on both sides.
105fn check_keyword_at(bytes: &[u8], i: usize, len: usize, keyword: &[u8]) -> bool {
106    let klen = keyword.len();
107    if i + klen > len {
108        return false;
109    }
110    for j in 0..klen {
111        if !bytes[i + j].eq_ignore_ascii_case(&keyword[j]) {
112            return false;
113        }
114    }
115    let before_ok = i == 0 || is_word_boundary(bytes[i - 1]);
116    let after_ok = i + klen >= len || is_word_boundary(bytes[i + klen]);
117    before_ok && after_ok
118}
119
120/// Skip whitespace characters and return the next non-whitespace position.
121fn skip_whitespace(bytes: &[u8], start: usize, len: usize) -> usize {
122    let mut i = start;
123    while i < len && bytes[i].is_ascii_whitespace() {
124        i += 1;
125    }
126    i
127}
128
129/// Check if a byte is a word boundary character (not alphanumeric and not underscore)
130fn is_word_boundary(b: u8) -> bool {
131    !b.is_ascii_alphanumeric() && b != b'_'
132}
133
134/// Check if position `i` in string `s` starts a top-level `FROM` keyword with proper boundaries.
135/// Uses byte-level comparison to avoid panics on multi-byte UTF-8 characters.
136fn check_from_at(_s: &str, bytes: &[u8], i: usize, len: usize) -> bool {
137    if i + 4 > len {
138        return false;
139    }
140    let from_match = matches!(bytes[i], b'f' | b'F')
141        && matches!(bytes[i + 1], b'r' | b'R')
142        && matches!(bytes[i + 2], b'o' | b'O')
143        && matches!(bytes[i + 3], b'm' | b'M');
144    if !from_match {
145        return false;
146    }
147    let before_ok = i == 0 || is_word_boundary(bytes[i - 1]);
148    let after_ok = i + 4 >= len || is_word_boundary(bytes[i + 4]);
149    before_ok && after_ok
150}
151
152/// Find the position of the first top-level `FROM` keyword (not inside parentheses).
153/// Returns the byte offset of the start of `FROM` relative to the input string.
154fn find_top_level_from(s: &str) -> Option<usize> {
155    let bytes = s.as_bytes();
156    let len = bytes.len();
157    let mut depth: u32 = 0;
158    let mut i = 0;
159
160    while i < len {
161        match bytes[i] {
162            b'(' => depth += 1,
163            b')' => {
164                depth = depth.saturating_sub(1);
165            }
166            b'f' | b'F' if depth == 0 && check_from_at(s, bytes, i, len) => {
167                return Some(i);
168            }
169            _ => {}
170        }
171        i += 1;
172    }
173
174    None
175}
176
177/// Split a string on commas that are not inside parentheses.
178fn split_top_level_commas(s: &str) -> Vec<String> {
179    let mut items = Vec::new();
180    let mut current = String::new();
181    let mut depth = 0;
182
183    for ch in s.chars() {
184        match ch {
185            '(' => {
186                depth += 1;
187                current.push(ch);
188            }
189            ')' => {
190                depth -= 1;
191                current.push(ch);
192            }
193            ',' if depth == 0 => {
194                items.push(current.clone());
195                current.clear();
196            }
197            _ => {
198                current.push(ch);
199            }
200        }
201    }
202
203    if !current.trim().is_empty() {
204        items.push(current);
205    }
206
207    items
208}
209
210/// Extract the alias after a closing parenthesis, e.g., `(SELECT ...) AS alias`
211fn extract_alias_after_paren(item: &str) -> Option<String> {
212    // Find the last closing paren
213    let close = item.rfind(')')?;
214    let after = item[close + 1..].trim();
215    if after.is_empty() {
216        return None;
217    }
218    // Strip leading AS (case-insensitive) using byte comparison to avoid
219    // panics on multi-byte UTF-8 characters
220    let after = if after.len() >= 3
221        && matches!(after.as_bytes()[0], b'a' | b'A')
222        && matches!(after.as_bytes()[1], b's' | b'S')
223        && after.as_bytes()[2].is_ascii_whitespace()
224    {
225        after[2..].trim()
226    } else {
227        after
228    };
229    if after.is_empty() {
230        None
231    } else {
232        Some(clean_identifier(after))
233    }
234}
235
236/// Extract the effective column name from a single SELECT item.
237///
238/// Rules:
239/// 1. If `AS alias` is present, return the alias.
240/// 2. If `table.column`, return column.
241/// 3. Otherwise return the token itself (e.g., `*`, `col1`).
242fn extract_column_name(item: &str) -> String {
243    let item = item.trim();
244
245    // Check for AS alias (case-insensitive) - look for last " AS " or " as "
246    // We search from the end to handle expressions like `CAST(x AS int) AS col`
247    if let Some(alias) = find_last_as_alias(item) {
248        return clean_identifier(&alias);
249    }
250
251    // No alias; take the last token (handles `table.col` and bare `col`)
252    let last_token = item.split_whitespace().last().unwrap_or(item);
253
254    // Handle table.column
255    if let Some(pos) = last_token.rfind('.') {
256        return clean_identifier(&last_token[pos + 1..]);
257    }
258
259    clean_identifier(last_token)
260}
261
262/// Check if position `i` (a whitespace char) starts a top-level ` AS ` token.
263/// Returns the position after "AS " if matched.
264/// Uses byte-level comparison to avoid panics on multi-byte UTF-8 characters.
265fn is_as_keyword_at(_item: &str, bytes: &[u8], i: usize, len: usize) -> Option<usize> {
266    if i + 3 >= len {
267        return None;
268    }
269    let as_match = matches!(bytes[i + 1], b'a' | b'A') && matches!(bytes[i + 2], b's' | b'S');
270    if as_match && matches!(bytes[i + 3], b' ' | b'\t' | b'\n' | b'\r') {
271        Some(i + 4)
272    } else {
273        None
274    }
275}
276
277/// Find the alias from the last ` AS ` keyword that is not inside parentheses.
278fn find_last_as_alias(item: &str) -> Option<String> {
279    let bytes = item.as_bytes();
280    let len = bytes.len();
281    let mut depth = 0;
282    let mut last_as_pos: Option<usize> = None;
283
284    let mut i = 0;
285    while i < len {
286        match bytes[i] {
287            b'(' => depth += 1,
288            b')' if depth > 0 => {
289                depth -= 1;
290            }
291            b' ' | b'\t' | b'\n' | b'\r' if depth == 0 => {
292                if let Some(pos) = is_as_keyword_at(item, bytes, i, len) {
293                    last_as_pos = Some(pos);
294                }
295            }
296            _ => {}
297        }
298        i += 1;
299    }
300
301    last_as_pos.map(|pos| item[pos..].trim().to_string())
302}
303
304/// Clean an identifier: trim whitespace and remove surrounding backticks or quotes.
305fn clean_identifier(s: &str) -> String {
306    let s = s.trim();
307    let s = s.trim_matches('`');
308    let s = s.trim_matches('"');
309    s.to_string()
310}
311
312/// Extract output column names from a parsed polyglot-sql Expression.
313///
314/// Applies CTE star expansion to resolve `SELECT *` through CTEs,
315/// then reads the output column names from the outermost SELECT.
316/// Returns an empty Vec if the expression is not a SELECT. Unresolved star columns (`*` or
317/// qualified stars) are silently dropped from the result.
318#[cfg(feature = "column-lineage")]
319pub fn extract_select_columns_from_expr(
320    expr: &Expression,
321    schema: Option<&dyn polyglot_sql::Schema>,
322) -> Vec<String> {
323    let mut owned = expr.clone();
324    polyglot_sql::lineage::expand_cte_stars(&mut owned, schema);
325    match &owned {
326        Expression::Select(select) => select
327            .expressions
328            .iter()
329            .filter_map(|e| match e {
330                Expression::Alias(a) => Some(a.alias.name.clone()),
331                Expression::Column(c) => {
332                    if c.name.name == "*" {
333                        None // unresolved qualified star
334                    } else {
335                        Some(c.name.name.clone())
336                    }
337                }
338                Expression::Identifier(id) => Some(id.name.clone()),
339                Expression::Star(_) => None, // unresolved star
340                _ => None,
341            })
342            .collect(),
343        _ => vec![],
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    #[cfg(feature = "column-lineage")]
352    #[test]
353    fn test_extract_from_expr_cte_star() {
354        let sql = r#"with
355source as (select * from "raw"."raw_orders"),
356renamed as (
357    select id as order_id, customer as customer_id, ordered_at
358    from source
359)
360select * from renamed"#;
361        let expr = polyglot_sql::parse_one(sql, polyglot_sql::DialectType::Generic).unwrap();
362        let cols = extract_select_columns_from_expr(&expr, None);
363        assert_eq!(cols, vec!["order_id", "customer_id", "ordered_at"]);
364    }
365
366    #[cfg(feature = "column-lineage")]
367    #[test]
368    fn test_extract_from_expr_cte_star_with_cast() {
369        // Realistic dbt stg_orders pattern with ::numeric cast
370        let sql = r#"with
371source as (
372    select * from "jaffle_shop"."raw"."raw_orders"
373),
374renamed as (
375    select
376        id as order_id,
377        store_id as location_id,
378        customer as customer_id,
379        subtotal as subtotal_cents,
380        tax_paid as tax_paid_cents,
381        order_total as order_total_cents,
382        (subtotal / 100)::numeric(16, 2) as subtotal,
383        (tax_paid / 100)::numeric(16, 2) as tax_paid,
384        (order_total / 100)::numeric(16, 2) as order_total,
385        date_trunc('day', ordered_at) as ordered_at
386    from source
387)
388select * from renamed"#;
389        let expr = polyglot_sql::parse_one(sql, polyglot_sql::DialectType::Generic).unwrap();
390        let cols = extract_select_columns_from_expr(&expr, None);
391        assert!(cols.contains(&"order_id".to_string()), "cols: {:?}", cols);
392        assert!(
393            cols.contains(&"customer_id".to_string()),
394            "cols: {:?}",
395            cols
396        );
397        assert!(cols.contains(&"ordered_at".to_string()), "cols: {:?}", cols);
398        assert!(
399            cols.contains(&"order_total".to_string()),
400            "cols: {:?}",
401            cols
402        );
403        assert_eq!(cols.len(), 10, "cols: {:?}", cols);
404    }
405
406    #[test]
407    fn test_simple_select() {
408        let sql = "SELECT col1, col2 FROM my_table";
409        let cols = extract_select_columns(sql);
410        assert_eq!(cols, vec!["col1", "col2"]);
411    }
412
413    #[test]
414    fn test_select_with_aliases() {
415        let sql = "SELECT col1 AS alias1, col2 as alias2 FROM my_table";
416        let cols = extract_select_columns(sql);
417        assert_eq!(cols, vec!["alias1", "alias2"]);
418    }
419
420    #[test]
421    fn test_select_with_table_prefixes() {
422        let sql = "SELECT t.col1, t.col2 FROM my_table t";
423        let cols = extract_select_columns(sql);
424        assert_eq!(cols, vec!["col1", "col2"]);
425    }
426
427    #[test]
428    fn test_select_star() {
429        let sql = "SELECT * FROM my_table";
430        let cols = extract_select_columns(sql);
431        assert_eq!(cols, vec!["*"]);
432    }
433
434    #[test]
435    fn test_select_distinct() {
436        let sql = "SELECT DISTINCT col1, col2 FROM my_table";
437        let cols = extract_select_columns(sql);
438        assert_eq!(cols, vec!["col1", "col2"]);
439    }
440
441    #[test]
442    fn test_select_with_jinja() {
443        let sql = r#"
444            {{ config(materialized='table') }}
445
446            SELECT
447                order_id,
448                {{ dbt_utils.star(from=ref('stg_orders')) }},
449                customer_id
450            FROM {{ ref('stg_orders') }}
451        "#;
452        let cols = extract_select_columns(sql);
453        assert_eq!(cols, vec!["order_id", "__jinja__", "customer_id"]);
454    }
455
456    #[test]
457    fn test_multiline_select() {
458        let sql = r#"
459            SELECT
460                order_id,
461                customer_id,
462                order_date,
463                status
464            FROM orders
465        "#;
466        let cols = extract_select_columns(sql);
467        assert_eq!(
468            cols,
469            vec!["order_id", "customer_id", "order_date", "status"]
470        );
471    }
472
473    #[test]
474    fn test_cte_gets_outer_select() {
475        let sql = r#"
476            WITH cte AS (
477                SELECT inner_col1, inner_col2 FROM raw_table
478            )
479            SELECT outer_col1, outer_col2 FROM cte
480        "#;
481        let cols = extract_select_columns(sql);
482        assert_eq!(cols, vec!["outer_col1", "outer_col2"]);
483    }
484
485    #[test]
486    fn test_multiple_ctes_gets_final_select() {
487        let sql = r#"
488            WITH cte1 AS (
489                SELECT * FROM raw_table
490            ),
491            cte2 AS (
492                SELECT a, b FROM cte1
493            )
494            SELECT
495                onramp_name,
496                count(distinct client_id) as total_known_clients,
497                sum(total_deals) as total_deals
498            FROM cte2
499            GROUP BY 1
500        "#;
501        let cols = extract_select_columns(sql);
502        assert_eq!(
503            cols,
504            vec!["onramp_name", "total_known_clients", "total_deals"]
505        );
506    }
507
508    #[test]
509    fn test_select_with_function() {
510        let sql = "SELECT COUNT(*) AS total, SUM(amount) AS total_amount FROM orders";
511        let cols = extract_select_columns(sql);
512        assert_eq!(cols, vec!["total", "total_amount"]);
513    }
514
515    #[test]
516    fn test_select_table_prefix_with_alias() {
517        let sql = "SELECT t.col1 AS alias1, t.col2 FROM my_table t";
518        let cols = extract_select_columns(sql);
519        assert_eq!(cols, vec!["alias1", "col2"]);
520    }
521
522    #[test]
523    fn test_no_select() {
524        let sql = "INSERT INTO my_table VALUES (1, 2, 3)";
525        let cols = extract_select_columns(sql);
526        assert!(cols.is_empty());
527    }
528
529    #[test]
530    fn test_select_with_jinja_comments() {
531        let sql = r#"
532            {# Select all order columns #}
533            SELECT order_id, status FROM orders
534        "#;
535        let cols = extract_select_columns(sql);
536        assert_eq!(cols, vec!["order_id", "status"]);
537    }
538
539    #[test]
540    fn test_select_with_cast() {
541        let sql = "SELECT CAST(order_id AS INTEGER) AS order_id, status FROM orders";
542        let cols = extract_select_columns(sql);
543        assert_eq!(cols, vec!["order_id", "status"]);
544    }
545
546    #[test]
547    fn test_select_with_subquery_alias() {
548        let sql = "SELECT (SELECT MAX(id) FROM t) AS max_id, name FROM users";
549        let cols = extract_select_columns(sql);
550        assert_eq!(cols, vec!["max_id", "name"]);
551    }
552
553    #[test]
554    fn test_typical_dbt_model() {
555        let sql = r#"
556            {{ config(materialized='view') }}
557
558            SELECT
559                order_id,
560                customer_id,
561                order_date,
562                status,
563                amount
564            FROM {{ ref('stg_orders') }}
565        "#;
566        let cols = extract_select_columns(sql);
567        assert_eq!(
568            cols,
569            vec!["order_id", "customer_id", "order_date", "status", "amount"]
570        );
571    }
572
573    #[test]
574    fn test_select_case_insensitive() {
575        let sql = "select col1, col2 from my_table";
576        let cols = extract_select_columns(sql);
577        assert_eq!(cols, vec!["col1", "col2"]);
578    }
579
580    #[test]
581    fn test_select_with_multibyte_utf8_comment() {
582        // GitHub Issue #1: panic on multi-byte UTF-8 characters
583        let sql = r#"SELECT
584    case
585      when flag = true then false -- 日本語コメント
586      else flag
587    end as flag
588FROM my_table"#;
589        let cols = extract_select_columns(sql);
590        assert_eq!(cols, vec!["flag"]);
591    }
592
593    #[test]
594    fn test_select_with_multibyte_utf8_string_literal() {
595        let sql = "SELECT '中文字符' AS label, col1 FROM my_table";
596        let cols = extract_select_columns(sql);
597        assert_eq!(cols, vec!["label", "col1"]);
598    }
599
600    #[test]
601    fn test_select_with_korean_comment_no_panic() {
602        // Verify no panic on Korean characters (comment stripping is a separate concern)
603        let sql = "SELECT col1, col2 -- 한국어 코멘트\nFROM my_table";
604        let cols = extract_select_columns(sql);
605        assert!(!cols.is_empty());
606    }
607
608    #[test]
609    fn test_select_with_emoji_comment_no_panic() {
610        // Verify no panic on emoji characters (comment stripping is a separate concern)
611        let sql = "SELECT col1 -- 🎉 celebration\nFROM my_table";
612        let cols = extract_select_columns(sql);
613        assert!(!cols.is_empty());
614    }
615
616    #[test]
617    fn test_select_with_backtick_identifiers() {
618        let sql = "SELECT `col1`, `col2` FROM my_table";
619        let cols = extract_select_columns(sql);
620        assert_eq!(cols, vec!["col1", "col2"]);
621    }
622
623    #[test]
624    fn test_extract_alias_after_paren_no_alias() {
625        // Subquery with no alias after the closing paren
626        let result = extract_alias_after_paren("(SELECT 1)");
627        assert!(result.is_none());
628    }
629
630    #[test]
631    fn test_extract_alias_after_paren_bare_alias() {
632        // Subquery with bare alias (no AS keyword)
633        let result = extract_alias_after_paren("(SELECT 1) my_alias");
634        assert_eq!(result, Some("my_alias".to_string()));
635    }
636
637    #[test]
638    fn test_extract_alias_after_paren_as_alias() {
639        // Subquery with AS alias
640        let result = extract_alias_after_paren("(SELECT 1) AS my_alias");
641        assert_eq!(result, Some("my_alias".to_string()));
642    }
643
644    #[test]
645    fn test_extract_alias_after_paren_no_paren() {
646        // No closing paren at all
647        let result = extract_alias_after_paren("SELECT 1");
648        assert!(result.is_none());
649    }
650}