Skip to main content

dlin_core/parser/
jinja.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use minijinja::value::{Kwargs, from_args};
5use minijinja::{Environment, ErrorKind, Value};
6
7use super::sql::{RefCall, SourceCall, SqlConfig, normalize_version_str};
8
9/// All extracted information from rendering a dbt Jinja SQL template
10#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
11pub struct JinjaExtraction {
12    pub refs: Vec<RefCall>,
13    pub sources: Vec<SourceCall>,
14    pub config: SqlConfig,
15}
16
17/// Try to extract refs, sources, and config from SQL content using minijinja.
18/// Renders twice (with `is_incremental()` returning both false and true) and
19/// merges results to capture refs/sources from all conditional branches.
20/// Returns `None` if the template fails to render (caller should fall back to regex).
21///
22/// `macro_prefix` is the pre-built concatenation of valid macro SQL files.
23/// It is prepended to the template so that custom macros containing
24/// ref()/source() calls are expanded and tracked.
25pub fn extract_via_jinja(sql: &str, macro_prefix: &str) -> Option<JinjaExtraction> {
26    extract_via_jinja_with_vars(sql, macro_prefix, &HashMap::new())
27}
28
29/// Like [`extract_via_jinja`] but resolves `var()` calls using the given
30/// project-level variables (parsed from `dbt_project.yml`).
31pub fn extract_via_jinja_with_vars(
32    sql: &str,
33    macro_prefix: &str,
34    vars: &HashMap<String, serde_json::Value>,
35) -> Option<JinjaExtraction> {
36    let template = if macro_prefix.is_empty() {
37        sql.to_string()
38    } else {
39        format!("{}\n{}", macro_prefix, sql)
40    };
41
42    // Render with is_incremental=false first (full-load path)
43    let mut result = render_with_incremental(&template, false, vars)?;
44
45    // Render again with is_incremental=true to capture incremental-only refs
46    if let Some(incr) = render_with_incremental(&template, true, vars) {
47        merge_extraction(&mut result, incr);
48    }
49
50    Some(result)
51}
52
53/// Build a macro prefix string from individual macro sources, skipping
54/// any that fail to parse as valid minijinja templates. This ensures one
55/// bad macro file doesn't disable jinja-based extraction for all models.
56pub fn build_macro_prefix(macro_sources: &[String]) -> String {
57    if macro_sources.is_empty() {
58        return String::new();
59    }
60    let env = Environment::new();
61    let mut prefix = String::new();
62    for source in macro_sources {
63        // Only include macros that minijinja can parse individually
64        if env.template_from_str(source).is_err() {
65            continue;
66        }
67        // Verify the accumulated prefix still parses after adding this macro
68        let len = prefix.len();
69        prefix.push_str(source);
70        prefix.push('\n');
71        if env.template_from_str(&prefix).is_err() {
72            prefix.truncate(len);
73        }
74    }
75    prefix
76}
77
78/// Merge `other` into `base`, adding only deduplicated refs and sources
79fn merge_extraction(base: &mut JinjaExtraction, other: JinjaExtraction) {
80    for r in other.refs {
81        if !base.refs.contains(&r) {
82            base.refs.push(r);
83        }
84    }
85    for s in other.sources {
86        if !base.sources.contains(&s) {
87            base.sources.push(s);
88        }
89    }
90    // config from first render takes precedence
91}
92
93/// Convert a `serde_json::Value` to a `minijinja::Value`.
94fn json_to_minijinja(v: &serde_json::Value) -> Value {
95    Value::from_serialize(v)
96}
97
98/// Render a dbt SQL template once with the given `is_incremental` value.
99fn render_with_incremental(
100    sql: &str,
101    is_incremental: bool,
102    vars: &HashMap<String, serde_json::Value>,
103) -> Option<JinjaExtraction> {
104    let extraction = Arc::new(Mutex::new(JinjaExtraction::default()));
105
106    let mut env = Environment::new();
107    env.set_undefined_behavior(minijinja::UndefinedBehavior::Lenient);
108
109    // ref('name'), ref('package', 'name'), or ref('name', version=N)
110    // kwargs (e.g. version=2) are appended by minijinja as the last element of args.
111    // from_args splits positional args from kwargs so we can extract version.
112    let ext = extraction.clone();
113    env.add_function(
114        "ref",
115        move |args: &[Value]| -> Result<Value, minijinja::Error> {
116            let mut ext = ext.lock().unwrap();
117            let (positional, kwargs): (&[Value], Kwargs) = from_args(args)
118                .map_err(|e| minijinja::Error::new(ErrorKind::InvalidOperation, e.to_string()))?;
119            // dbt accepts both `version=N` and `v=N` as shorthand.
120            // The value may be an integer (version=2) or a quoted string (version='alpha'),
121            // matching dbt-core which uses StringOrInteger for version kwargs.
122            let version: Option<String> = kwargs
123                .peek::<i64>("version")
124                .ok()
125                .map(|n| n.to_string())
126                .or_else(|| {
127                    kwargs
128                        .peek::<String>("version")
129                        .ok()
130                        .map(|s| normalize_version_str(&s))
131                })
132                .or_else(|| kwargs.peek::<i64>("v").ok().map(|n| n.to_string()))
133                .or_else(|| {
134                    kwargs
135                        .peek::<String>("v")
136                        .ok()
137                        .map(|s| normalize_version_str(&s))
138                });
139            match positional.len() {
140                1 => {
141                    let name = positional[0].to_string();
142                    ext.refs.push(RefCall {
143                        package: None,
144                        name: name.clone(),
145                        version,
146                    });
147                    Ok(Value::from(format!("__dbt_ref_{}__", name)))
148                }
149                2 => {
150                    let pkg = positional[0].to_string();
151                    let name = positional[1].to_string();
152                    ext.refs.push(RefCall {
153                        package: Some(pkg),
154                        name: name.clone(),
155                        version,
156                    });
157                    Ok(Value::from(format!("__dbt_ref_{}__", name)))
158                }
159                _ => Err(minijinja::Error::new(
160                    ErrorKind::TooManyArguments,
161                    "ref() takes 1 or 2 positional arguments",
162                )),
163            }
164        },
165    );
166
167    // source('source_name', 'table_name')
168    let ext = extraction.clone();
169    env.add_function(
170        "source",
171        move |args: &[Value]| -> Result<Value, minijinja::Error> {
172            if args.len() >= 2 {
173                let source_name = args[0].to_string();
174                let table_name = args[1].to_string();
175                ext.lock().unwrap().sources.push(SourceCall {
176                    source_name: source_name.clone(),
177                    table_name: table_name.clone(),
178                });
179                Ok(Value::from(format!(
180                    "__dbt_source_{}_{}__",
181                    source_name, table_name
182                )))
183            } else {
184                Err(minijinja::Error::new(
185                    ErrorKind::MissingArgument,
186                    "source() requires 2 arguments",
187                ))
188            }
189        },
190    );
191
192    // config(materialized='...', tags=[...], ...)
193    // Unknown kwargs (schema, alias, unique_key, etc.) are silently ignored.
194    let ext = extraction.clone();
195    env.add_function(
196        "config",
197        move |kwargs: Kwargs| -> Result<Value, minijinja::Error> {
198            let mut ext = ext.lock().unwrap();
199            if let Ok(mat) = kwargs.get::<&str>("materialized") {
200                ext.config.materialized = Some(mat.to_string());
201            }
202            if let Ok(tags_val) = kwargs.get::<Value>("tags")
203                && let Ok(iter) = tags_val.try_iter()
204            {
205                ext.config.tags = iter.map(|v| v.to_string()).collect();
206            }
207            Ok(Value::from(""))
208        },
209    );
210
211    // is_incremental() → parameterized
212    env.add_function(
213        "is_incremental",
214        move || -> Result<Value, minijinja::Error> { Ok(Value::from(is_incremental)) },
215    );
216
217    // this → dummy relation object
218    env.add_global("this", Value::from("__dbt_this__"));
219
220    // var() → resolves from dbt_project.yml vars, then default, then truthy sentinel
221    let vars_map: HashMap<String, Value> = vars
222        .iter()
223        .map(|(k, v)| (k.clone(), json_to_minijinja(v)))
224        .collect();
225    env.add_function(
226        "var",
227        move |args: &[Value]| -> Result<Value, minijinja::Error> {
228            if let Some(key) = args.first()
229                && let Some(key_str) = key.as_str()
230                && let Some(val) = vars_map.get(key_str)
231            {
232                return Ok(val.clone());
233            }
234            // Fall back to default argument (2nd arg) or truthy sentinel
235            if args.len() >= 2 {
236                Ok(args[1].clone())
237            } else {
238                Ok(Value::from("__dbt_var_unknown__"))
239            }
240        },
241    );
242
243    // env_var() → returns default or empty string
244    env.add_function(
245        "env_var",
246        |args: &[Value]| -> Result<Value, minijinja::Error> {
247            if args.len() >= 2 {
248                Ok(args[1].clone())
249            } else {
250                Ok(Value::from(""))
251            }
252        },
253    );
254
255    // return() → pass through
256    env.add_function(
257        "return",
258        |args: &[Value]| -> Result<Value, minijinja::Error> {
259            Ok(args.first().cloned().unwrap_or(Value::from("")))
260        },
261    );
262
263    // log() → no-op
264    env.add_function(
265        "log",
266        |_args: &[Value]| -> Result<Value, minijinja::Error> { Ok(Value::from("")) },
267    );
268
269    // run_query → no-op
270    env.add_function(
271        "run_query",
272        |_args: &[Value]| -> Result<Value, minijinja::Error> { Ok(Value::from("")) },
273    );
274
275    // statement → no-op
276    env.add_function(
277        "statement",
278        |_args: &[Value]| -> Result<Value, minijinja::Error> { Ok(Value::from("")) },
279    );
280
281    // Common dbt globals
282    env.add_global("adapter", Value::from("__dbt_adapter__"));
283    env.add_global("exceptions", Value::from("__dbt_exceptions__"));
284    env.add_global("api", Value::from("__dbt_api__"));
285    env.add_global("graph", Value::from("__dbt_graph__"));
286    env.add_global("target", Value::from("__dbt_target__"));
287    env.add_global("invocation_id", Value::from("__dbt_invocation_id__"));
288    env.add_global("run_started_at", Value::from("2025-01-01T00:00:00Z"));
289    env.add_global("flags", Value::from("__dbt_flags__"));
290    env.add_global("modules", Value::from("__dbt_modules__"));
291    env.add_global("dbt_version", Value::from("1.0.0"));
292    env.add_global("model", Value::from("__dbt_model__"));
293    env.add_global("execute", Value::from(true));
294
295    let render_result = env.render_str(sql, ());
296    drop(env);
297
298    match render_result {
299        Ok(_) => {
300            let result = Arc::try_unwrap(extraction)
301                .expect("single owner")
302                .into_inner()
303                .unwrap_or_else(|e| e.into_inner());
304            Some(result)
305        }
306        Err(_) => None,
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn test_simple_ref() {
316        let sql = "SELECT * FROM {{ ref('stg_orders') }}";
317        let ext = extract_via_jinja(sql, "").unwrap();
318        assert_eq!(ext.refs.len(), 1);
319        assert_eq!(ext.refs[0].name, "stg_orders");
320        assert!(ext.refs[0].package.is_none());
321    }
322
323    #[test]
324    fn test_two_arg_ref() {
325        let sql = "SELECT * FROM {{ ref('other_pkg', 'stg_orders') }}";
326        let ext = extract_via_jinja(sql, "").unwrap();
327        assert_eq!(ext.refs.len(), 1);
328        assert_eq!(ext.refs[0].package.as_deref(), Some("other_pkg"));
329        assert_eq!(ext.refs[0].name, "stg_orders");
330    }
331
332    #[test]
333    fn test_source() {
334        let sql = "SELECT * FROM {{ source('raw', 'orders') }}";
335        let ext = extract_via_jinja(sql, "").unwrap();
336        assert_eq!(ext.sources.len(), 1);
337        assert_eq!(ext.sources[0].source_name, "raw");
338        assert_eq!(ext.sources[0].table_name, "orders");
339    }
340
341    #[test]
342    fn test_config() {
343        let sql = "{{ config(materialized='incremental', tags=['nightly', 'finance']) }}\nSELECT 1";
344        let ext = extract_via_jinja(sql, "").unwrap();
345        assert_eq!(ext.config.materialized.as_deref(), Some("incremental"));
346        assert_eq!(ext.config.tags, vec!["nightly", "finance"]);
347    }
348
349    #[test]
350    fn test_mixed() {
351        let sql = r#"
352            {{ config(materialized='table') }}
353            SELECT
354                o.*,
355                c.name
356            FROM {{ ref('stg_orders') }} o
357            JOIN {{ source('raw', 'customers') }} c ON o.customer_id = c.id
358        "#;
359        let ext = extract_via_jinja(sql, "").unwrap();
360        assert_eq!(ext.refs.len(), 1);
361        assert_eq!(ext.sources.len(), 1);
362        assert_eq!(ext.config.materialized.as_deref(), Some("table"));
363    }
364
365    #[test]
366    fn test_ref_inside_set() {
367        let sql = r#"
368            {% set orders = ref('stg_orders') %}
369            SELECT * FROM {{ orders }}
370        "#;
371        let ext = extract_via_jinja(sql, "").unwrap();
372        assert_eq!(ext.refs.len(), 1);
373        assert_eq!(ext.refs[0].name, "stg_orders");
374    }
375
376    #[test]
377    fn test_is_incremental_both_branches() {
378        let sql = r#"
379            {% if is_incremental() %}
380            SELECT * FROM {{ ref('stg_incremental_orders') }}
381            WHERE updated_at > (SELECT max(updated_at) FROM {{ this }})
382            {% else %}
383            SELECT * FROM {{ ref('stg_full_orders') }}
384            {% endif %}
385        "#;
386        let ext = extract_via_jinja(sql, "").unwrap();
387        // Both branches are rendered: unique refs from each branch
388        assert_eq!(ext.refs.len(), 2);
389        assert!(ext.refs.iter().any(|r| r.name == "stg_full_orders"));
390        assert!(ext.refs.iter().any(|r| r.name == "stg_incremental_orders"));
391    }
392
393    #[test]
394    fn test_jinja_comment_ignored() {
395        let sql = r#"
396            {# This is a comment with {{ ref('should_be_ignored') }} #}
397            SELECT * FROM {{ ref('actual_model') }}
398        "#;
399        let ext = extract_via_jinja(sql, "").unwrap();
400        assert_eq!(ext.refs.len(), 1);
401        assert_eq!(ext.refs[0].name, "actual_model");
402    }
403
404    #[test]
405    fn test_whitespace_control() {
406        let sql = "SELECT * FROM {{- ref('stg_orders') -}}";
407        let ext = extract_via_jinja(sql, "").unwrap();
408        assert_eq!(ext.refs.len(), 1);
409        assert_eq!(ext.refs[0].name, "stg_orders");
410    }
411
412    #[test]
413    fn test_var_with_default() {
414        let sql = "SELECT * FROM {{ ref('model_' ~ var('suffix', 'default')) }}";
415        let ext = extract_via_jinja(sql, "").unwrap();
416        assert_eq!(ext.refs.len(), 1);
417        assert_eq!(ext.refs[0].name, "model_default");
418    }
419
420    #[test]
421    fn test_var_resolved_from_project_vars() {
422        let sql = "SELECT * FROM {{ ref('model_' ~ var('suffix')) }}";
423        let mut vars = HashMap::new();
424        vars.insert(
425            "suffix".to_string(),
426            serde_json::Value::String("prod".to_string()),
427        );
428        let ext = extract_via_jinja_with_vars(sql, "", &vars).unwrap();
429        assert_eq!(ext.refs.len(), 1);
430        assert_eq!(ext.refs[0].name, "model_prod");
431    }
432
433    #[test]
434    fn test_var_list_expansion_in_for_loop() {
435        // Reproduces the reported bug: var() returning a list should iterate
436        // as a list, not char-by-char as a string.
437        let sql = r#"
438            {%- set categories = var("product_categories") -%}
439            {%- for cat in categories -%}
440                SELECT * FROM {{ ref('stg_' ~ cat ~ '_summary') }}
441                {% if not loop.last %}UNION ALL{% endif %}
442            {% endfor -%}
443        "#;
444        let mut vars = HashMap::new();
445        vars.insert(
446            "product_categories".to_string(),
447            serde_json::json!(["electronics", "clothing"]),
448        );
449        let ext = extract_via_jinja_with_vars(sql, "", &vars).unwrap();
450        assert_eq!(ext.refs.len(), 2);
451        assert!(ext.refs.iter().any(|r| r.name == "stg_electronics_summary"));
452        assert!(ext.refs.iter().any(|r| r.name == "stg_clothing_summary"));
453    }
454
455    #[test]
456    fn test_var_project_overrides_default() {
457        // When project vars are provided, they should take precedence over
458        // the default argument in var().
459        let sql = "SELECT * FROM {{ ref('model_' ~ var('env', 'dev')) }}";
460        let mut vars = HashMap::new();
461        vars.insert(
462            "env".to_string(),
463            serde_json::Value::String("staging".to_string()),
464        );
465        let ext = extract_via_jinja_with_vars(sql, "", &vars).unwrap();
466        assert_eq!(ext.refs.len(), 1);
467        assert_eq!(ext.refs[0].name, "model_staging");
468    }
469
470    #[test]
471    fn test_var_unknown_falls_back_to_default() {
472        // When a var is not in project vars, fall back to the default argument.
473        let sql = "SELECT * FROM {{ ref('model_' ~ var('missing', 'fallback')) }}";
474        let mut vars = HashMap::new();
475        vars.insert(
476            "other_var".to_string(),
477            serde_json::Value::String("unused".to_string()),
478        );
479        let ext = extract_via_jinja_with_vars(sql, "", &vars).unwrap();
480        assert_eq!(ext.refs.len(), 1);
481        assert_eq!(ext.refs[0].name, "model_fallback");
482    }
483
484    #[test]
485    fn test_for_loop_with_refs() {
486        let sql = r#"
487            {% for src in ['orders', 'customers'] %}
488                SELECT * FROM {{ source('raw', src) }}
489                {% if not loop.last %}UNION ALL{% endif %}
490            {% endfor %}
491        "#;
492        let ext = extract_via_jinja(sql, "").unwrap();
493        assert_eq!(ext.sources.len(), 2);
494        assert_eq!(ext.sources[0].source_name, "raw");
495        assert_eq!(ext.sources[0].table_name, "orders");
496        assert_eq!(ext.sources[1].source_name, "raw");
497        assert_eq!(ext.sources[1].table_name, "customers");
498    }
499
500    #[test]
501    fn test_config_with_extra_kwargs() {
502        let sql = "{{ config(materialized='incremental', schema='analytics', unique_key='id', tags=['nightly']) }}\nSELECT 1";
503        let ext = extract_via_jinja(sql, "").unwrap();
504        assert_eq!(ext.config.materialized.as_deref(), Some("incremental"));
505        assert_eq!(ext.config.tags, vec!["nightly"]);
506    }
507
508    #[test]
509    fn test_ref_with_version_kwarg() {
510        let sql = "SELECT * FROM {{ ref('my_model', version=2) }}";
511        let ext = extract_via_jinja(sql, "").unwrap();
512        assert_eq!(ext.refs.len(), 1);
513        assert_eq!(ext.refs[0].name, "my_model");
514        assert_eq!(ext.refs[0].version.as_deref(), Some("2"));
515        assert!(ext.refs[0].package.is_none());
516    }
517
518    #[test]
519    fn test_ref_with_version_kwarg_and_package() {
520        let sql = "SELECT * FROM {{ ref('mypkg', 'my_model', version=3) }}";
521        let ext = extract_via_jinja(sql, "").unwrap();
522        assert_eq!(ext.refs.len(), 1);
523        assert_eq!(ext.refs[0].package.as_deref(), Some("mypkg"));
524        assert_eq!(ext.refs[0].name, "my_model");
525        assert_eq!(ext.refs[0].version.as_deref(), Some("3"));
526    }
527
528    #[test]
529    fn test_ref_without_version_has_none() {
530        let sql = "SELECT * FROM {{ ref('my_model') }}";
531        let ext = extract_via_jinja(sql, "").unwrap();
532        assert_eq!(ext.refs[0].version, None);
533    }
534
535    #[test]
536    fn test_ref_with_v_shorthand_kwarg() {
537        let sql = "SELECT * FROM {{ ref('my_model', v=2) }}";
538        let ext = extract_via_jinja(sql, "").unwrap();
539        assert_eq!(ext.refs.len(), 1);
540        assert_eq!(ext.refs[0].name, "my_model");
541        assert_eq!(ext.refs[0].version.as_deref(), Some("2"));
542        assert!(ext.refs[0].package.is_none());
543    }
544
545    #[test]
546    fn test_ref_with_v_shorthand_kwarg_and_package() {
547        let sql = "SELECT * FROM {{ ref('mypkg', 'my_model', v=3) }}";
548        let ext = extract_via_jinja(sql, "").unwrap();
549        assert_eq!(ext.refs.len(), 1);
550        assert_eq!(ext.refs[0].package.as_deref(), Some("mypkg"));
551        assert_eq!(ext.refs[0].name, "my_model");
552        assert_eq!(ext.refs[0].version.as_deref(), Some("3"));
553    }
554
555    #[test]
556    fn test_ref_with_string_version_kwarg() {
557        // version='alpha' (non-numeric string) passes through unchanged
558        let sql = "SELECT * FROM {{ ref('my_model', version='alpha') }}";
559        let ext = extract_via_jinja(sql, "").unwrap();
560        assert_eq!(ext.refs[0].version.as_deref(), Some("alpha"));
561    }
562
563    #[test]
564    fn test_ref_with_padded_integer_version_kwarg() {
565        // version='02' (string kwarg) must normalize to "2"
566        let sql = "SELECT * FROM {{ ref('my_model', version='02') }}";
567        let ext = extract_via_jinja(sql, "").unwrap();
568        assert_eq!(ext.refs[0].version.as_deref(), Some("2"));
569    }
570
571    #[test]
572    fn test_ref_with_decimal_version_kwarg() {
573        // version='2.0' stays as "2.0" — matching YAML `v: "2.0"` which also keeps "2.0".
574        let sql = "SELECT * FROM {{ ref('my_model', version='2.0') }}";
575        let ext = extract_via_jinja(sql, "").unwrap();
576        assert_eq!(ext.refs[0].version.as_deref(), Some("2.0"));
577    }
578
579    #[test]
580    fn test_returns_none_on_unsupported_template() {
581        // Unknown block tags should cause failure
582        let sql = "{% materialization table, default %} SELECT 1 {% endmaterialization %}";
583        let result = extract_via_jinja(sql, "");
584        assert!(result.is_none());
585    }
586
587    #[test]
588    fn test_macro_ref_extraction() {
589        let macro_src = r#"
590            {% macro my_cte() %}
591                SELECT * FROM {{ ref('base_model') }}
592            {% endmacro %}
593        "#;
594        let sql = "SELECT * FROM ({{ my_cte() }})";
595        let ext = extract_via_jinja(sql, macro_src).unwrap();
596        assert_eq!(ext.refs.len(), 1);
597        assert_eq!(ext.refs[0].name, "base_model");
598    }
599
600    #[test]
601    fn test_macro_source_extraction() {
602        let macro_src = r#"
603            {% macro raw_data(table) %}
604                SELECT * FROM {{ source('raw', table) }}
605            {% endmacro %}
606        "#;
607        let sql = "SELECT * FROM ({{ raw_data('orders') }})";
608        let ext = extract_via_jinja(sql, macro_src).unwrap();
609        assert_eq!(ext.sources.len(), 1);
610        assert_eq!(ext.sources[0].source_name, "raw");
611        assert_eq!(ext.sources[0].table_name, "orders");
612    }
613
614    #[test]
615    fn test_macro_with_multiple_refs() {
616        let macro_src = r#"
617            {% macro join_tables(period) %}
618                SELECT * FROM {{ ref('deals') }}
619                LEFT JOIN {{ ref('providers') }} ON 1=1
620                LEFT JOIN {{ source('raw', 'prices') }} ON 1=1
621            {% endmacro %}
622        "#;
623        let sql = "{{ join_tables('day') }}";
624        let ext = extract_via_jinja(sql, macro_src).unwrap();
625        assert_eq!(ext.refs.len(), 2);
626        assert!(ext.refs.iter().any(|r| r.name == "deals"));
627        assert!(ext.refs.iter().any(|r| r.name == "providers"));
628        assert_eq!(ext.sources.len(), 1);
629        assert_eq!(ext.sources[0].table_name, "prices");
630    }
631
632    #[test]
633    fn test_multiple_macro_files() {
634        let sources = vec![
635            r#"
636            {% macro get_orders() %}
637                SELECT * FROM {{ ref('stg_orders') }}
638            {% endmacro %}
639            "#
640            .to_string(),
641            r#"
642            {% macro get_customers() %}
643                SELECT * FROM {{ ref('stg_customers') }}
644            {% endmacro %}
645            "#
646            .to_string(),
647        ];
648        let prefix = build_macro_prefix(&sources);
649        let sql = "{{ get_orders() }} UNION ALL {{ get_customers() }}";
650        let ext = extract_via_jinja(sql, &prefix).unwrap();
651        assert_eq!(ext.refs.len(), 2);
652        assert!(ext.refs.iter().any(|r| r.name == "stg_orders"));
653        assert!(ext.refs.iter().any(|r| r.name == "stg_customers"));
654    }
655
656    #[test]
657    fn test_build_macro_prefix_skips_invalid() {
658        let sources = vec![
659            "{% macro good() %}SELECT 1{% endmacro %}".to_string(),
660            // Invalid: unsupported block tag
661            "{% materialization custom %} stuff {% endmaterialization %}".to_string(),
662            "{% macro also_good() %}SELECT 2{% endmacro %}".to_string(),
663            // Invalid: unclosed raw block
664            "{% raw %}unclosed raw content".to_string(),
665        ];
666        let prefix = build_macro_prefix(&sources);
667        assert!(prefix.contains("{% macro good() %}"));
668        assert!(prefix.contains("{% macro also_good() %}"));
669        assert!(!prefix.contains("materialization"));
670        assert!(!prefix.contains("{% raw %}"));
671    }
672
673    #[test]
674    fn test_build_macro_prefix_includes_compatible_macros() {
675        let env = Environment::new();
676
677        let macro_a = "{% macro a() %}ok{% endmacro %}".to_string();
678        let macro_b = "{% macro b() %}ok{% endmacro %}".to_string();
679        assert!(env.template_from_str(&macro_a).is_ok());
680        assert!(env.template_from_str(&macro_b).is_ok());
681
682        let sources = vec![macro_a, macro_b];
683        let prefix = build_macro_prefix(&sources);
684        assert!(prefix.contains("{% macro a() %}"));
685        assert!(prefix.contains("{% macro b() %}"));
686    }
687
688    #[test]
689    fn test_invalid_macro_skipped_refs_still_extracted() {
690        let sources = vec![
691            // Bad macro that would poison everything if not filtered
692            "{% materialization custom %} stuff {% endmaterialization %}".to_string(),
693        ];
694        let prefix = build_macro_prefix(&sources);
695        let sql = "SELECT * FROM {{ ref('orders') }}";
696        let ext = extract_via_jinja(sql, &prefix).unwrap();
697        assert_eq!(ext.refs.len(), 1);
698        assert_eq!(ext.refs[0].name, "orders");
699    }
700}