1#[cfg(feature = "column-lineage")]
2use polyglot_sql::expressions::Expression;
3use regex::Regex;
4use std::sync::LazyLock;
5
6static JINJA_TAG: LazyLock<Regex> =
8 LazyLock::new(|| Regex::new(r"\{\{-?[\s\S]*?-?\}\}|\{%-?[\s\S]*?-?%\}").unwrap());
9
10static JINJA_COMMENT: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\{#[\s\S]*?#\}").unwrap());
12
13pub fn extract_select_columns(sql: &str) -> Vec<String> {
26 let cleaned = JINJA_COMMENT.replace_all(sql, "");
28 let cleaned = JINJA_TAG.replace_all(&cleaned, "__jinja__");
29
30 let select_end = match find_last_top_level_select(&cleaned) {
34 Some(end) => end,
35 None => return vec![],
36 };
37
38 let after_select = &cleaned[select_end..];
40 let select_body = match find_top_level_from(after_select) {
41 Some(pos) => &after_select[..pos],
42 None => return vec![],
43 };
44
45 let items = split_top_level_commas(select_body);
47
48 items
49 .iter()
50 .filter_map(|item| classify_select_item(item.trim()))
51 .collect()
52}
53
54fn classify_select_item(item: &str) -> Option<String> {
56 if item.is_empty() {
57 return None;
58 }
59
60 if item.starts_with('(') {
62 return extract_alias_after_paren(item);
63 }
64
65 let col = extract_column_name(item);
66 if col.is_empty() { None } else { Some(col) }
67}
68
69fn find_last_top_level_select(s: &str) -> Option<usize> {
72 let bytes = s.as_bytes();
73 let len = bytes.len();
74 let mut depth: u32 = 0;
75 let mut last_select_end: Option<usize> = None;
76 let mut i = 0;
77
78 while i < len {
79 match bytes[i] {
80 b'(' => depth += 1,
81 b')' => {
82 depth = depth.saturating_sub(1);
83 }
84 b's' | b'S' if depth == 0 && check_keyword_at(bytes, i, len, b"SELECT") => {
85 let end = i + 6;
86 let after = skip_whitespace(bytes, end, len);
88 if check_keyword_at(bytes, after, len, b"DISTINCT") {
89 let after_distinct = skip_whitespace(bytes, after + 8, len);
90 last_select_end = Some(after_distinct);
91 } else {
92 last_select_end = Some(after);
93 }
94 }
95 _ => {}
96 }
97 i += 1;
98 }
99
100 last_select_end
101}
102
103fn check_keyword_at(bytes: &[u8], i: usize, len: usize, keyword: &[u8]) -> bool {
106 let klen = keyword.len();
107 if i + klen > len {
108 return false;
109 }
110 for j in 0..klen {
111 if !bytes[i + j].eq_ignore_ascii_case(&keyword[j]) {
112 return false;
113 }
114 }
115 let before_ok = i == 0 || is_word_boundary(bytes[i - 1]);
116 let after_ok = i + klen >= len || is_word_boundary(bytes[i + klen]);
117 before_ok && after_ok
118}
119
120fn skip_whitespace(bytes: &[u8], start: usize, len: usize) -> usize {
122 let mut i = start;
123 while i < len && bytes[i].is_ascii_whitespace() {
124 i += 1;
125 }
126 i
127}
128
129fn is_word_boundary(b: u8) -> bool {
131 !b.is_ascii_alphanumeric() && b != b'_'
132}
133
134fn check_from_at(_s: &str, bytes: &[u8], i: usize, len: usize) -> bool {
137 if i + 4 > len {
138 return false;
139 }
140 let from_match = matches!(bytes[i], b'f' | b'F')
141 && matches!(bytes[i + 1], b'r' | b'R')
142 && matches!(bytes[i + 2], b'o' | b'O')
143 && matches!(bytes[i + 3], b'm' | b'M');
144 if !from_match {
145 return false;
146 }
147 let before_ok = i == 0 || is_word_boundary(bytes[i - 1]);
148 let after_ok = i + 4 >= len || is_word_boundary(bytes[i + 4]);
149 before_ok && after_ok
150}
151
152fn find_top_level_from(s: &str) -> Option<usize> {
155 let bytes = s.as_bytes();
156 let len = bytes.len();
157 let mut depth: u32 = 0;
158 let mut i = 0;
159
160 while i < len {
161 match bytes[i] {
162 b'(' => depth += 1,
163 b')' => {
164 depth = depth.saturating_sub(1);
165 }
166 b'f' | b'F' if depth == 0 && check_from_at(s, bytes, i, len) => {
167 return Some(i);
168 }
169 _ => {}
170 }
171 i += 1;
172 }
173
174 None
175}
176
177fn split_top_level_commas(s: &str) -> Vec<String> {
179 let mut items = Vec::new();
180 let mut current = String::new();
181 let mut depth = 0;
182
183 for ch in s.chars() {
184 match ch {
185 '(' => {
186 depth += 1;
187 current.push(ch);
188 }
189 ')' => {
190 depth -= 1;
191 current.push(ch);
192 }
193 ',' if depth == 0 => {
194 items.push(current.clone());
195 current.clear();
196 }
197 _ => {
198 current.push(ch);
199 }
200 }
201 }
202
203 if !current.trim().is_empty() {
204 items.push(current);
205 }
206
207 items
208}
209
210fn extract_alias_after_paren(item: &str) -> Option<String> {
212 let close = item.rfind(')')?;
214 let after = item[close + 1..].trim();
215 if after.is_empty() {
216 return None;
217 }
218 let after = if after.len() >= 3
221 && matches!(after.as_bytes()[0], b'a' | b'A')
222 && matches!(after.as_bytes()[1], b's' | b'S')
223 && after.as_bytes()[2].is_ascii_whitespace()
224 {
225 after[2..].trim()
226 } else {
227 after
228 };
229 if after.is_empty() {
230 None
231 } else {
232 Some(clean_identifier(after))
233 }
234}
235
236fn extract_column_name(item: &str) -> String {
243 let item = item.trim();
244
245 if let Some(alias) = find_last_as_alias(item) {
248 return clean_identifier(&alias);
249 }
250
251 let last_token = item.split_whitespace().last().unwrap_or(item);
253
254 if let Some(pos) = last_token.rfind('.') {
256 return clean_identifier(&last_token[pos + 1..]);
257 }
258
259 clean_identifier(last_token)
260}
261
262fn is_as_keyword_at(_item: &str, bytes: &[u8], i: usize, len: usize) -> Option<usize> {
266 if i + 3 >= len {
267 return None;
268 }
269 let as_match = matches!(bytes[i + 1], b'a' | b'A') && matches!(bytes[i + 2], b's' | b'S');
270 if as_match && matches!(bytes[i + 3], b' ' | b'\t' | b'\n' | b'\r') {
271 Some(i + 4)
272 } else {
273 None
274 }
275}
276
277fn find_last_as_alias(item: &str) -> Option<String> {
279 let bytes = item.as_bytes();
280 let len = bytes.len();
281 let mut depth = 0;
282 let mut last_as_pos: Option<usize> = None;
283
284 let mut i = 0;
285 while i < len {
286 match bytes[i] {
287 b'(' => depth += 1,
288 b')' if depth > 0 => {
289 depth -= 1;
290 }
291 b' ' | b'\t' | b'\n' | b'\r' if depth == 0 => {
292 if let Some(pos) = is_as_keyword_at(item, bytes, i, len) {
293 last_as_pos = Some(pos);
294 }
295 }
296 _ => {}
297 }
298 i += 1;
299 }
300
301 last_as_pos.map(|pos| item[pos..].trim().to_string())
302}
303
304fn clean_identifier(s: &str) -> String {
306 let s = s.trim();
307 let s = s.trim_matches('`');
308 let s = s.trim_matches('"');
309 s.to_string()
310}
311
312#[cfg(feature = "column-lineage")]
319pub fn extract_select_columns_from_expr(
320 expr: &Expression,
321 schema: Option<&dyn polyglot_sql::Schema>,
322) -> Vec<String> {
323 let mut owned = expr.clone();
324 polyglot_sql::lineage::expand_cte_stars(&mut owned, schema);
325 match &owned {
326 Expression::Select(select) => select
327 .expressions
328 .iter()
329 .filter_map(|e| match e {
330 Expression::Alias(a) => Some(a.alias.name.clone()),
331 Expression::Column(c) => {
332 if c.name.name == "*" {
333 None } else {
335 Some(c.name.name.clone())
336 }
337 }
338 Expression::Identifier(id) => Some(id.name.clone()),
339 Expression::Star(_) => None, _ => None,
341 })
342 .collect(),
343 _ => vec![],
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[cfg(feature = "column-lineage")]
352 #[test]
353 fn test_extract_from_expr_cte_star() {
354 let sql = r#"with
355source as (select * from "raw"."raw_orders"),
356renamed as (
357 select id as order_id, customer as customer_id, ordered_at
358 from source
359)
360select * from renamed"#;
361 let expr = polyglot_sql::parse_one(sql, polyglot_sql::DialectType::Generic).unwrap();
362 let cols = extract_select_columns_from_expr(&expr, None);
363 assert_eq!(cols, vec!["order_id", "customer_id", "ordered_at"]);
364 }
365
366 #[cfg(feature = "column-lineage")]
367 #[test]
368 fn test_extract_from_expr_cte_star_with_cast() {
369 let sql = r#"with
371source as (
372 select * from "jaffle_shop"."raw"."raw_orders"
373),
374renamed as (
375 select
376 id as order_id,
377 store_id as location_id,
378 customer as customer_id,
379 subtotal as subtotal_cents,
380 tax_paid as tax_paid_cents,
381 order_total as order_total_cents,
382 (subtotal / 100)::numeric(16, 2) as subtotal,
383 (tax_paid / 100)::numeric(16, 2) as tax_paid,
384 (order_total / 100)::numeric(16, 2) as order_total,
385 date_trunc('day', ordered_at) as ordered_at
386 from source
387)
388select * from renamed"#;
389 let expr = polyglot_sql::parse_one(sql, polyglot_sql::DialectType::Generic).unwrap();
390 let cols = extract_select_columns_from_expr(&expr, None);
391 assert!(cols.contains(&"order_id".to_string()), "cols: {:?}", cols);
392 assert!(
393 cols.contains(&"customer_id".to_string()),
394 "cols: {:?}",
395 cols
396 );
397 assert!(cols.contains(&"ordered_at".to_string()), "cols: {:?}", cols);
398 assert!(
399 cols.contains(&"order_total".to_string()),
400 "cols: {:?}",
401 cols
402 );
403 assert_eq!(cols.len(), 10, "cols: {:?}", cols);
404 }
405
406 #[test]
407 fn test_simple_select() {
408 let sql = "SELECT col1, col2 FROM my_table";
409 let cols = extract_select_columns(sql);
410 assert_eq!(cols, vec!["col1", "col2"]);
411 }
412
413 #[test]
414 fn test_select_with_aliases() {
415 let sql = "SELECT col1 AS alias1, col2 as alias2 FROM my_table";
416 let cols = extract_select_columns(sql);
417 assert_eq!(cols, vec!["alias1", "alias2"]);
418 }
419
420 #[test]
421 fn test_select_with_table_prefixes() {
422 let sql = "SELECT t.col1, t.col2 FROM my_table t";
423 let cols = extract_select_columns(sql);
424 assert_eq!(cols, vec!["col1", "col2"]);
425 }
426
427 #[test]
428 fn test_select_star() {
429 let sql = "SELECT * FROM my_table";
430 let cols = extract_select_columns(sql);
431 assert_eq!(cols, vec!["*"]);
432 }
433
434 #[test]
435 fn test_select_distinct() {
436 let sql = "SELECT DISTINCT col1, col2 FROM my_table";
437 let cols = extract_select_columns(sql);
438 assert_eq!(cols, vec!["col1", "col2"]);
439 }
440
441 #[test]
442 fn test_select_with_jinja() {
443 let sql = r#"
444 {{ config(materialized='table') }}
445
446 SELECT
447 order_id,
448 {{ dbt_utils.star(from=ref('stg_orders')) }},
449 customer_id
450 FROM {{ ref('stg_orders') }}
451 "#;
452 let cols = extract_select_columns(sql);
453 assert_eq!(cols, vec!["order_id", "__jinja__", "customer_id"]);
454 }
455
456 #[test]
457 fn test_multiline_select() {
458 let sql = r#"
459 SELECT
460 order_id,
461 customer_id,
462 order_date,
463 status
464 FROM orders
465 "#;
466 let cols = extract_select_columns(sql);
467 assert_eq!(
468 cols,
469 vec!["order_id", "customer_id", "order_date", "status"]
470 );
471 }
472
473 #[test]
474 fn test_cte_gets_outer_select() {
475 let sql = r#"
476 WITH cte AS (
477 SELECT inner_col1, inner_col2 FROM raw_table
478 )
479 SELECT outer_col1, outer_col2 FROM cte
480 "#;
481 let cols = extract_select_columns(sql);
482 assert_eq!(cols, vec!["outer_col1", "outer_col2"]);
483 }
484
485 #[test]
486 fn test_multiple_ctes_gets_final_select() {
487 let sql = r#"
488 WITH cte1 AS (
489 SELECT * FROM raw_table
490 ),
491 cte2 AS (
492 SELECT a, b FROM cte1
493 )
494 SELECT
495 onramp_name,
496 count(distinct client_id) as total_known_clients,
497 sum(total_deals) as total_deals
498 FROM cte2
499 GROUP BY 1
500 "#;
501 let cols = extract_select_columns(sql);
502 assert_eq!(
503 cols,
504 vec!["onramp_name", "total_known_clients", "total_deals"]
505 );
506 }
507
508 #[test]
509 fn test_select_with_function() {
510 let sql = "SELECT COUNT(*) AS total, SUM(amount) AS total_amount FROM orders";
511 let cols = extract_select_columns(sql);
512 assert_eq!(cols, vec!["total", "total_amount"]);
513 }
514
515 #[test]
516 fn test_select_table_prefix_with_alias() {
517 let sql = "SELECT t.col1 AS alias1, t.col2 FROM my_table t";
518 let cols = extract_select_columns(sql);
519 assert_eq!(cols, vec!["alias1", "col2"]);
520 }
521
522 #[test]
523 fn test_no_select() {
524 let sql = "INSERT INTO my_table VALUES (1, 2, 3)";
525 let cols = extract_select_columns(sql);
526 assert!(cols.is_empty());
527 }
528
529 #[test]
530 fn test_select_with_jinja_comments() {
531 let sql = r#"
532 {# Select all order columns #}
533 SELECT order_id, status FROM orders
534 "#;
535 let cols = extract_select_columns(sql);
536 assert_eq!(cols, vec!["order_id", "status"]);
537 }
538
539 #[test]
540 fn test_select_with_cast() {
541 let sql = "SELECT CAST(order_id AS INTEGER) AS order_id, status FROM orders";
542 let cols = extract_select_columns(sql);
543 assert_eq!(cols, vec!["order_id", "status"]);
544 }
545
546 #[test]
547 fn test_select_with_subquery_alias() {
548 let sql = "SELECT (SELECT MAX(id) FROM t) AS max_id, name FROM users";
549 let cols = extract_select_columns(sql);
550 assert_eq!(cols, vec!["max_id", "name"]);
551 }
552
553 #[test]
554 fn test_typical_dbt_model() {
555 let sql = r#"
556 {{ config(materialized='view') }}
557
558 SELECT
559 order_id,
560 customer_id,
561 order_date,
562 status,
563 amount
564 FROM {{ ref('stg_orders') }}
565 "#;
566 let cols = extract_select_columns(sql);
567 assert_eq!(
568 cols,
569 vec!["order_id", "customer_id", "order_date", "status", "amount"]
570 );
571 }
572
573 #[test]
574 fn test_select_case_insensitive() {
575 let sql = "select col1, col2 from my_table";
576 let cols = extract_select_columns(sql);
577 assert_eq!(cols, vec!["col1", "col2"]);
578 }
579
580 #[test]
581 fn test_select_with_multibyte_utf8_comment() {
582 let sql = r#"SELECT
584 case
585 when flag = true then false -- 日本語コメント
586 else flag
587 end as flag
588FROM my_table"#;
589 let cols = extract_select_columns(sql);
590 assert_eq!(cols, vec!["flag"]);
591 }
592
593 #[test]
594 fn test_select_with_multibyte_utf8_string_literal() {
595 let sql = "SELECT '中文字符' AS label, col1 FROM my_table";
596 let cols = extract_select_columns(sql);
597 assert_eq!(cols, vec!["label", "col1"]);
598 }
599
600 #[test]
601 fn test_select_with_korean_comment_no_panic() {
602 let sql = "SELECT col1, col2 -- 한국어 코멘트\nFROM my_table";
604 let cols = extract_select_columns(sql);
605 assert!(!cols.is_empty());
606 }
607
608 #[test]
609 fn test_select_with_emoji_comment_no_panic() {
610 let sql = "SELECT col1 -- 🎉 celebration\nFROM my_table";
612 let cols = extract_select_columns(sql);
613 assert!(!cols.is_empty());
614 }
615
616 #[test]
617 fn test_select_with_backtick_identifiers() {
618 let sql = "SELECT `col1`, `col2` FROM my_table";
619 let cols = extract_select_columns(sql);
620 assert_eq!(cols, vec!["col1", "col2"]);
621 }
622
623 #[test]
624 fn test_extract_alias_after_paren_no_alias() {
625 let result = extract_alias_after_paren("(SELECT 1)");
627 assert!(result.is_none());
628 }
629
630 #[test]
631 fn test_extract_alias_after_paren_bare_alias() {
632 let result = extract_alias_after_paren("(SELECT 1) my_alias");
634 assert_eq!(result, Some("my_alias".to_string()));
635 }
636
637 #[test]
638 fn test_extract_alias_after_paren_as_alias() {
639 let result = extract_alias_after_paren("(SELECT 1) AS my_alias");
641 assert_eq!(result, Some("my_alias".to_string()));
642 }
643
644 #[test]
645 fn test_extract_alias_after_paren_no_paren() {
646 let result = extract_alias_after_paren("SELECT 1");
648 assert!(result.is_none());
649 }
650}