Skip to main content

dlin_core/parser/
sql.rs

1use regex::Regex;
2use std::sync::LazyLock;
3
4/// A reference to another dbt model via ref()
5#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
6pub struct RefCall {
7    /// Optional package name (for cross-project refs)
8    pub package: Option<String>,
9    /// Model name
10    pub name: String,
11    /// Version from ref('name', version=N) or ref('name', version='alpha').
12    /// Stored as a string to support both integer and non-integer versions.
13    #[serde(default)]
14    pub version: Option<String>,
15}
16
17/// A reference to a dbt source via source()
18#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
19pub struct SourceCall {
20    /// Source name
21    pub source_name: String,
22    /// Table name within the source
23    pub table_name: String,
24}
25
26static JINJA_COMMENT: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\{#[\s\S]*?#\}").unwrap());
27
28// Matches ref('name'), ref("name"), ref('pkg', 'name'), ref("pkg", "name"),
29// ref('name', version=N), ref('name', v=N), and the pkg variants.
30// Both `version=` and `v=` are accepted per dbt-core v2.
31// Version values may be bare integers (version=2) or quoted strings (version='alpha').
32// Handles {{ ref(...) }} and {{- ref(...) -}} whitespace control.
33// Capture groups:
34//   1, 2 → pkg, name      (two-positional-arg form)
35//   3    → version        (optional version=/v= kwarg in two-arg form)
36//   4, 5 → name, version  (single-arg + version=/v= kwarg form)
37//   6    → name           (single-arg form)
38static REF_PATTERN: LazyLock<Regex> = LazyLock::new(|| {
39    Regex::new(
40        r#"(?x)
41        \{\{-?\s*
42        ref\s*\(\s*
43        (?:
44            # Two-argument form: ref('pkg', 'name') or ref('pkg', 'name', version=N) or ref('pkg', 'name', v=N)
45            (?:['"]([^'"]+)['"]\s*,\s*['"]([^'"]+)['"]\s*(?:,\s*(?:version|v)\s*=\s*(-?\d+|'[^']*'|"[^"]*"))?)
46            |
47            # Single-arg + version kwarg: ref('name', version=N) or ref('name', v=N)
48            (?:['"]([^'"]+)['"]\s*,\s*(?:version|v)\s*=\s*(-?\d+|'[^']*'|"[^"]*"))
49            |
50            # Single-argument form: ref('name') or ref("name")
51            ['"]([^'"]+)['"]
52        )
53        \s*\)\s*
54        -?\}\}
55    "#,
56    )
57    .unwrap()
58});
59
60// Matches source('src_name', 'table_name')
61static SOURCE_PATTERN: LazyLock<Regex> = LazyLock::new(|| {
62    Regex::new(
63        r#"(?x)
64        \{\{-?\s*
65        source\s*\(\s*
66        ['"]([^'"]+)['"]\s*,\s*['"]([^'"]+)['"]
67        \s*\)\s*
68        -?\}\}
69    "#,
70    )
71    .unwrap()
72});
73
74/// Strip Jinja comments from SQL content
75fn strip_jinja_comments(sql: &str) -> String {
76    JINJA_COMMENT.replace_all(sql, "").to_string()
77}
78
79/// Extract all refs, sources, and config from SQL content in a single pass.
80/// Tries minijinja rendering first; falls back to regex on failure.
81///
82/// `macro_prefix` is the pre-built concatenation of valid macro SQL files
83/// so that custom macros containing ref()/source() are expanded and tracked.
84pub fn extract_all(sql: &str, macro_prefix: &str) -> super::jinja::JinjaExtraction {
85    extract_all_with_vars(sql, macro_prefix, &std::collections::HashMap::new())
86}
87
88/// Like [`extract_all`] but resolves `var()` calls using project-level variables.
89pub fn extract_all_with_vars(
90    sql: &str,
91    macro_prefix: &str,
92    vars: &std::collections::HashMap<String, serde_json::Value>,
93) -> super::jinja::JinjaExtraction {
94    if let Some(ext) = super::jinja::extract_via_jinja_with_vars(sql, macro_prefix, vars) {
95        return ext;
96    }
97    super::jinja::JinjaExtraction {
98        refs: extract_refs_regex(sql),
99        sources: extract_sources_regex(sql),
100        config: extract_config_regex(sql),
101    }
102}
103
104/// Extract all ref() and source() calls from SQL content in a single pass.
105/// Tries minijinja rendering first; falls back to regex on failure.
106///
107/// `macro_prefix` is the pre-built concatenation of valid macro SQL files
108/// so that custom macros containing ref()/source() are expanded and tracked.
109pub fn extract_refs_and_sources(sql: &str, macro_prefix: &str) -> (Vec<RefCall>, Vec<SourceCall>) {
110    extract_refs_and_sources_with_vars(sql, macro_prefix, &std::collections::HashMap::new())
111}
112
113/// Like [`extract_refs_and_sources`] but resolves `var()` calls using project-level variables.
114pub fn extract_refs_and_sources_with_vars(
115    sql: &str,
116    macro_prefix: &str,
117    vars: &std::collections::HashMap<String, serde_json::Value>,
118) -> (Vec<RefCall>, Vec<SourceCall>) {
119    if let Some(ext) = super::jinja::extract_via_jinja_with_vars(sql, macro_prefix, vars) {
120        return (ext.refs, ext.sources);
121    }
122    (extract_refs_regex(sql), extract_sources_regex(sql))
123}
124
125/// Extract all ref() calls from SQL content.
126pub fn extract_refs(sql: &str) -> Vec<RefCall> {
127    extract_refs_and_sources(sql, "").0
128}
129
130/// Extract all source() calls from SQL content.
131pub fn extract_sources(sql: &str) -> Vec<SourceCall> {
132    extract_refs_and_sources(sql, "").1
133}
134
135/// Strip surrounding single or double quotes from a version kwarg capture.
136/// Bare integers are returned unchanged; quoted strings have their delimiters removed.
137fn strip_version_quotes(s: &str) -> String {
138    let s = s.trim();
139    if s.len() >= 2
140        && ((s.starts_with('\'') && s.ends_with('\'')) || (s.starts_with('"') && s.ends_with('"')))
141    {
142        s[1..s.len() - 1].to_string()
143    } else {
144        s.to_string()
145    }
146}
147
148/// Normalize a version string to a canonical form, matching the normalization
149/// applied to YAML string version values in version_value_to_str().
150/// Integer strings (including zero-padded) are normalized: "02" → "2".
151/// Non-integer strings (including "2.0") are returned as-is.
152/// Using i64 only (no f64 fallback) keeps this consistent with the YAML string
153/// path so that ref(version='2.0') resolves to the same ID as `v: "2.0"`.
154pub(super) fn normalize_version_str(s: &str) -> String {
155    if let Ok(n) = s.parse::<i64>() {
156        return n.to_string();
157    }
158    s.to_string()
159}
160
161/// Regex fallback for extracting ref() calls
162fn extract_refs_regex(sql: &str) -> Vec<RefCall> {
163    let cleaned = strip_jinja_comments(sql);
164    let mut refs = Vec::new();
165
166    for cap in REF_PATTERN.captures_iter(&cleaned) {
167        if let (Some(pkg), Some(name)) = (cap.get(1), cap.get(2)) {
168            // Two-positional-arg form: ref('pkg', 'name') or ref('pkg', 'name', version=N)
169            refs.push(RefCall {
170                package: Some(pkg.as_str().to_string()),
171                name: name.as_str().to_string(),
172                version: cap
173                    .get(3)
174                    .map(|v| normalize_version_str(&strip_version_quotes(v.as_str()))),
175            });
176        } else if let (Some(name), Some(ver)) = (cap.get(4), cap.get(5)) {
177            // Single-arg + version kwarg: ref('name', version=N) or ref('name', version='str')
178            refs.push(RefCall {
179                package: None,
180                name: name.as_str().to_string(),
181                version: Some(normalize_version_str(&strip_version_quotes(ver.as_str()))),
182            });
183        } else if let Some(name) = cap.get(6) {
184            // Single-arg form: ref('name')
185            refs.push(RefCall {
186                package: None,
187                name: name.as_str().to_string(),
188                version: None,
189            });
190        }
191    }
192
193    refs
194}
195
196/// Regex fallback for extracting source() calls
197fn extract_sources_regex(sql: &str) -> Vec<SourceCall> {
198    let cleaned = strip_jinja_comments(sql);
199    let mut sources = Vec::new();
200
201    for cap in SOURCE_PATTERN.captures_iter(&cleaned) {
202        sources.push(SourceCall {
203            source_name: cap[1].to_string(),
204            table_name: cap[2].to_string(),
205        });
206    }
207
208    sources
209}
210
211/// Parsed config block from SQL
212#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
213pub struct SqlConfig {
214    pub materialized: Option<String>,
215    pub tags: Vec<String>,
216}
217
218// Matches {{ config(...) }} blocks — captures the inner arguments
219static CONFIG_PATTERN: LazyLock<Regex> = LazyLock::new(|| {
220    Regex::new(
221        r#"(?x)
222        \{\{-?\s*
223        config\s*\(
224        ([\s\S]*?)
225        \)\s*
226        -?\}\}
227    "#,
228    )
229    .unwrap()
230});
231
232// Matches materialized='value' or materialized="value"
233static MATERIALIZED_PATTERN: LazyLock<Regex> =
234    LazyLock::new(|| Regex::new(r#"materialized\s*=\s*['"]([^'"]+)['"]"#).unwrap());
235
236// Matches tags=['a', 'b'] or tags=["a", "b"]
237static TAGS_PATTERN: LazyLock<Regex> =
238    LazyLock::new(|| Regex::new(r#"tags\s*=\s*\[([^\]]*)\]"#).unwrap());
239
240// Matches individual tag values inside the tags list
241static TAG_VALUE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r#"['"]([^'"]+)['"]"#).unwrap());
242
243/// Extract config() block settings from SQL content.
244/// Tries minijinja rendering first; falls back to regex on failure.
245pub fn extract_config(sql: &str, macro_prefix: &str) -> SqlConfig {
246    if let Some(ext) = super::jinja::extract_via_jinja(sql, macro_prefix) {
247        return ext.config;
248    }
249    extract_config_regex(sql)
250}
251
252/// Regex fallback for extracting config() settings
253fn extract_config_regex(sql: &str) -> SqlConfig {
254    let cleaned = strip_jinja_comments(sql);
255    let mut config = SqlConfig::default();
256
257    if let Some(cap) = CONFIG_PATTERN.captures(&cleaned) {
258        let inner = &cap[1];
259
260        if let Some(mat) = MATERIALIZED_PATTERN.captures(inner) {
261            config.materialized = Some(mat[1].to_string());
262        }
263
264        if let Some(tags_cap) = TAGS_PATTERN.captures(inner) {
265            let tags_inner = &tags_cap[1];
266            config.tags = TAG_VALUE
267                .captures_iter(tags_inner)
268                .map(|c| c[1].to_string())
269                .collect();
270        }
271    }
272
273    config
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    // Access private fallback directly so the regex path is covered independently
280    // of the Jinja extractor (which normally runs first in extract_refs).
281    use super::extract_refs_regex;
282
283    #[test]
284    fn test_single_ref() {
285        let sql = "SELECT * FROM {{ ref('stg_orders') }}";
286        let refs = extract_refs(sql);
287        assert_eq!(refs.len(), 1);
288        assert_eq!(refs[0].name, "stg_orders");
289        assert!(refs[0].package.is_none());
290    }
291
292    #[test]
293    fn test_double_quoted_ref() {
294        let sql = r#"SELECT * FROM {{ ref("stg_orders") }}"#;
295        let refs = extract_refs(sql);
296        assert_eq!(refs.len(), 1);
297        assert_eq!(refs[0].name, "stg_orders");
298    }
299
300    #[test]
301    fn test_two_arg_ref() {
302        let sql = "SELECT * FROM {{ ref('other_project', 'stg_orders') }}";
303        let refs = extract_refs(sql);
304        assert_eq!(refs.len(), 1);
305        assert_eq!(refs[0].package.as_deref(), Some("other_project"));
306        assert_eq!(refs[0].name, "stg_orders");
307    }
308
309    #[test]
310    fn test_whitespace_control() {
311        let sql = "SELECT * FROM {{- ref('stg_orders') -}}";
312        let refs = extract_refs(sql);
313        assert_eq!(refs.len(), 1);
314        assert_eq!(refs[0].name, "stg_orders");
315    }
316
317    #[test]
318    fn test_multiple_refs() {
319        let sql = r#"
320            SELECT
321                o.*,
322                c.name
323            FROM {{ ref('stg_orders') }} o
324            JOIN {{ ref('stg_customers') }} c ON o.customer_id = c.id
325        "#;
326        let refs = extract_refs(sql);
327        assert_eq!(refs.len(), 2);
328        assert_eq!(refs[0].name, "stg_orders");
329        assert_eq!(refs[1].name, "stg_customers");
330    }
331
332    #[test]
333    fn test_source() {
334        let sql = "SELECT * FROM {{ source('raw', 'orders') }}";
335        let sources = extract_sources(sql);
336        assert_eq!(sources.len(), 1);
337        assert_eq!(sources[0].source_name, "raw");
338        assert_eq!(sources[0].table_name, "orders");
339    }
340
341    #[test]
342    fn test_source_whitespace_control() {
343        let sql = "SELECT * FROM {{- source('raw', 'orders') -}}";
344        let sources = extract_sources(sql);
345        assert_eq!(sources.len(), 1);
346        assert_eq!(sources[0].source_name, "raw");
347    }
348
349    #[test]
350    fn test_strip_jinja_comments() {
351        let sql = r#"
352            {# This is a comment with {{ ref('should_be_ignored') }} #}
353            SELECT * FROM {{ ref('actual_model') }}
354        "#;
355        let refs = extract_refs(sql);
356        assert_eq!(refs.len(), 1);
357        assert_eq!(refs[0].name, "actual_model");
358    }
359
360    #[test]
361    fn test_mixed_refs_and_sources() {
362        let sql = r#"
363            SELECT *
364            FROM {{ source('raw', 'orders') }}
365            JOIN {{ ref('stg_customers') }} ON 1=1
366        "#;
367        let refs = extract_refs(sql);
368        let sources = extract_sources(sql);
369        assert_eq!(refs.len(), 1);
370        assert_eq!(sources.len(), 1);
371    }
372
373    #[test]
374    fn test_no_refs() {
375        let sql = "SELECT 1 as id";
376        let refs = extract_refs(sql);
377        assert!(refs.is_empty());
378    }
379
380    #[test]
381    fn test_extra_spaces() {
382        let sql = "SELECT * FROM {{  ref(  'stg_orders'  )  }}";
383        let refs = extract_refs(sql);
384        assert_eq!(refs.len(), 1);
385        assert_eq!(refs[0].name, "stg_orders");
386    }
387
388    #[test]
389    fn test_ref_with_version_kwarg() {
390        let sql = "SELECT * FROM {{ ref('my_model', version=2) }}";
391        let refs = extract_refs(sql);
392        assert_eq!(refs.len(), 1);
393        assert_eq!(refs[0].name, "my_model");
394        assert_eq!(refs[0].version.as_deref(), Some("2"));
395        assert!(refs[0].package.is_none());
396    }
397
398    #[test]
399    fn test_ref_with_version_kwarg_spaced() {
400        let sql = "SELECT * FROM {{ ref('my_model', version = 3) }}";
401        let refs = extract_refs(sql);
402        assert_eq!(refs.len(), 1);
403        assert_eq!(refs[0].name, "my_model");
404        assert_eq!(refs[0].version.as_deref(), Some("3"));
405    }
406
407    #[test]
408    fn test_ref_without_version_has_none() {
409        let sql = "SELECT * FROM {{ ref('my_model') }}";
410        let refs = extract_refs(sql);
411        assert_eq!(refs.len(), 1);
412        assert_eq!(refs[0].version, None);
413    }
414
415    #[test]
416    fn test_ref_two_arg_has_no_version() {
417        let sql = "SELECT * FROM {{ ref('pkg', 'my_model') }}";
418        let refs = extract_refs(sql);
419        assert_eq!(refs.len(), 1);
420        assert_eq!(refs[0].package.as_deref(), Some("pkg"));
421        assert_eq!(refs[0].name, "my_model");
422        assert_eq!(refs[0].version, None);
423    }
424
425    #[test]
426    fn test_version_does_not_conflict_with_two_arg_form() {
427        // ref('pkg', 'name') must NOT match the version=N branch
428        let sql = "SELECT * FROM {{ ref('mypkg', 'model_a') }}";
429        let refs = extract_refs(sql);
430        assert_eq!(refs.len(), 1);
431        assert_eq!(refs[0].package.as_deref(), Some("mypkg"));
432        assert_eq!(refs[0].name, "model_a");
433        assert_eq!(refs[0].version, None);
434    }
435
436    #[test]
437    fn test_two_arg_ref_with_version_kwarg() {
438        let sql = "SELECT * FROM {{ ref('mypkg', 'my_model', version=3) }}";
439        let refs = extract_refs(sql);
440        assert_eq!(refs.len(), 1);
441        assert_eq!(refs[0].package.as_deref(), Some("mypkg"));
442        assert_eq!(refs[0].name, "my_model");
443        assert_eq!(refs[0].version.as_deref(), Some("3"));
444    }
445
446    #[test]
447    fn test_ref_with_v_shorthand_kwarg() {
448        let sql = "SELECT * FROM {{ ref('my_model', v=2) }}";
449        let refs = extract_refs(sql);
450        assert_eq!(refs.len(), 1);
451        assert_eq!(refs[0].name, "my_model");
452        assert_eq!(refs[0].version.as_deref(), Some("2"));
453        assert!(refs[0].package.is_none());
454    }
455
456    #[test]
457    fn test_two_arg_ref_with_v_shorthand_kwarg() {
458        let sql = "SELECT * FROM {{ ref('mypkg', 'my_model', v=3) }}";
459        let refs = extract_refs(sql);
460        assert_eq!(refs.len(), 1);
461        assert_eq!(refs[0].package.as_deref(), Some("mypkg"));
462        assert_eq!(refs[0].name, "my_model");
463        assert_eq!(refs[0].version.as_deref(), Some("3"));
464    }
465
466    #[test]
467    fn test_ref_with_string_version_kwarg() {
468        // dbt-core accepts version='alpha' for non-integer version strings
469        let refs = extract_refs_regex("SELECT * FROM {{ ref('my_model', version='alpha') }}");
470        assert_eq!(refs.len(), 1);
471        assert_eq!(refs[0].name, "my_model");
472        assert_eq!(refs[0].version.as_deref(), Some("alpha"));
473    }
474
475    #[test]
476    fn test_ref_with_quoted_integer_version_kwarg() {
477        // version='2' (quoted) must resolve identically to version=2 (bare integer)
478        let refs = extract_refs_regex("SELECT * FROM {{ ref('my_model', version='2') }}");
479        assert_eq!(refs.len(), 1);
480        assert_eq!(refs[0].version.as_deref(), Some("2"));
481    }
482
483    #[test]
484    fn test_ref_with_padded_integer_version_kwarg() {
485        // version='02' must normalize to "2" to match YAML v: 2 → version_value_to_str → "2"
486        let refs = extract_refs_regex("SELECT * FROM {{ ref('my_model', version='02') }}");
487        assert_eq!(refs.len(), 1);
488        assert_eq!(refs[0].version.as_deref(), Some("2"));
489    }
490
491    #[test]
492    fn test_ref_with_decimal_version_kwarg() {
493        // version='2.0' stays as "2.0" — matching YAML `v: "2.0"` which also keeps "2.0".
494        // Both use i64-only normalization so non-integer numeric strings are not rewritten.
495        let refs = extract_refs_regex("SELECT * FROM {{ ref('my_model', version='2.0') }}");
496        assert_eq!(refs.len(), 1);
497        assert_eq!(refs[0].version.as_deref(), Some("2.0"));
498    }
499
500    // These tests call the regex fallback directly to confirm `v=` support
501    // in that path, independent of the Jinja extractor.
502
503    #[test]
504    fn test_regex_fallback_v_shorthand_kwarg() {
505        let refs = extract_refs_regex("SELECT * FROM {{ ref('my_model', v=2) }}");
506        assert_eq!(refs.len(), 1);
507        assert_eq!(refs[0].name, "my_model");
508        assert_eq!(refs[0].version.as_deref(), Some("2"));
509        assert!(refs[0].package.is_none());
510    }
511
512    #[test]
513    fn test_regex_fallback_two_arg_v_shorthand_kwarg() {
514        let refs = extract_refs_regex("SELECT * FROM {{ ref('mypkg', 'my_model', v=3) }}");
515        assert_eq!(refs.len(), 1);
516        assert_eq!(refs[0].package.as_deref(), Some("mypkg"));
517        assert_eq!(refs[0].name, "my_model");
518        assert_eq!(refs[0].version.as_deref(), Some("3"));
519    }
520
521    // ─── Config extraction tests ───
522
523    #[test]
524    fn test_config_materialized() {
525        let sql = "{{ config(materialized='incremental') }}\nSELECT 1";
526        let config = extract_config(sql, "");
527        assert_eq!(config.materialized.as_deref(), Some("incremental"));
528        assert!(config.tags.is_empty());
529    }
530
531    #[test]
532    fn test_config_materialized_double_quotes() {
533        let sql = r#"{{ config(materialized="table") }}"#;
534        let config = extract_config(sql, "");
535        assert_eq!(config.materialized.as_deref(), Some("table"));
536    }
537
538    #[test]
539    fn test_config_tags() {
540        let sql = "{{ config(tags=['nightly', 'finance']) }}\nSELECT 1";
541        let config = extract_config(sql, "");
542        assert_eq!(config.tags, vec!["nightly", "finance"]);
543    }
544
545    #[test]
546    fn test_config_both() {
547        let sql = "{{ config(materialized='view', tags=['daily']) }}\nSELECT 1";
548        let config = extract_config(sql, "");
549        assert_eq!(config.materialized.as_deref(), Some("view"));
550        assert_eq!(config.tags, vec!["daily"]);
551    }
552
553    #[test]
554    fn test_config_whitespace_control() {
555        let sql = "{{- config(materialized='ephemeral') -}}\nSELECT 1";
556        let config = extract_config(sql, "");
557        assert_eq!(config.materialized.as_deref(), Some("ephemeral"));
558    }
559
560    #[test]
561    fn test_config_multiline() {
562        let sql = r#"{{
563            config(
564                materialized='incremental',
565                tags=['nightly', 'warehouse']
566            )
567        }}
568        SELECT 1"#;
569        let config = extract_config(sql, "");
570        assert_eq!(config.materialized.as_deref(), Some("incremental"));
571        assert_eq!(config.tags, vec!["nightly", "warehouse"]);
572    }
573
574    #[test]
575    fn test_no_config() {
576        let sql = "SELECT * FROM {{ ref('orders') }}";
577        let config = extract_config(sql, "");
578        assert!(config.materialized.is_none());
579        assert!(config.tags.is_empty());
580    }
581
582    #[test]
583    fn test_config_in_comment_ignored() {
584        let sql = r#"
585            {# {{ config(materialized='table') }} #}
586            SELECT 1
587        "#;
588        let config = extract_config(sql, "");
589        assert!(config.materialized.is_none());
590    }
591}