1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use minijinja::value::{Kwargs, from_args};
5use minijinja::{Environment, ErrorKind, Value};
6
7use super::sql::{RefCall, SourceCall, SqlConfig, normalize_version_str};
8
9#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
11pub struct JinjaExtraction {
12 pub refs: Vec<RefCall>,
13 pub sources: Vec<SourceCall>,
14 pub config: SqlConfig,
15}
16
17pub fn extract_via_jinja(sql: &str, macro_prefix: &str) -> Option<JinjaExtraction> {
26 extract_via_jinja_with_vars(sql, macro_prefix, &HashMap::new())
27}
28
29pub fn extract_via_jinja_with_vars(
32 sql: &str,
33 macro_prefix: &str,
34 vars: &HashMap<String, serde_json::Value>,
35) -> Option<JinjaExtraction> {
36 let template = if macro_prefix.is_empty() {
37 sql.to_string()
38 } else {
39 format!("{}\n{}", macro_prefix, sql)
40 };
41
42 let mut result = render_with_incremental(&template, false, vars)?;
44
45 if let Some(incr) = render_with_incremental(&template, true, vars) {
47 merge_extraction(&mut result, incr);
48 }
49
50 Some(result)
51}
52
53pub fn build_macro_prefix(macro_sources: &[String]) -> String {
57 if macro_sources.is_empty() {
58 return String::new();
59 }
60 let env = Environment::new();
61 let mut prefix = String::new();
62 for source in macro_sources {
63 if env.template_from_str(source).is_err() {
65 continue;
66 }
67 let len = prefix.len();
69 prefix.push_str(source);
70 prefix.push('\n');
71 if env.template_from_str(&prefix).is_err() {
72 prefix.truncate(len);
73 }
74 }
75 prefix
76}
77
78fn merge_extraction(base: &mut JinjaExtraction, other: JinjaExtraction) {
80 for r in other.refs {
81 if !base.refs.contains(&r) {
82 base.refs.push(r);
83 }
84 }
85 for s in other.sources {
86 if !base.sources.contains(&s) {
87 base.sources.push(s);
88 }
89 }
90 }
92
93fn json_to_minijinja(v: &serde_json::Value) -> Value {
95 Value::from_serialize(v)
96}
97
98fn render_with_incremental(
100 sql: &str,
101 is_incremental: bool,
102 vars: &HashMap<String, serde_json::Value>,
103) -> Option<JinjaExtraction> {
104 let extraction = Arc::new(Mutex::new(JinjaExtraction::default()));
105
106 let mut env = Environment::new();
107 env.set_undefined_behavior(minijinja::UndefinedBehavior::Lenient);
108
109 let ext = extraction.clone();
113 env.add_function(
114 "ref",
115 move |args: &[Value]| -> Result<Value, minijinja::Error> {
116 let mut ext = ext.lock().unwrap();
117 let (positional, kwargs): (&[Value], Kwargs) = from_args(args)
118 .map_err(|e| minijinja::Error::new(ErrorKind::InvalidOperation, e.to_string()))?;
119 let version: Option<String> = kwargs
123 .peek::<i64>("version")
124 .ok()
125 .map(|n| n.to_string())
126 .or_else(|| {
127 kwargs
128 .peek::<String>("version")
129 .ok()
130 .map(|s| normalize_version_str(&s))
131 })
132 .or_else(|| kwargs.peek::<i64>("v").ok().map(|n| n.to_string()))
133 .or_else(|| {
134 kwargs
135 .peek::<String>("v")
136 .ok()
137 .map(|s| normalize_version_str(&s))
138 });
139 match positional.len() {
140 1 => {
141 let name = positional[0].to_string();
142 ext.refs.push(RefCall {
143 package: None,
144 name: name.clone(),
145 version,
146 });
147 Ok(Value::from(format!("__dbt_ref_{}__", name)))
148 }
149 2 => {
150 let pkg = positional[0].to_string();
151 let name = positional[1].to_string();
152 ext.refs.push(RefCall {
153 package: Some(pkg),
154 name: name.clone(),
155 version,
156 });
157 Ok(Value::from(format!("__dbt_ref_{}__", name)))
158 }
159 _ => Err(minijinja::Error::new(
160 ErrorKind::TooManyArguments,
161 "ref() takes 1 or 2 positional arguments",
162 )),
163 }
164 },
165 );
166
167 let ext = extraction.clone();
169 env.add_function(
170 "source",
171 move |args: &[Value]| -> Result<Value, minijinja::Error> {
172 if args.len() >= 2 {
173 let source_name = args[0].to_string();
174 let table_name = args[1].to_string();
175 ext.lock().unwrap().sources.push(SourceCall {
176 source_name: source_name.clone(),
177 table_name: table_name.clone(),
178 });
179 Ok(Value::from(format!(
180 "__dbt_source_{}_{}__",
181 source_name, table_name
182 )))
183 } else {
184 Err(minijinja::Error::new(
185 ErrorKind::MissingArgument,
186 "source() requires 2 arguments",
187 ))
188 }
189 },
190 );
191
192 let ext = extraction.clone();
195 env.add_function(
196 "config",
197 move |kwargs: Kwargs| -> Result<Value, minijinja::Error> {
198 let mut ext = ext.lock().unwrap();
199 if let Ok(mat) = kwargs.get::<&str>("materialized") {
200 ext.config.materialized = Some(mat.to_string());
201 }
202 if let Ok(tags_val) = kwargs.get::<Value>("tags")
203 && let Ok(iter) = tags_val.try_iter()
204 {
205 ext.config.tags = iter.map(|v| v.to_string()).collect();
206 }
207 Ok(Value::from(""))
208 },
209 );
210
211 env.add_function(
213 "is_incremental",
214 move || -> Result<Value, minijinja::Error> { Ok(Value::from(is_incremental)) },
215 );
216
217 env.add_global("this", Value::from("__dbt_this__"));
219
220 let vars_map: HashMap<String, Value> = vars
222 .iter()
223 .map(|(k, v)| (k.clone(), json_to_minijinja(v)))
224 .collect();
225 env.add_function(
226 "var",
227 move |args: &[Value]| -> Result<Value, minijinja::Error> {
228 if let Some(key) = args.first()
229 && let Some(key_str) = key.as_str()
230 && let Some(val) = vars_map.get(key_str)
231 {
232 return Ok(val.clone());
233 }
234 if args.len() >= 2 {
236 Ok(args[1].clone())
237 } else {
238 Ok(Value::from("__dbt_var_unknown__"))
239 }
240 },
241 );
242
243 env.add_function(
245 "env_var",
246 |args: &[Value]| -> Result<Value, minijinja::Error> {
247 if args.len() >= 2 {
248 Ok(args[1].clone())
249 } else {
250 Ok(Value::from(""))
251 }
252 },
253 );
254
255 env.add_function(
257 "return",
258 |args: &[Value]| -> Result<Value, minijinja::Error> {
259 Ok(args.first().cloned().unwrap_or(Value::from("")))
260 },
261 );
262
263 env.add_function(
265 "log",
266 |_args: &[Value]| -> Result<Value, minijinja::Error> { Ok(Value::from("")) },
267 );
268
269 env.add_function(
271 "run_query",
272 |_args: &[Value]| -> Result<Value, minijinja::Error> { Ok(Value::from("")) },
273 );
274
275 env.add_function(
277 "statement",
278 |_args: &[Value]| -> Result<Value, minijinja::Error> { Ok(Value::from("")) },
279 );
280
281 env.add_global("adapter", Value::from("__dbt_adapter__"));
283 env.add_global("exceptions", Value::from("__dbt_exceptions__"));
284 env.add_global("api", Value::from("__dbt_api__"));
285 env.add_global("graph", Value::from("__dbt_graph__"));
286 env.add_global("target", Value::from("__dbt_target__"));
287 env.add_global("invocation_id", Value::from("__dbt_invocation_id__"));
288 env.add_global("run_started_at", Value::from("2025-01-01T00:00:00Z"));
289 env.add_global("flags", Value::from("__dbt_flags__"));
290 env.add_global("modules", Value::from("__dbt_modules__"));
291 env.add_global("dbt_version", Value::from("1.0.0"));
292 env.add_global("model", Value::from("__dbt_model__"));
293 env.add_global("execute", Value::from(true));
294
295 let render_result = env.render_str(sql, ());
296 drop(env);
297
298 match render_result {
299 Ok(_) => {
300 let result = Arc::try_unwrap(extraction)
301 .expect("single owner")
302 .into_inner()
303 .unwrap_or_else(|e| e.into_inner());
304 Some(result)
305 }
306 Err(_) => None,
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
315 fn test_simple_ref() {
316 let sql = "SELECT * FROM {{ ref('stg_orders') }}";
317 let ext = extract_via_jinja(sql, "").unwrap();
318 assert_eq!(ext.refs.len(), 1);
319 assert_eq!(ext.refs[0].name, "stg_orders");
320 assert!(ext.refs[0].package.is_none());
321 }
322
323 #[test]
324 fn test_two_arg_ref() {
325 let sql = "SELECT * FROM {{ ref('other_pkg', 'stg_orders') }}";
326 let ext = extract_via_jinja(sql, "").unwrap();
327 assert_eq!(ext.refs.len(), 1);
328 assert_eq!(ext.refs[0].package.as_deref(), Some("other_pkg"));
329 assert_eq!(ext.refs[0].name, "stg_orders");
330 }
331
332 #[test]
333 fn test_source() {
334 let sql = "SELECT * FROM {{ source('raw', 'orders') }}";
335 let ext = extract_via_jinja(sql, "").unwrap();
336 assert_eq!(ext.sources.len(), 1);
337 assert_eq!(ext.sources[0].source_name, "raw");
338 assert_eq!(ext.sources[0].table_name, "orders");
339 }
340
341 #[test]
342 fn test_config() {
343 let sql = "{{ config(materialized='incremental', tags=['nightly', 'finance']) }}\nSELECT 1";
344 let ext = extract_via_jinja(sql, "").unwrap();
345 assert_eq!(ext.config.materialized.as_deref(), Some("incremental"));
346 assert_eq!(ext.config.tags, vec!["nightly", "finance"]);
347 }
348
349 #[test]
350 fn test_mixed() {
351 let sql = r#"
352 {{ config(materialized='table') }}
353 SELECT
354 o.*,
355 c.name
356 FROM {{ ref('stg_orders') }} o
357 JOIN {{ source('raw', 'customers') }} c ON o.customer_id = c.id
358 "#;
359 let ext = extract_via_jinja(sql, "").unwrap();
360 assert_eq!(ext.refs.len(), 1);
361 assert_eq!(ext.sources.len(), 1);
362 assert_eq!(ext.config.materialized.as_deref(), Some("table"));
363 }
364
365 #[test]
366 fn test_ref_inside_set() {
367 let sql = r#"
368 {% set orders = ref('stg_orders') %}
369 SELECT * FROM {{ orders }}
370 "#;
371 let ext = extract_via_jinja(sql, "").unwrap();
372 assert_eq!(ext.refs.len(), 1);
373 assert_eq!(ext.refs[0].name, "stg_orders");
374 }
375
376 #[test]
377 fn test_is_incremental_both_branches() {
378 let sql = r#"
379 {% if is_incremental() %}
380 SELECT * FROM {{ ref('stg_incremental_orders') }}
381 WHERE updated_at > (SELECT max(updated_at) FROM {{ this }})
382 {% else %}
383 SELECT * FROM {{ ref('stg_full_orders') }}
384 {% endif %}
385 "#;
386 let ext = extract_via_jinja(sql, "").unwrap();
387 assert_eq!(ext.refs.len(), 2);
389 assert!(ext.refs.iter().any(|r| r.name == "stg_full_orders"));
390 assert!(ext.refs.iter().any(|r| r.name == "stg_incremental_orders"));
391 }
392
393 #[test]
394 fn test_jinja_comment_ignored() {
395 let sql = r#"
396 {# This is a comment with {{ ref('should_be_ignored') }} #}
397 SELECT * FROM {{ ref('actual_model') }}
398 "#;
399 let ext = extract_via_jinja(sql, "").unwrap();
400 assert_eq!(ext.refs.len(), 1);
401 assert_eq!(ext.refs[0].name, "actual_model");
402 }
403
404 #[test]
405 fn test_whitespace_control() {
406 let sql = "SELECT * FROM {{- ref('stg_orders') -}}";
407 let ext = extract_via_jinja(sql, "").unwrap();
408 assert_eq!(ext.refs.len(), 1);
409 assert_eq!(ext.refs[0].name, "stg_orders");
410 }
411
412 #[test]
413 fn test_var_with_default() {
414 let sql = "SELECT * FROM {{ ref('model_' ~ var('suffix', 'default')) }}";
415 let ext = extract_via_jinja(sql, "").unwrap();
416 assert_eq!(ext.refs.len(), 1);
417 assert_eq!(ext.refs[0].name, "model_default");
418 }
419
420 #[test]
421 fn test_var_resolved_from_project_vars() {
422 let sql = "SELECT * FROM {{ ref('model_' ~ var('suffix')) }}";
423 let mut vars = HashMap::new();
424 vars.insert(
425 "suffix".to_string(),
426 serde_json::Value::String("prod".to_string()),
427 );
428 let ext = extract_via_jinja_with_vars(sql, "", &vars).unwrap();
429 assert_eq!(ext.refs.len(), 1);
430 assert_eq!(ext.refs[0].name, "model_prod");
431 }
432
433 #[test]
434 fn test_var_list_expansion_in_for_loop() {
435 let sql = r#"
438 {%- set categories = var("product_categories") -%}
439 {%- for cat in categories -%}
440 SELECT * FROM {{ ref('stg_' ~ cat ~ '_summary') }}
441 {% if not loop.last %}UNION ALL{% endif %}
442 {% endfor -%}
443 "#;
444 let mut vars = HashMap::new();
445 vars.insert(
446 "product_categories".to_string(),
447 serde_json::json!(["electronics", "clothing"]),
448 );
449 let ext = extract_via_jinja_with_vars(sql, "", &vars).unwrap();
450 assert_eq!(ext.refs.len(), 2);
451 assert!(ext.refs.iter().any(|r| r.name == "stg_electronics_summary"));
452 assert!(ext.refs.iter().any(|r| r.name == "stg_clothing_summary"));
453 }
454
455 #[test]
456 fn test_var_project_overrides_default() {
457 let sql = "SELECT * FROM {{ ref('model_' ~ var('env', 'dev')) }}";
460 let mut vars = HashMap::new();
461 vars.insert(
462 "env".to_string(),
463 serde_json::Value::String("staging".to_string()),
464 );
465 let ext = extract_via_jinja_with_vars(sql, "", &vars).unwrap();
466 assert_eq!(ext.refs.len(), 1);
467 assert_eq!(ext.refs[0].name, "model_staging");
468 }
469
470 #[test]
471 fn test_var_unknown_falls_back_to_default() {
472 let sql = "SELECT * FROM {{ ref('model_' ~ var('missing', 'fallback')) }}";
474 let mut vars = HashMap::new();
475 vars.insert(
476 "other_var".to_string(),
477 serde_json::Value::String("unused".to_string()),
478 );
479 let ext = extract_via_jinja_with_vars(sql, "", &vars).unwrap();
480 assert_eq!(ext.refs.len(), 1);
481 assert_eq!(ext.refs[0].name, "model_fallback");
482 }
483
484 #[test]
485 fn test_for_loop_with_refs() {
486 let sql = r#"
487 {% for src in ['orders', 'customers'] %}
488 SELECT * FROM {{ source('raw', src) }}
489 {% if not loop.last %}UNION ALL{% endif %}
490 {% endfor %}
491 "#;
492 let ext = extract_via_jinja(sql, "").unwrap();
493 assert_eq!(ext.sources.len(), 2);
494 assert_eq!(ext.sources[0].source_name, "raw");
495 assert_eq!(ext.sources[0].table_name, "orders");
496 assert_eq!(ext.sources[1].source_name, "raw");
497 assert_eq!(ext.sources[1].table_name, "customers");
498 }
499
500 #[test]
501 fn test_config_with_extra_kwargs() {
502 let sql = "{{ config(materialized='incremental', schema='analytics', unique_key='id', tags=['nightly']) }}\nSELECT 1";
503 let ext = extract_via_jinja(sql, "").unwrap();
504 assert_eq!(ext.config.materialized.as_deref(), Some("incremental"));
505 assert_eq!(ext.config.tags, vec!["nightly"]);
506 }
507
508 #[test]
509 fn test_ref_with_version_kwarg() {
510 let sql = "SELECT * FROM {{ ref('my_model', version=2) }}";
511 let ext = extract_via_jinja(sql, "").unwrap();
512 assert_eq!(ext.refs.len(), 1);
513 assert_eq!(ext.refs[0].name, "my_model");
514 assert_eq!(ext.refs[0].version.as_deref(), Some("2"));
515 assert!(ext.refs[0].package.is_none());
516 }
517
518 #[test]
519 fn test_ref_with_version_kwarg_and_package() {
520 let sql = "SELECT * FROM {{ ref('mypkg', 'my_model', version=3) }}";
521 let ext = extract_via_jinja(sql, "").unwrap();
522 assert_eq!(ext.refs.len(), 1);
523 assert_eq!(ext.refs[0].package.as_deref(), Some("mypkg"));
524 assert_eq!(ext.refs[0].name, "my_model");
525 assert_eq!(ext.refs[0].version.as_deref(), Some("3"));
526 }
527
528 #[test]
529 fn test_ref_without_version_has_none() {
530 let sql = "SELECT * FROM {{ ref('my_model') }}";
531 let ext = extract_via_jinja(sql, "").unwrap();
532 assert_eq!(ext.refs[0].version, None);
533 }
534
535 #[test]
536 fn test_ref_with_v_shorthand_kwarg() {
537 let sql = "SELECT * FROM {{ ref('my_model', v=2) }}";
538 let ext = extract_via_jinja(sql, "").unwrap();
539 assert_eq!(ext.refs.len(), 1);
540 assert_eq!(ext.refs[0].name, "my_model");
541 assert_eq!(ext.refs[0].version.as_deref(), Some("2"));
542 assert!(ext.refs[0].package.is_none());
543 }
544
545 #[test]
546 fn test_ref_with_v_shorthand_kwarg_and_package() {
547 let sql = "SELECT * FROM {{ ref('mypkg', 'my_model', v=3) }}";
548 let ext = extract_via_jinja(sql, "").unwrap();
549 assert_eq!(ext.refs.len(), 1);
550 assert_eq!(ext.refs[0].package.as_deref(), Some("mypkg"));
551 assert_eq!(ext.refs[0].name, "my_model");
552 assert_eq!(ext.refs[0].version.as_deref(), Some("3"));
553 }
554
555 #[test]
556 fn test_ref_with_string_version_kwarg() {
557 let sql = "SELECT * FROM {{ ref('my_model', version='alpha') }}";
559 let ext = extract_via_jinja(sql, "").unwrap();
560 assert_eq!(ext.refs[0].version.as_deref(), Some("alpha"));
561 }
562
563 #[test]
564 fn test_ref_with_padded_integer_version_kwarg() {
565 let sql = "SELECT * FROM {{ ref('my_model', version='02') }}";
567 let ext = extract_via_jinja(sql, "").unwrap();
568 assert_eq!(ext.refs[0].version.as_deref(), Some("2"));
569 }
570
571 #[test]
572 fn test_ref_with_decimal_version_kwarg() {
573 let sql = "SELECT * FROM {{ ref('my_model', version='2.0') }}";
575 let ext = extract_via_jinja(sql, "").unwrap();
576 assert_eq!(ext.refs[0].version.as_deref(), Some("2.0"));
577 }
578
579 #[test]
580 fn test_returns_none_on_unsupported_template() {
581 let sql = "{% materialization table, default %} SELECT 1 {% endmaterialization %}";
583 let result = extract_via_jinja(sql, "");
584 assert!(result.is_none());
585 }
586
587 #[test]
588 fn test_macro_ref_extraction() {
589 let macro_src = r#"
590 {% macro my_cte() %}
591 SELECT * FROM {{ ref('base_model') }}
592 {% endmacro %}
593 "#;
594 let sql = "SELECT * FROM ({{ my_cte() }})";
595 let ext = extract_via_jinja(sql, macro_src).unwrap();
596 assert_eq!(ext.refs.len(), 1);
597 assert_eq!(ext.refs[0].name, "base_model");
598 }
599
600 #[test]
601 fn test_macro_source_extraction() {
602 let macro_src = r#"
603 {% macro raw_data(table) %}
604 SELECT * FROM {{ source('raw', table) }}
605 {% endmacro %}
606 "#;
607 let sql = "SELECT * FROM ({{ raw_data('orders') }})";
608 let ext = extract_via_jinja(sql, macro_src).unwrap();
609 assert_eq!(ext.sources.len(), 1);
610 assert_eq!(ext.sources[0].source_name, "raw");
611 assert_eq!(ext.sources[0].table_name, "orders");
612 }
613
614 #[test]
615 fn test_macro_with_multiple_refs() {
616 let macro_src = r#"
617 {% macro join_tables(period) %}
618 SELECT * FROM {{ ref('deals') }}
619 LEFT JOIN {{ ref('providers') }} ON 1=1
620 LEFT JOIN {{ source('raw', 'prices') }} ON 1=1
621 {% endmacro %}
622 "#;
623 let sql = "{{ join_tables('day') }}";
624 let ext = extract_via_jinja(sql, macro_src).unwrap();
625 assert_eq!(ext.refs.len(), 2);
626 assert!(ext.refs.iter().any(|r| r.name == "deals"));
627 assert!(ext.refs.iter().any(|r| r.name == "providers"));
628 assert_eq!(ext.sources.len(), 1);
629 assert_eq!(ext.sources[0].table_name, "prices");
630 }
631
632 #[test]
633 fn test_multiple_macro_files() {
634 let sources = vec![
635 r#"
636 {% macro get_orders() %}
637 SELECT * FROM {{ ref('stg_orders') }}
638 {% endmacro %}
639 "#
640 .to_string(),
641 r#"
642 {% macro get_customers() %}
643 SELECT * FROM {{ ref('stg_customers') }}
644 {% endmacro %}
645 "#
646 .to_string(),
647 ];
648 let prefix = build_macro_prefix(&sources);
649 let sql = "{{ get_orders() }} UNION ALL {{ get_customers() }}";
650 let ext = extract_via_jinja(sql, &prefix).unwrap();
651 assert_eq!(ext.refs.len(), 2);
652 assert!(ext.refs.iter().any(|r| r.name == "stg_orders"));
653 assert!(ext.refs.iter().any(|r| r.name == "stg_customers"));
654 }
655
656 #[test]
657 fn test_build_macro_prefix_skips_invalid() {
658 let sources = vec![
659 "{% macro good() %}SELECT 1{% endmacro %}".to_string(),
660 "{% materialization custom %} stuff {% endmaterialization %}".to_string(),
662 "{% macro also_good() %}SELECT 2{% endmacro %}".to_string(),
663 "{% raw %}unclosed raw content".to_string(),
665 ];
666 let prefix = build_macro_prefix(&sources);
667 assert!(prefix.contains("{% macro good() %}"));
668 assert!(prefix.contains("{% macro also_good() %}"));
669 assert!(!prefix.contains("materialization"));
670 assert!(!prefix.contains("{% raw %}"));
671 }
672
673 #[test]
674 fn test_build_macro_prefix_includes_compatible_macros() {
675 let env = Environment::new();
676
677 let macro_a = "{% macro a() %}ok{% endmacro %}".to_string();
678 let macro_b = "{% macro b() %}ok{% endmacro %}".to_string();
679 assert!(env.template_from_str(¯o_a).is_ok());
680 assert!(env.template_from_str(¯o_b).is_ok());
681
682 let sources = vec![macro_a, macro_b];
683 let prefix = build_macro_prefix(&sources);
684 assert!(prefix.contains("{% macro a() %}"));
685 assert!(prefix.contains("{% macro b() %}"));
686 }
687
688 #[test]
689 fn test_invalid_macro_skipped_refs_still_extracted() {
690 let sources = vec![
691 "{% materialization custom %} stuff {% endmaterialization %}".to_string(),
693 ];
694 let prefix = build_macro_prefix(&sources);
695 let sql = "SELECT * FROM {{ ref('orders') }}";
696 let ext = extract_via_jinja(sql, &prefix).unwrap();
697 assert_eq!(ext.refs.len(), 1);
698 assert_eq!(ext.refs[0].name, "orders");
699 }
700}