Skip to main content

dlin_core/parser/
columns.rs

1use regex::Regex;
2use std::sync::LazyLock;
3
4/// Regex to strip Jinja tags {{ ... }} and {%- ... -%} etc.
5static JINJA_TAG: LazyLock<Regex> =
6    LazyLock::new(|| Regex::new(r"\{\{-?[\s\S]*?-?\}\}|\{%-?[\s\S]*?-?%\}").unwrap());
7
8/// Regex to strip Jinja comments {# ... #}
9static JINJA_COMMENT: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\{#[\s\S]*?#\}").unwrap());
10
11/// Extract column names from the outermost SELECT clause of a SQL string.
12///
13/// This is a best-effort regex-based extraction, not a full SQL parser.
14/// It handles:
15/// - `SELECT col1, col2 FROM ...` -> `["col1", "col2"]`
16/// - `SELECT t.col1 AS alias1` -> `["alias1"]`
17/// - `SELECT col1 as alias1` -> `["alias1"]`
18/// - `SELECT *` -> `["*"]`
19/// - `SELECT DISTINCT col1, col2` -> `["col1", "col2"]`
20/// - Jinja tags are stripped before parsing
21/// - Subqueries in parentheses are skipped
22/// - Multiline SELECT clauses are handled
23pub fn extract_select_columns(sql: &str) -> Vec<String> {
24    // Strip Jinja comments and tags
25    let cleaned = JINJA_COMMENT.replace_all(sql, "");
26    let cleaned = JINJA_TAG.replace_all(&cleaned, "__jinja__");
27
28    // Find the last top-level SELECT keyword (not inside parentheses).
29    // This handles CTEs correctly: `WITH cte AS (SELECT ... ) SELECT ...`
30    // where the CTE's SELECT is inside parentheses.
31    let select_end = match find_last_top_level_select(&cleaned) {
32        Some(end) => end,
33        None => return vec![],
34    };
35
36    // Find the first top-level FROM after the SELECT (not inside parentheses)
37    let after_select = &cleaned[select_end..];
38    let select_body = match find_top_level_from(after_select) {
39        Some(pos) => &after_select[..pos],
40        None => return vec![],
41    };
42
43    // Split on commas, but not commas inside parentheses
44    let items = split_top_level_commas(select_body);
45
46    items
47        .iter()
48        .filter_map(|item| classify_select_item(item.trim()))
49        .collect()
50}
51
52/// Classify a single SELECT item and return its column name, if any.
53fn classify_select_item(item: &str) -> Option<String> {
54    if item.is_empty() {
55        return None;
56    }
57
58    // Items starting with '(' are subqueries; check for alias after closing paren
59    if item.starts_with('(') {
60        return extract_alias_after_paren(item);
61    }
62
63    let col = extract_column_name(item);
64    if col.is_empty() { None } else { Some(col) }
65}
66
67/// Find the byte offset just after the last top-level SELECT keyword (not inside parentheses).
68/// Also skips DISTINCT if present. Returns the position where the column list begins.
69fn find_last_top_level_select(s: &str) -> Option<usize> {
70    let bytes = s.as_bytes();
71    let len = bytes.len();
72    let mut depth: u32 = 0;
73    let mut last_select_end: Option<usize> = None;
74    let mut i = 0;
75
76    while i < len {
77        match bytes[i] {
78            b'(' => depth += 1,
79            b')' => {
80                depth = depth.saturating_sub(1);
81            }
82            b's' | b'S' if depth == 0 => {
83                if check_keyword_at(bytes, i, len, b"SELECT") {
84                    let end = i + 6;
85                    // Skip optional DISTINCT
86                    let after = skip_whitespace(bytes, end, len);
87                    if check_keyword_at(bytes, after, len, b"DISTINCT") {
88                        let after_distinct = skip_whitespace(bytes, after + 8, len);
89                        last_select_end = Some(after_distinct);
90                    } else {
91                        last_select_end = Some(after);
92                    }
93                }
94            }
95            _ => {}
96        }
97        i += 1;
98    }
99
100    last_select_end
101}
102
103/// Check if the bytes at position `i` match the given keyword (case-insensitive)
104/// with word boundaries on both sides.
105fn check_keyword_at(bytes: &[u8], i: usize, len: usize, keyword: &[u8]) -> bool {
106    let klen = keyword.len();
107    if i + klen > len {
108        return false;
109    }
110    for j in 0..klen {
111        if !bytes[i + j].eq_ignore_ascii_case(&keyword[j]) {
112            return false;
113        }
114    }
115    let before_ok = i == 0 || is_word_boundary(bytes[i - 1]);
116    let after_ok = i + klen >= len || is_word_boundary(bytes[i + klen]);
117    before_ok && after_ok
118}
119
120/// Skip whitespace characters and return the next non-whitespace position.
121fn skip_whitespace(bytes: &[u8], start: usize, len: usize) -> usize {
122    let mut i = start;
123    while i < len && bytes[i].is_ascii_whitespace() {
124        i += 1;
125    }
126    i
127}
128
129/// Check if a byte is a word boundary character (not alphanumeric and not underscore)
130fn is_word_boundary(b: u8) -> bool {
131    !b.is_ascii_alphanumeric() && b != b'_'
132}
133
134/// Check if position `i` in string `s` starts a top-level `FROM` keyword with proper boundaries.
135/// Uses byte-level comparison to avoid panics on multi-byte UTF-8 characters.
136fn check_from_at(_s: &str, bytes: &[u8], i: usize, len: usize) -> bool {
137    if i + 4 > len {
138        return false;
139    }
140    let from_match = matches!(bytes[i], b'f' | b'F')
141        && matches!(bytes[i + 1], b'r' | b'R')
142        && matches!(bytes[i + 2], b'o' | b'O')
143        && matches!(bytes[i + 3], b'm' | b'M');
144    if !from_match {
145        return false;
146    }
147    let before_ok = i == 0 || is_word_boundary(bytes[i - 1]);
148    let after_ok = i + 4 >= len || is_word_boundary(bytes[i + 4]);
149    before_ok && after_ok
150}
151
152/// Find the position of the first top-level `FROM` keyword (not inside parentheses).
153/// Returns the byte offset of the start of `FROM` relative to the input string.
154fn find_top_level_from(s: &str) -> Option<usize> {
155    let bytes = s.as_bytes();
156    let len = bytes.len();
157    let mut depth: u32 = 0;
158    let mut i = 0;
159
160    while i < len {
161        match bytes[i] {
162            b'(' => depth += 1,
163            b')' => {
164                depth = depth.saturating_sub(1);
165            }
166            b'f' | b'F' if depth == 0 => {
167                if check_from_at(s, bytes, i, len) {
168                    return Some(i);
169                }
170            }
171            _ => {}
172        }
173        i += 1;
174    }
175
176    None
177}
178
179/// Split a string on commas that are not inside parentheses.
180fn split_top_level_commas(s: &str) -> Vec<String> {
181    let mut items = Vec::new();
182    let mut current = String::new();
183    let mut depth = 0;
184
185    for ch in s.chars() {
186        match ch {
187            '(' => {
188                depth += 1;
189                current.push(ch);
190            }
191            ')' => {
192                depth -= 1;
193                current.push(ch);
194            }
195            ',' if depth == 0 => {
196                items.push(current.clone());
197                current.clear();
198            }
199            _ => {
200                current.push(ch);
201            }
202        }
203    }
204
205    if !current.trim().is_empty() {
206        items.push(current);
207    }
208
209    items
210}
211
212/// Extract the alias after a closing parenthesis, e.g., `(SELECT ...) AS alias`
213fn extract_alias_after_paren(item: &str) -> Option<String> {
214    // Find the last closing paren
215    let close = item.rfind(')')?;
216    let after = item[close + 1..].trim();
217    if after.is_empty() {
218        return None;
219    }
220    // Strip leading AS (case-insensitive) using byte comparison to avoid
221    // panics on multi-byte UTF-8 characters
222    let after = if after.len() >= 3
223        && matches!(after.as_bytes()[0], b'a' | b'A')
224        && matches!(after.as_bytes()[1], b's' | b'S')
225        && after.as_bytes()[2].is_ascii_whitespace()
226    {
227        after[2..].trim()
228    } else {
229        after
230    };
231    if after.is_empty() {
232        None
233    } else {
234        Some(clean_identifier(after))
235    }
236}
237
238/// Extract the effective column name from a single SELECT item.
239///
240/// Rules:
241/// 1. If `AS alias` is present, return the alias.
242/// 2. If `table.column`, return column.
243/// 3. Otherwise return the token itself (e.g., `*`, `col1`).
244fn extract_column_name(item: &str) -> String {
245    let item = item.trim();
246
247    // Check for AS alias (case-insensitive) - look for last " AS " or " as "
248    // We search from the end to handle expressions like `CAST(x AS int) AS col`
249    if let Some(alias) = find_last_as_alias(item) {
250        return clean_identifier(&alias);
251    }
252
253    // No alias; take the last token (handles `table.col` and bare `col`)
254    let last_token = item.split_whitespace().last().unwrap_or(item);
255
256    // Handle table.column
257    if let Some(pos) = last_token.rfind('.') {
258        return clean_identifier(&last_token[pos + 1..]);
259    }
260
261    clean_identifier(last_token)
262}
263
264/// Check if position `i` (a whitespace char) starts a top-level ` AS ` token.
265/// Returns the position after "AS " if matched.
266/// Uses byte-level comparison to avoid panics on multi-byte UTF-8 characters.
267fn is_as_keyword_at(_item: &str, bytes: &[u8], i: usize, len: usize) -> Option<usize> {
268    if i + 3 >= len {
269        return None;
270    }
271    let as_match = matches!(bytes[i + 1], b'a' | b'A') && matches!(bytes[i + 2], b's' | b'S');
272    if as_match && matches!(bytes[i + 3], b' ' | b'\t' | b'\n' | b'\r') {
273        Some(i + 4)
274    } else {
275        None
276    }
277}
278
279/// Find the alias from the last ` AS ` keyword that is not inside parentheses.
280fn find_last_as_alias(item: &str) -> Option<String> {
281    let bytes = item.as_bytes();
282    let len = bytes.len();
283    let mut depth = 0;
284    let mut last_as_pos: Option<usize> = None;
285
286    let mut i = 0;
287    while i < len {
288        match bytes[i] {
289            b'(' => depth += 1,
290            b')' => {
291                if depth > 0 {
292                    depth -= 1;
293                }
294            }
295            b' ' | b'\t' | b'\n' | b'\r' if depth == 0 => {
296                if let Some(pos) = is_as_keyword_at(item, bytes, i, len) {
297                    last_as_pos = Some(pos);
298                }
299            }
300            _ => {}
301        }
302        i += 1;
303    }
304
305    last_as_pos.map(|pos| item[pos..].trim().to_string())
306}
307
308/// Clean an identifier: trim whitespace and remove surrounding backticks or quotes.
309fn clean_identifier(s: &str) -> String {
310    let s = s.trim();
311    let s = s.trim_matches('`');
312    let s = s.trim_matches('"');
313    s.to_string()
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319
320    #[test]
321    fn test_simple_select() {
322        let sql = "SELECT col1, col2 FROM my_table";
323        let cols = extract_select_columns(sql);
324        assert_eq!(cols, vec!["col1", "col2"]);
325    }
326
327    #[test]
328    fn test_select_with_aliases() {
329        let sql = "SELECT col1 AS alias1, col2 as alias2 FROM my_table";
330        let cols = extract_select_columns(sql);
331        assert_eq!(cols, vec!["alias1", "alias2"]);
332    }
333
334    #[test]
335    fn test_select_with_table_prefixes() {
336        let sql = "SELECT t.col1, t.col2 FROM my_table t";
337        let cols = extract_select_columns(sql);
338        assert_eq!(cols, vec!["col1", "col2"]);
339    }
340
341    #[test]
342    fn test_select_star() {
343        let sql = "SELECT * FROM my_table";
344        let cols = extract_select_columns(sql);
345        assert_eq!(cols, vec!["*"]);
346    }
347
348    #[test]
349    fn test_select_distinct() {
350        let sql = "SELECT DISTINCT col1, col2 FROM my_table";
351        let cols = extract_select_columns(sql);
352        assert_eq!(cols, vec!["col1", "col2"]);
353    }
354
355    #[test]
356    fn test_select_with_jinja() {
357        let sql = r#"
358            {{ config(materialized='table') }}
359
360            SELECT
361                order_id,
362                {{ dbt_utils.star(from=ref('stg_orders')) }},
363                customer_id
364            FROM {{ ref('stg_orders') }}
365        "#;
366        let cols = extract_select_columns(sql);
367        assert_eq!(cols, vec!["order_id", "__jinja__", "customer_id"]);
368    }
369
370    #[test]
371    fn test_multiline_select() {
372        let sql = r#"
373            SELECT
374                order_id,
375                customer_id,
376                order_date,
377                status
378            FROM orders
379        "#;
380        let cols = extract_select_columns(sql);
381        assert_eq!(
382            cols,
383            vec!["order_id", "customer_id", "order_date", "status"]
384        );
385    }
386
387    #[test]
388    fn test_cte_gets_outer_select() {
389        let sql = r#"
390            WITH cte AS (
391                SELECT inner_col1, inner_col2 FROM raw_table
392            )
393            SELECT outer_col1, outer_col2 FROM cte
394        "#;
395        let cols = extract_select_columns(sql);
396        assert_eq!(cols, vec!["outer_col1", "outer_col2"]);
397    }
398
399    #[test]
400    fn test_multiple_ctes_gets_final_select() {
401        let sql = r#"
402            WITH cte1 AS (
403                SELECT * FROM raw_table
404            ),
405            cte2 AS (
406                SELECT a, b FROM cte1
407            )
408            SELECT
409                onramp_name,
410                count(distinct client_id) as total_known_clients,
411                sum(total_deals) as total_deals
412            FROM cte2
413            GROUP BY 1
414        "#;
415        let cols = extract_select_columns(sql);
416        assert_eq!(
417            cols,
418            vec!["onramp_name", "total_known_clients", "total_deals"]
419        );
420    }
421
422    #[test]
423    fn test_select_with_function() {
424        let sql = "SELECT COUNT(*) AS total, SUM(amount) AS total_amount FROM orders";
425        let cols = extract_select_columns(sql);
426        assert_eq!(cols, vec!["total", "total_amount"]);
427    }
428
429    #[test]
430    fn test_select_table_prefix_with_alias() {
431        let sql = "SELECT t.col1 AS alias1, t.col2 FROM my_table t";
432        let cols = extract_select_columns(sql);
433        assert_eq!(cols, vec!["alias1", "col2"]);
434    }
435
436    #[test]
437    fn test_no_select() {
438        let sql = "INSERT INTO my_table VALUES (1, 2, 3)";
439        let cols = extract_select_columns(sql);
440        assert!(cols.is_empty());
441    }
442
443    #[test]
444    fn test_select_with_jinja_comments() {
445        let sql = r#"
446            {# Select all order columns #}
447            SELECT order_id, status FROM orders
448        "#;
449        let cols = extract_select_columns(sql);
450        assert_eq!(cols, vec!["order_id", "status"]);
451    }
452
453    #[test]
454    fn test_select_with_cast() {
455        let sql = "SELECT CAST(order_id AS INTEGER) AS order_id, status FROM orders";
456        let cols = extract_select_columns(sql);
457        assert_eq!(cols, vec!["order_id", "status"]);
458    }
459
460    #[test]
461    fn test_select_with_subquery_alias() {
462        let sql = "SELECT (SELECT MAX(id) FROM t) AS max_id, name FROM users";
463        let cols = extract_select_columns(sql);
464        assert_eq!(cols, vec!["max_id", "name"]);
465    }
466
467    #[test]
468    fn test_typical_dbt_model() {
469        let sql = r#"
470            {{ config(materialized='view') }}
471
472            SELECT
473                order_id,
474                customer_id,
475                order_date,
476                status,
477                amount
478            FROM {{ ref('stg_orders') }}
479        "#;
480        let cols = extract_select_columns(sql);
481        assert_eq!(
482            cols,
483            vec!["order_id", "customer_id", "order_date", "status", "amount"]
484        );
485    }
486
487    #[test]
488    fn test_select_case_insensitive() {
489        let sql = "select col1, col2 from my_table";
490        let cols = extract_select_columns(sql);
491        assert_eq!(cols, vec!["col1", "col2"]);
492    }
493
494    #[test]
495    fn test_select_with_multibyte_utf8_comment() {
496        // GitHub Issue #1: panic on multi-byte UTF-8 characters
497        let sql = r#"SELECT
498    case
499      when flag = true then false -- 日本語コメント
500      else flag
501    end as flag
502FROM my_table"#;
503        let cols = extract_select_columns(sql);
504        assert_eq!(cols, vec!["flag"]);
505    }
506
507    #[test]
508    fn test_select_with_multibyte_utf8_string_literal() {
509        let sql = "SELECT '中文字符' AS label, col1 FROM my_table";
510        let cols = extract_select_columns(sql);
511        assert_eq!(cols, vec!["label", "col1"]);
512    }
513
514    #[test]
515    fn test_select_with_korean_comment_no_panic() {
516        // Verify no panic on Korean characters (comment stripping is a separate concern)
517        let sql = "SELECT col1, col2 -- 한국어 코멘트\nFROM my_table";
518        let cols = extract_select_columns(sql);
519        assert!(!cols.is_empty());
520    }
521
522    #[test]
523    fn test_select_with_emoji_comment_no_panic() {
524        // Verify no panic on emoji characters (comment stripping is a separate concern)
525        let sql = "SELECT col1 -- 🎉 celebration\nFROM my_table";
526        let cols = extract_select_columns(sql);
527        assert!(!cols.is_empty());
528    }
529
530    #[test]
531    fn test_select_with_backtick_identifiers() {
532        let sql = "SELECT `col1`, `col2` FROM my_table";
533        let cols = extract_select_columns(sql);
534        assert_eq!(cols, vec!["col1", "col2"]);
535    }
536
537    #[test]
538    fn test_extract_alias_after_paren_no_alias() {
539        // Subquery with no alias after the closing paren
540        let result = extract_alias_after_paren("(SELECT 1)");
541        assert!(result.is_none());
542    }
543
544    #[test]
545    fn test_extract_alias_after_paren_bare_alias() {
546        // Subquery with bare alias (no AS keyword)
547        let result = extract_alias_after_paren("(SELECT 1) my_alias");
548        assert_eq!(result, Some("my_alias".to_string()));
549    }
550
551    #[test]
552    fn test_extract_alias_after_paren_as_alias() {
553        // Subquery with AS alias
554        let result = extract_alias_after_paren("(SELECT 1) AS my_alias");
555        assert_eq!(result, Some("my_alias".to_string()));
556    }
557
558    #[test]
559    fn test_extract_alias_after_paren_no_paren() {
560        // No closing paren at all
561        let result = extract_alias_after_paren("SELECT 1");
562        assert!(result.is_none());
563    }
564}