Skip to main content

dlin_core/parser/
jinja.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use minijinja::value::Kwargs;
5use minijinja::{Environment, ErrorKind, Value};
6
7use super::sql::{RefCall, SourceCall, SqlConfig};
8
9/// All extracted information from rendering a dbt Jinja SQL template
10#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
11pub struct JinjaExtraction {
12    pub refs: Vec<RefCall>,
13    pub sources: Vec<SourceCall>,
14    pub config: SqlConfig,
15}
16
17/// Try to extract refs, sources, and config from SQL content using minijinja.
18/// Renders twice (with `is_incremental()` returning both false and true) and
19/// merges results to capture refs/sources from all conditional branches.
20/// Returns `None` if the template fails to render (caller should fall back to regex).
21///
22/// `macro_prefix` is the pre-built concatenation of valid macro SQL files.
23/// It is prepended to the template so that custom macros containing
24/// ref()/source() calls are expanded and tracked.
25pub fn extract_via_jinja(sql: &str, macro_prefix: &str) -> Option<JinjaExtraction> {
26    extract_via_jinja_with_vars(sql, macro_prefix, &HashMap::new())
27}
28
29/// Like [`extract_via_jinja`] but resolves `var()` calls using the given
30/// project-level variables (parsed from `dbt_project.yml`).
31pub fn extract_via_jinja_with_vars(
32    sql: &str,
33    macro_prefix: &str,
34    vars: &HashMap<String, serde_json::Value>,
35) -> Option<JinjaExtraction> {
36    let template = if macro_prefix.is_empty() {
37        sql.to_string()
38    } else {
39        format!("{}\n{}", macro_prefix, sql)
40    };
41
42    // Render with is_incremental=false first (full-load path)
43    let mut result = render_with_incremental(&template, false, vars)?;
44
45    // Render again with is_incremental=true to capture incremental-only refs
46    if let Some(incr) = render_with_incremental(&template, true, vars) {
47        merge_extraction(&mut result, incr);
48    }
49
50    Some(result)
51}
52
53/// Build a macro prefix string from individual macro sources, skipping
54/// any that fail to parse as valid minijinja templates. This ensures one
55/// bad macro file doesn't disable jinja-based extraction for all models.
56pub fn build_macro_prefix(macro_sources: &[String]) -> String {
57    if macro_sources.is_empty() {
58        return String::new();
59    }
60    let env = Environment::new();
61    let mut prefix = String::new();
62    for source in macro_sources {
63        // Only include macros that minijinja can parse individually
64        if env.template_from_str(source).is_err() {
65            continue;
66        }
67        // Verify the accumulated prefix still parses after adding this macro
68        let len = prefix.len();
69        prefix.push_str(source);
70        prefix.push('\n');
71        if env.template_from_str(&prefix).is_err() {
72            prefix.truncate(len);
73        }
74    }
75    prefix
76}
77
78/// Merge `other` into `base`, adding only deduplicated refs and sources
79fn merge_extraction(base: &mut JinjaExtraction, other: JinjaExtraction) {
80    for r in other.refs {
81        if !base.refs.contains(&r) {
82            base.refs.push(r);
83        }
84    }
85    for s in other.sources {
86        if !base.sources.contains(&s) {
87            base.sources.push(s);
88        }
89    }
90    // config from first render takes precedence
91}
92
93/// Convert a `serde_json::Value` to a `minijinja::Value`.
94fn json_to_minijinja(v: &serde_json::Value) -> Value {
95    Value::from_serialize(v)
96}
97
98/// Render a dbt SQL template once with the given `is_incremental` value.
99fn render_with_incremental(
100    sql: &str,
101    is_incremental: bool,
102    vars: &HashMap<String, serde_json::Value>,
103) -> Option<JinjaExtraction> {
104    let extraction = Arc::new(Mutex::new(JinjaExtraction::default()));
105
106    let mut env = Environment::new();
107    env.set_undefined_behavior(minijinja::UndefinedBehavior::Lenient);
108
109    // ref('name') or ref('package', 'name')
110    let ext = extraction.clone();
111    env.add_function(
112        "ref",
113        move |args: &[Value]| -> Result<Value, minijinja::Error> {
114            let mut ext = ext.lock().unwrap();
115            match args.len() {
116                1 => {
117                    let name = args[0].to_string();
118                    ext.refs.push(RefCall {
119                        package: None,
120                        name: name.clone(),
121                    });
122                    Ok(Value::from(format!("__dbt_ref_{}__", name)))
123                }
124                2 => {
125                    let pkg = args[0].to_string();
126                    let name = args[1].to_string();
127                    ext.refs.push(RefCall {
128                        package: Some(pkg),
129                        name: name.clone(),
130                    });
131                    Ok(Value::from(format!("__dbt_ref_{}__", name)))
132                }
133                _ => Err(minijinja::Error::new(
134                    ErrorKind::TooManyArguments,
135                    "ref() takes 1 or 2 arguments",
136                )),
137            }
138        },
139    );
140
141    // source('source_name', 'table_name')
142    let ext = extraction.clone();
143    env.add_function(
144        "source",
145        move |args: &[Value]| -> Result<Value, minijinja::Error> {
146            if args.len() >= 2 {
147                let source_name = args[0].to_string();
148                let table_name = args[1].to_string();
149                ext.lock().unwrap().sources.push(SourceCall {
150                    source_name: source_name.clone(),
151                    table_name: table_name.clone(),
152                });
153                Ok(Value::from(format!(
154                    "__dbt_source_{}_{}__",
155                    source_name, table_name
156                )))
157            } else {
158                Err(minijinja::Error::new(
159                    ErrorKind::MissingArgument,
160                    "source() requires 2 arguments",
161                ))
162            }
163        },
164    );
165
166    // config(materialized='...', tags=[...], ...)
167    // Unknown kwargs (schema, alias, unique_key, etc.) are silently ignored.
168    let ext = extraction.clone();
169    env.add_function(
170        "config",
171        move |kwargs: Kwargs| -> Result<Value, minijinja::Error> {
172            let mut ext = ext.lock().unwrap();
173            if let Ok(mat) = kwargs.get::<&str>("materialized") {
174                ext.config.materialized = Some(mat.to_string());
175            }
176            if let Ok(tags_val) = kwargs.get::<Value>("tags")
177                && let Ok(iter) = tags_val.try_iter()
178            {
179                ext.config.tags = iter.map(|v| v.to_string()).collect();
180            }
181            Ok(Value::from(""))
182        },
183    );
184
185    // is_incremental() → parameterized
186    env.add_function(
187        "is_incremental",
188        move || -> Result<Value, minijinja::Error> { Ok(Value::from(is_incremental)) },
189    );
190
191    // this → dummy relation object
192    env.add_global("this", Value::from("__dbt_this__"));
193
194    // var() → resolves from dbt_project.yml vars, then default, then truthy sentinel
195    let vars_map: HashMap<String, Value> = vars
196        .iter()
197        .map(|(k, v)| (k.clone(), json_to_minijinja(v)))
198        .collect();
199    env.add_function(
200        "var",
201        move |args: &[Value]| -> Result<Value, minijinja::Error> {
202            if let Some(key) = args.first()
203                && let Some(key_str) = key.as_str()
204                && let Some(val) = vars_map.get(key_str)
205            {
206                return Ok(val.clone());
207            }
208            // Fall back to default argument (2nd arg) or truthy sentinel
209            if args.len() >= 2 {
210                Ok(args[1].clone())
211            } else {
212                Ok(Value::from("__dbt_var_unknown__"))
213            }
214        },
215    );
216
217    // env_var() → returns default or empty string
218    env.add_function(
219        "env_var",
220        |args: &[Value]| -> Result<Value, minijinja::Error> {
221            if args.len() >= 2 {
222                Ok(args[1].clone())
223            } else {
224                Ok(Value::from(""))
225            }
226        },
227    );
228
229    // return() → pass through
230    env.add_function(
231        "return",
232        |args: &[Value]| -> Result<Value, minijinja::Error> {
233            Ok(args.first().cloned().unwrap_or(Value::from("")))
234        },
235    );
236
237    // log() → no-op
238    env.add_function(
239        "log",
240        |_args: &[Value]| -> Result<Value, minijinja::Error> { Ok(Value::from("")) },
241    );
242
243    // run_query → no-op
244    env.add_function(
245        "run_query",
246        |_args: &[Value]| -> Result<Value, minijinja::Error> { Ok(Value::from("")) },
247    );
248
249    // statement → no-op
250    env.add_function(
251        "statement",
252        |_args: &[Value]| -> Result<Value, minijinja::Error> { Ok(Value::from("")) },
253    );
254
255    // Common dbt globals
256    env.add_global("adapter", Value::from("__dbt_adapter__"));
257    env.add_global("exceptions", Value::from("__dbt_exceptions__"));
258    env.add_global("api", Value::from("__dbt_api__"));
259    env.add_global("graph", Value::from("__dbt_graph__"));
260    env.add_global("target", Value::from("__dbt_target__"));
261    env.add_global("invocation_id", Value::from("__dbt_invocation_id__"));
262    env.add_global("run_started_at", Value::from("2025-01-01T00:00:00Z"));
263    env.add_global("flags", Value::from("__dbt_flags__"));
264    env.add_global("modules", Value::from("__dbt_modules__"));
265    env.add_global("dbt_version", Value::from("1.0.0"));
266    env.add_global("model", Value::from("__dbt_model__"));
267    env.add_global("execute", Value::from(true));
268
269    let render_result = env.render_str(sql, ());
270    drop(env);
271
272    match render_result {
273        Ok(_) => {
274            let result = Arc::try_unwrap(extraction)
275                .expect("single owner")
276                .into_inner()
277                .unwrap_or_else(|e| e.into_inner());
278            Some(result)
279        }
280        Err(_) => None,
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn test_simple_ref() {
290        let sql = "SELECT * FROM {{ ref('stg_orders') }}";
291        let ext = extract_via_jinja(sql, "").unwrap();
292        assert_eq!(ext.refs.len(), 1);
293        assert_eq!(ext.refs[0].name, "stg_orders");
294        assert!(ext.refs[0].package.is_none());
295    }
296
297    #[test]
298    fn test_two_arg_ref() {
299        let sql = "SELECT * FROM {{ ref('other_pkg', 'stg_orders') }}";
300        let ext = extract_via_jinja(sql, "").unwrap();
301        assert_eq!(ext.refs.len(), 1);
302        assert_eq!(ext.refs[0].package.as_deref(), Some("other_pkg"));
303        assert_eq!(ext.refs[0].name, "stg_orders");
304    }
305
306    #[test]
307    fn test_source() {
308        let sql = "SELECT * FROM {{ source('raw', 'orders') }}";
309        let ext = extract_via_jinja(sql, "").unwrap();
310        assert_eq!(ext.sources.len(), 1);
311        assert_eq!(ext.sources[0].source_name, "raw");
312        assert_eq!(ext.sources[0].table_name, "orders");
313    }
314
315    #[test]
316    fn test_config() {
317        let sql = "{{ config(materialized='incremental', tags=['nightly', 'finance']) }}\nSELECT 1";
318        let ext = extract_via_jinja(sql, "").unwrap();
319        assert_eq!(ext.config.materialized.as_deref(), Some("incremental"));
320        assert_eq!(ext.config.tags, vec!["nightly", "finance"]);
321    }
322
323    #[test]
324    fn test_mixed() {
325        let sql = r#"
326            {{ config(materialized='table') }}
327            SELECT
328                o.*,
329                c.name
330            FROM {{ ref('stg_orders') }} o
331            JOIN {{ source('raw', 'customers') }} c ON o.customer_id = c.id
332        "#;
333        let ext = extract_via_jinja(sql, "").unwrap();
334        assert_eq!(ext.refs.len(), 1);
335        assert_eq!(ext.sources.len(), 1);
336        assert_eq!(ext.config.materialized.as_deref(), Some("table"));
337    }
338
339    #[test]
340    fn test_ref_inside_set() {
341        let sql = r#"
342            {% set orders = ref('stg_orders') %}
343            SELECT * FROM {{ orders }}
344        "#;
345        let ext = extract_via_jinja(sql, "").unwrap();
346        assert_eq!(ext.refs.len(), 1);
347        assert_eq!(ext.refs[0].name, "stg_orders");
348    }
349
350    #[test]
351    fn test_is_incremental_both_branches() {
352        let sql = r#"
353            {% if is_incremental() %}
354            SELECT * FROM {{ ref('stg_incremental_orders') }}
355            WHERE updated_at > (SELECT max(updated_at) FROM {{ this }})
356            {% else %}
357            SELECT * FROM {{ ref('stg_full_orders') }}
358            {% endif %}
359        "#;
360        let ext = extract_via_jinja(sql, "").unwrap();
361        // Both branches are rendered: unique refs from each branch
362        assert_eq!(ext.refs.len(), 2);
363        assert!(ext.refs.iter().any(|r| r.name == "stg_full_orders"));
364        assert!(ext.refs.iter().any(|r| r.name == "stg_incremental_orders"));
365    }
366
367    #[test]
368    fn test_jinja_comment_ignored() {
369        let sql = r#"
370            {# This is a comment with {{ ref('should_be_ignored') }} #}
371            SELECT * FROM {{ ref('actual_model') }}
372        "#;
373        let ext = extract_via_jinja(sql, "").unwrap();
374        assert_eq!(ext.refs.len(), 1);
375        assert_eq!(ext.refs[0].name, "actual_model");
376    }
377
378    #[test]
379    fn test_whitespace_control() {
380        let sql = "SELECT * FROM {{- ref('stg_orders') -}}";
381        let ext = extract_via_jinja(sql, "").unwrap();
382        assert_eq!(ext.refs.len(), 1);
383        assert_eq!(ext.refs[0].name, "stg_orders");
384    }
385
386    #[test]
387    fn test_var_with_default() {
388        let sql = "SELECT * FROM {{ ref('model_' ~ var('suffix', 'default')) }}";
389        let ext = extract_via_jinja(sql, "").unwrap();
390        assert_eq!(ext.refs.len(), 1);
391        assert_eq!(ext.refs[0].name, "model_default");
392    }
393
394    #[test]
395    fn test_var_resolved_from_project_vars() {
396        let sql = "SELECT * FROM {{ ref('model_' ~ var('suffix')) }}";
397        let mut vars = HashMap::new();
398        vars.insert(
399            "suffix".to_string(),
400            serde_json::Value::String("prod".to_string()),
401        );
402        let ext = extract_via_jinja_with_vars(sql, "", &vars).unwrap();
403        assert_eq!(ext.refs.len(), 1);
404        assert_eq!(ext.refs[0].name, "model_prod");
405    }
406
407    #[test]
408    fn test_var_list_expansion_in_for_loop() {
409        // Reproduces the reported bug: var() returning a list should iterate
410        // as a list, not char-by-char as a string.
411        let sql = r#"
412            {%- set categories = var("product_categories") -%}
413            {%- for cat in categories -%}
414                SELECT * FROM {{ ref('stg_' ~ cat ~ '_summary') }}
415                {% if not loop.last %}UNION ALL{% endif %}
416            {% endfor -%}
417        "#;
418        let mut vars = HashMap::new();
419        vars.insert(
420            "product_categories".to_string(),
421            serde_json::json!(["electronics", "clothing"]),
422        );
423        let ext = extract_via_jinja_with_vars(sql, "", &vars).unwrap();
424        assert_eq!(ext.refs.len(), 2);
425        assert!(ext.refs.iter().any(|r| r.name == "stg_electronics_summary"));
426        assert!(ext.refs.iter().any(|r| r.name == "stg_clothing_summary"));
427    }
428
429    #[test]
430    fn test_var_project_overrides_default() {
431        // When project vars are provided, they should take precedence over
432        // the default argument in var().
433        let sql = "SELECT * FROM {{ ref('model_' ~ var('env', 'dev')) }}";
434        let mut vars = HashMap::new();
435        vars.insert(
436            "env".to_string(),
437            serde_json::Value::String("staging".to_string()),
438        );
439        let ext = extract_via_jinja_with_vars(sql, "", &vars).unwrap();
440        assert_eq!(ext.refs.len(), 1);
441        assert_eq!(ext.refs[0].name, "model_staging");
442    }
443
444    #[test]
445    fn test_var_unknown_falls_back_to_default() {
446        // When a var is not in project vars, fall back to the default argument.
447        let sql = "SELECT * FROM {{ ref('model_' ~ var('missing', 'fallback')) }}";
448        let mut vars = HashMap::new();
449        vars.insert(
450            "other_var".to_string(),
451            serde_json::Value::String("unused".to_string()),
452        );
453        let ext = extract_via_jinja_with_vars(sql, "", &vars).unwrap();
454        assert_eq!(ext.refs.len(), 1);
455        assert_eq!(ext.refs[0].name, "model_fallback");
456    }
457
458    #[test]
459    fn test_for_loop_with_refs() {
460        let sql = r#"
461            {% for src in ['orders', 'customers'] %}
462                SELECT * FROM {{ source('raw', src) }}
463                {% if not loop.last %}UNION ALL{% endif %}
464            {% endfor %}
465        "#;
466        let ext = extract_via_jinja(sql, "").unwrap();
467        assert_eq!(ext.sources.len(), 2);
468        assert_eq!(ext.sources[0].source_name, "raw");
469        assert_eq!(ext.sources[0].table_name, "orders");
470        assert_eq!(ext.sources[1].source_name, "raw");
471        assert_eq!(ext.sources[1].table_name, "customers");
472    }
473
474    #[test]
475    fn test_config_with_extra_kwargs() {
476        let sql = "{{ config(materialized='incremental', schema='analytics', unique_key='id', tags=['nightly']) }}\nSELECT 1";
477        let ext = extract_via_jinja(sql, "").unwrap();
478        assert_eq!(ext.config.materialized.as_deref(), Some("incremental"));
479        assert_eq!(ext.config.tags, vec!["nightly"]);
480    }
481
482    #[test]
483    fn test_returns_none_on_unsupported_template() {
484        // Unknown block tags should cause failure
485        let sql = "{% materialization table, default %} SELECT 1 {% endmaterialization %}";
486        let result = extract_via_jinja(sql, "");
487        assert!(result.is_none());
488    }
489
490    #[test]
491    fn test_macro_ref_extraction() {
492        let macro_src = r#"
493            {% macro my_cte() %}
494                SELECT * FROM {{ ref('base_model') }}
495            {% endmacro %}
496        "#;
497        let sql = "SELECT * FROM ({{ my_cte() }})";
498        let ext = extract_via_jinja(sql, macro_src).unwrap();
499        assert_eq!(ext.refs.len(), 1);
500        assert_eq!(ext.refs[0].name, "base_model");
501    }
502
503    #[test]
504    fn test_macro_source_extraction() {
505        let macro_src = r#"
506            {% macro raw_data(table) %}
507                SELECT * FROM {{ source('raw', table) }}
508            {% endmacro %}
509        "#;
510        let sql = "SELECT * FROM ({{ raw_data('orders') }})";
511        let ext = extract_via_jinja(sql, macro_src).unwrap();
512        assert_eq!(ext.sources.len(), 1);
513        assert_eq!(ext.sources[0].source_name, "raw");
514        assert_eq!(ext.sources[0].table_name, "orders");
515    }
516
517    #[test]
518    fn test_macro_with_multiple_refs() {
519        let macro_src = r#"
520            {% macro join_tables(period) %}
521                SELECT * FROM {{ ref('deals') }}
522                LEFT JOIN {{ ref('providers') }} ON 1=1
523                LEFT JOIN {{ source('raw', 'prices') }} ON 1=1
524            {% endmacro %}
525        "#;
526        let sql = "{{ join_tables('day') }}";
527        let ext = extract_via_jinja(sql, macro_src).unwrap();
528        assert_eq!(ext.refs.len(), 2);
529        assert!(ext.refs.iter().any(|r| r.name == "deals"));
530        assert!(ext.refs.iter().any(|r| r.name == "providers"));
531        assert_eq!(ext.sources.len(), 1);
532        assert_eq!(ext.sources[0].table_name, "prices");
533    }
534
535    #[test]
536    fn test_multiple_macro_files() {
537        let sources = vec![
538            r#"
539            {% macro get_orders() %}
540                SELECT * FROM {{ ref('stg_orders') }}
541            {% endmacro %}
542            "#
543            .to_string(),
544            r#"
545            {% macro get_customers() %}
546                SELECT * FROM {{ ref('stg_customers') }}
547            {% endmacro %}
548            "#
549            .to_string(),
550        ];
551        let prefix = build_macro_prefix(&sources);
552        let sql = "{{ get_orders() }} UNION ALL {{ get_customers() }}";
553        let ext = extract_via_jinja(sql, &prefix).unwrap();
554        assert_eq!(ext.refs.len(), 2);
555        assert!(ext.refs.iter().any(|r| r.name == "stg_orders"));
556        assert!(ext.refs.iter().any(|r| r.name == "stg_customers"));
557    }
558
559    #[test]
560    fn test_build_macro_prefix_skips_invalid() {
561        let sources = vec![
562            "{% macro good() %}SELECT 1{% endmacro %}".to_string(),
563            // Invalid: unsupported block tag
564            "{% materialization custom %} stuff {% endmaterialization %}".to_string(),
565            "{% macro also_good() %}SELECT 2{% endmacro %}".to_string(),
566            // Invalid: unclosed raw block
567            "{% raw %}unclosed raw content".to_string(),
568        ];
569        let prefix = build_macro_prefix(&sources);
570        assert!(prefix.contains("{% macro good() %}"));
571        assert!(prefix.contains("{% macro also_good() %}"));
572        assert!(!prefix.contains("materialization"));
573        assert!(!prefix.contains("{% raw %}"));
574    }
575
576    #[test]
577    fn test_build_macro_prefix_includes_compatible_macros() {
578        let env = Environment::new();
579
580        let macro_a = "{% macro a() %}ok{% endmacro %}".to_string();
581        let macro_b = "{% macro b() %}ok{% endmacro %}".to_string();
582        assert!(env.template_from_str(&macro_a).is_ok());
583        assert!(env.template_from_str(&macro_b).is_ok());
584
585        let sources = vec![macro_a, macro_b];
586        let prefix = build_macro_prefix(&sources);
587        assert!(prefix.contains("{% macro a() %}"));
588        assert!(prefix.contains("{% macro b() %}"));
589    }
590
591    #[test]
592    fn test_invalid_macro_skipped_refs_still_extracted() {
593        let sources = vec![
594            // Bad macro that would poison everything if not filtered
595            "{% materialization custom %} stuff {% endmaterialization %}".to_string(),
596        ];
597        let prefix = build_macro_prefix(&sources);
598        let sql = "SELECT * FROM {{ ref('orders') }}";
599        let ext = extract_via_jinja(sql, &prefix).unwrap();
600        assert_eq!(ext.refs.len(), 1);
601        assert_eq!(ext.refs[0].name, "orders");
602    }
603}