1use regex::Regex;
2use std::sync::LazyLock;
3
4static JINJA_TAG: LazyLock<Regex> =
6 LazyLock::new(|| Regex::new(r"\{\{-?[\s\S]*?-?\}\}|\{%-?[\s\S]*?-?%\}").unwrap());
7
8static JINJA_COMMENT: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\{#[\s\S]*?#\}").unwrap());
10
11pub fn extract_select_columns(sql: &str) -> Vec<String> {
24 let cleaned = JINJA_COMMENT.replace_all(sql, "");
26 let cleaned = JINJA_TAG.replace_all(&cleaned, "__jinja__");
27
28 let select_end = match find_last_top_level_select(&cleaned) {
32 Some(end) => end,
33 None => return vec![],
34 };
35
36 let after_select = &cleaned[select_end..];
38 let select_body = match find_top_level_from(after_select) {
39 Some(pos) => &after_select[..pos],
40 None => return vec![],
41 };
42
43 let items = split_top_level_commas(select_body);
45
46 items
47 .iter()
48 .filter_map(|item| classify_select_item(item.trim()))
49 .collect()
50}
51
52fn classify_select_item(item: &str) -> Option<String> {
54 if item.is_empty() {
55 return None;
56 }
57
58 if item.starts_with('(') {
60 return extract_alias_after_paren(item);
61 }
62
63 let col = extract_column_name(item);
64 if col.is_empty() { None } else { Some(col) }
65}
66
67fn find_last_top_level_select(s: &str) -> Option<usize> {
70 let bytes = s.as_bytes();
71 let len = bytes.len();
72 let mut depth: u32 = 0;
73 let mut last_select_end: Option<usize> = None;
74 let mut i = 0;
75
76 while i < len {
77 match bytes[i] {
78 b'(' => depth += 1,
79 b')' => {
80 depth = depth.saturating_sub(1);
81 }
82 b's' | b'S' if depth == 0 => {
83 if check_keyword_at(bytes, i, len, b"SELECT") {
84 let end = i + 6;
85 let after = skip_whitespace(bytes, end, len);
87 if check_keyword_at(bytes, after, len, b"DISTINCT") {
88 let after_distinct = skip_whitespace(bytes, after + 8, len);
89 last_select_end = Some(after_distinct);
90 } else {
91 last_select_end = Some(after);
92 }
93 }
94 }
95 _ => {}
96 }
97 i += 1;
98 }
99
100 last_select_end
101}
102
103fn check_keyword_at(bytes: &[u8], i: usize, len: usize, keyword: &[u8]) -> bool {
106 let klen = keyword.len();
107 if i + klen > len {
108 return false;
109 }
110 for j in 0..klen {
111 if !bytes[i + j].eq_ignore_ascii_case(&keyword[j]) {
112 return false;
113 }
114 }
115 let before_ok = i == 0 || is_word_boundary(bytes[i - 1]);
116 let after_ok = i + klen >= len || is_word_boundary(bytes[i + klen]);
117 before_ok && after_ok
118}
119
120fn skip_whitespace(bytes: &[u8], start: usize, len: usize) -> usize {
122 let mut i = start;
123 while i < len && bytes[i].is_ascii_whitespace() {
124 i += 1;
125 }
126 i
127}
128
129fn is_word_boundary(b: u8) -> bool {
131 !b.is_ascii_alphanumeric() && b != b'_'
132}
133
134fn check_from_at(_s: &str, bytes: &[u8], i: usize, len: usize) -> bool {
137 if i + 4 > len {
138 return false;
139 }
140 let from_match = matches!(bytes[i], b'f' | b'F')
141 && matches!(bytes[i + 1], b'r' | b'R')
142 && matches!(bytes[i + 2], b'o' | b'O')
143 && matches!(bytes[i + 3], b'm' | b'M');
144 if !from_match {
145 return false;
146 }
147 let before_ok = i == 0 || is_word_boundary(bytes[i - 1]);
148 let after_ok = i + 4 >= len || is_word_boundary(bytes[i + 4]);
149 before_ok && after_ok
150}
151
152fn find_top_level_from(s: &str) -> Option<usize> {
155 let bytes = s.as_bytes();
156 let len = bytes.len();
157 let mut depth: u32 = 0;
158 let mut i = 0;
159
160 while i < len {
161 match bytes[i] {
162 b'(' => depth += 1,
163 b')' => {
164 depth = depth.saturating_sub(1);
165 }
166 b'f' | b'F' if depth == 0 => {
167 if check_from_at(s, bytes, i, len) {
168 return Some(i);
169 }
170 }
171 _ => {}
172 }
173 i += 1;
174 }
175
176 None
177}
178
179fn split_top_level_commas(s: &str) -> Vec<String> {
181 let mut items = Vec::new();
182 let mut current = String::new();
183 let mut depth = 0;
184
185 for ch in s.chars() {
186 match ch {
187 '(' => {
188 depth += 1;
189 current.push(ch);
190 }
191 ')' => {
192 depth -= 1;
193 current.push(ch);
194 }
195 ',' if depth == 0 => {
196 items.push(current.clone());
197 current.clear();
198 }
199 _ => {
200 current.push(ch);
201 }
202 }
203 }
204
205 if !current.trim().is_empty() {
206 items.push(current);
207 }
208
209 items
210}
211
212fn extract_alias_after_paren(item: &str) -> Option<String> {
214 let close = item.rfind(')')?;
216 let after = item[close + 1..].trim();
217 if after.is_empty() {
218 return None;
219 }
220 let after = if after.len() >= 3
223 && matches!(after.as_bytes()[0], b'a' | b'A')
224 && matches!(after.as_bytes()[1], b's' | b'S')
225 && after.as_bytes()[2].is_ascii_whitespace()
226 {
227 after[2..].trim()
228 } else {
229 after
230 };
231 if after.is_empty() {
232 None
233 } else {
234 Some(clean_identifier(after))
235 }
236}
237
238fn extract_column_name(item: &str) -> String {
245 let item = item.trim();
246
247 if let Some(alias) = find_last_as_alias(item) {
250 return clean_identifier(&alias);
251 }
252
253 let last_token = item.split_whitespace().last().unwrap_or(item);
255
256 if let Some(pos) = last_token.rfind('.') {
258 return clean_identifier(&last_token[pos + 1..]);
259 }
260
261 clean_identifier(last_token)
262}
263
264fn is_as_keyword_at(_item: &str, bytes: &[u8], i: usize, len: usize) -> Option<usize> {
268 if i + 3 >= len {
269 return None;
270 }
271 let as_match = matches!(bytes[i + 1], b'a' | b'A') && matches!(bytes[i + 2], b's' | b'S');
272 if as_match && matches!(bytes[i + 3], b' ' | b'\t' | b'\n' | b'\r') {
273 Some(i + 4)
274 } else {
275 None
276 }
277}
278
279fn find_last_as_alias(item: &str) -> Option<String> {
281 let bytes = item.as_bytes();
282 let len = bytes.len();
283 let mut depth = 0;
284 let mut last_as_pos: Option<usize> = None;
285
286 let mut i = 0;
287 while i < len {
288 match bytes[i] {
289 b'(' => depth += 1,
290 b')' => {
291 if depth > 0 {
292 depth -= 1;
293 }
294 }
295 b' ' | b'\t' | b'\n' | b'\r' if depth == 0 => {
296 if let Some(pos) = is_as_keyword_at(item, bytes, i, len) {
297 last_as_pos = Some(pos);
298 }
299 }
300 _ => {}
301 }
302 i += 1;
303 }
304
305 last_as_pos.map(|pos| item[pos..].trim().to_string())
306}
307
308fn clean_identifier(s: &str) -> String {
310 let s = s.trim();
311 let s = s.trim_matches('`');
312 let s = s.trim_matches('"');
313 s.to_string()
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 #[test]
321 fn test_simple_select() {
322 let sql = "SELECT col1, col2 FROM my_table";
323 let cols = extract_select_columns(sql);
324 assert_eq!(cols, vec!["col1", "col2"]);
325 }
326
327 #[test]
328 fn test_select_with_aliases() {
329 let sql = "SELECT col1 AS alias1, col2 as alias2 FROM my_table";
330 let cols = extract_select_columns(sql);
331 assert_eq!(cols, vec!["alias1", "alias2"]);
332 }
333
334 #[test]
335 fn test_select_with_table_prefixes() {
336 let sql = "SELECT t.col1, t.col2 FROM my_table t";
337 let cols = extract_select_columns(sql);
338 assert_eq!(cols, vec!["col1", "col2"]);
339 }
340
341 #[test]
342 fn test_select_star() {
343 let sql = "SELECT * FROM my_table";
344 let cols = extract_select_columns(sql);
345 assert_eq!(cols, vec!["*"]);
346 }
347
348 #[test]
349 fn test_select_distinct() {
350 let sql = "SELECT DISTINCT col1, col2 FROM my_table";
351 let cols = extract_select_columns(sql);
352 assert_eq!(cols, vec!["col1", "col2"]);
353 }
354
355 #[test]
356 fn test_select_with_jinja() {
357 let sql = r#"
358 {{ config(materialized='table') }}
359
360 SELECT
361 order_id,
362 {{ dbt_utils.star(from=ref('stg_orders')) }},
363 customer_id
364 FROM {{ ref('stg_orders') }}
365 "#;
366 let cols = extract_select_columns(sql);
367 assert_eq!(cols, vec!["order_id", "__jinja__", "customer_id"]);
368 }
369
370 #[test]
371 fn test_multiline_select() {
372 let sql = r#"
373 SELECT
374 order_id,
375 customer_id,
376 order_date,
377 status
378 FROM orders
379 "#;
380 let cols = extract_select_columns(sql);
381 assert_eq!(
382 cols,
383 vec!["order_id", "customer_id", "order_date", "status"]
384 );
385 }
386
387 #[test]
388 fn test_cte_gets_outer_select() {
389 let sql = r#"
390 WITH cte AS (
391 SELECT inner_col1, inner_col2 FROM raw_table
392 )
393 SELECT outer_col1, outer_col2 FROM cte
394 "#;
395 let cols = extract_select_columns(sql);
396 assert_eq!(cols, vec!["outer_col1", "outer_col2"]);
397 }
398
399 #[test]
400 fn test_multiple_ctes_gets_final_select() {
401 let sql = r#"
402 WITH cte1 AS (
403 SELECT * FROM raw_table
404 ),
405 cte2 AS (
406 SELECT a, b FROM cte1
407 )
408 SELECT
409 onramp_name,
410 count(distinct client_id) as total_known_clients,
411 sum(total_deals) as total_deals
412 FROM cte2
413 GROUP BY 1
414 "#;
415 let cols = extract_select_columns(sql);
416 assert_eq!(
417 cols,
418 vec!["onramp_name", "total_known_clients", "total_deals"]
419 );
420 }
421
422 #[test]
423 fn test_select_with_function() {
424 let sql = "SELECT COUNT(*) AS total, SUM(amount) AS total_amount FROM orders";
425 let cols = extract_select_columns(sql);
426 assert_eq!(cols, vec!["total", "total_amount"]);
427 }
428
429 #[test]
430 fn test_select_table_prefix_with_alias() {
431 let sql = "SELECT t.col1 AS alias1, t.col2 FROM my_table t";
432 let cols = extract_select_columns(sql);
433 assert_eq!(cols, vec!["alias1", "col2"]);
434 }
435
436 #[test]
437 fn test_no_select() {
438 let sql = "INSERT INTO my_table VALUES (1, 2, 3)";
439 let cols = extract_select_columns(sql);
440 assert!(cols.is_empty());
441 }
442
443 #[test]
444 fn test_select_with_jinja_comments() {
445 let sql = r#"
446 {# Select all order columns #}
447 SELECT order_id, status FROM orders
448 "#;
449 let cols = extract_select_columns(sql);
450 assert_eq!(cols, vec!["order_id", "status"]);
451 }
452
453 #[test]
454 fn test_select_with_cast() {
455 let sql = "SELECT CAST(order_id AS INTEGER) AS order_id, status FROM orders";
456 let cols = extract_select_columns(sql);
457 assert_eq!(cols, vec!["order_id", "status"]);
458 }
459
460 #[test]
461 fn test_select_with_subquery_alias() {
462 let sql = "SELECT (SELECT MAX(id) FROM t) AS max_id, name FROM users";
463 let cols = extract_select_columns(sql);
464 assert_eq!(cols, vec!["max_id", "name"]);
465 }
466
467 #[test]
468 fn test_typical_dbt_model() {
469 let sql = r#"
470 {{ config(materialized='view') }}
471
472 SELECT
473 order_id,
474 customer_id,
475 order_date,
476 status,
477 amount
478 FROM {{ ref('stg_orders') }}
479 "#;
480 let cols = extract_select_columns(sql);
481 assert_eq!(
482 cols,
483 vec!["order_id", "customer_id", "order_date", "status", "amount"]
484 );
485 }
486
487 #[test]
488 fn test_select_case_insensitive() {
489 let sql = "select col1, col2 from my_table";
490 let cols = extract_select_columns(sql);
491 assert_eq!(cols, vec!["col1", "col2"]);
492 }
493
494 #[test]
495 fn test_select_with_multibyte_utf8_comment() {
496 let sql = r#"SELECT
498 case
499 when flag = true then false -- 日本語コメント
500 else flag
501 end as flag
502FROM my_table"#;
503 let cols = extract_select_columns(sql);
504 assert_eq!(cols, vec!["flag"]);
505 }
506
507 #[test]
508 fn test_select_with_multibyte_utf8_string_literal() {
509 let sql = "SELECT '中文字符' AS label, col1 FROM my_table";
510 let cols = extract_select_columns(sql);
511 assert_eq!(cols, vec!["label", "col1"]);
512 }
513
514 #[test]
515 fn test_select_with_korean_comment_no_panic() {
516 let sql = "SELECT col1, col2 -- 한국어 코멘트\nFROM my_table";
518 let cols = extract_select_columns(sql);
519 assert!(!cols.is_empty());
520 }
521
522 #[test]
523 fn test_select_with_emoji_comment_no_panic() {
524 let sql = "SELECT col1 -- 🎉 celebration\nFROM my_table";
526 let cols = extract_select_columns(sql);
527 assert!(!cols.is_empty());
528 }
529
530 #[test]
531 fn test_select_with_backtick_identifiers() {
532 let sql = "SELECT `col1`, `col2` FROM my_table";
533 let cols = extract_select_columns(sql);
534 assert_eq!(cols, vec!["col1", "col2"]);
535 }
536
537 #[test]
538 fn test_extract_alias_after_paren_no_alias() {
539 let result = extract_alias_after_paren("(SELECT 1)");
541 assert!(result.is_none());
542 }
543
544 #[test]
545 fn test_extract_alias_after_paren_bare_alias() {
546 let result = extract_alias_after_paren("(SELECT 1) my_alias");
548 assert_eq!(result, Some("my_alias".to_string()));
549 }
550
551 #[test]
552 fn test_extract_alias_after_paren_as_alias() {
553 let result = extract_alias_after_paren("(SELECT 1) AS my_alias");
555 assert_eq!(result, Some("my_alias".to_string()));
556 }
557
558 #[test]
559 fn test_extract_alias_after_paren_no_paren() {
560 let result = extract_alias_after_paren("SELECT 1");
562 assert!(result.is_none());
563 }
564}