1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use minijinja::value::Kwargs;
5use minijinja::{Environment, ErrorKind, Value};
6
7use super::sql::{RefCall, SourceCall, SqlConfig};
8
9#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
11pub struct JinjaExtraction {
12 pub refs: Vec<RefCall>,
13 pub sources: Vec<SourceCall>,
14 pub config: SqlConfig,
15}
16
17pub fn extract_via_jinja(sql: &str, macro_prefix: &str) -> Option<JinjaExtraction> {
26 extract_via_jinja_with_vars(sql, macro_prefix, &HashMap::new())
27}
28
29pub fn extract_via_jinja_with_vars(
32 sql: &str,
33 macro_prefix: &str,
34 vars: &HashMap<String, serde_json::Value>,
35) -> Option<JinjaExtraction> {
36 let template = if macro_prefix.is_empty() {
37 sql.to_string()
38 } else {
39 format!("{}\n{}", macro_prefix, sql)
40 };
41
42 let mut result = render_with_incremental(&template, false, vars)?;
44
45 if let Some(incr) = render_with_incremental(&template, true, vars) {
47 merge_extraction(&mut result, incr);
48 }
49
50 Some(result)
51}
52
53pub fn build_macro_prefix(macro_sources: &[String]) -> String {
57 if macro_sources.is_empty() {
58 return String::new();
59 }
60 let env = Environment::new();
61 let mut prefix = String::new();
62 for source in macro_sources {
63 if env.template_from_str(source).is_err() {
65 continue;
66 }
67 let len = prefix.len();
69 prefix.push_str(source);
70 prefix.push('\n');
71 if env.template_from_str(&prefix).is_err() {
72 prefix.truncate(len);
73 }
74 }
75 prefix
76}
77
78fn merge_extraction(base: &mut JinjaExtraction, other: JinjaExtraction) {
80 for r in other.refs {
81 if !base.refs.contains(&r) {
82 base.refs.push(r);
83 }
84 }
85 for s in other.sources {
86 if !base.sources.contains(&s) {
87 base.sources.push(s);
88 }
89 }
90 }
92
93fn json_to_minijinja(v: &serde_json::Value) -> Value {
95 Value::from_serialize(v)
96}
97
98fn render_with_incremental(
100 sql: &str,
101 is_incremental: bool,
102 vars: &HashMap<String, serde_json::Value>,
103) -> Option<JinjaExtraction> {
104 let extraction = Arc::new(Mutex::new(JinjaExtraction::default()));
105
106 let mut env = Environment::new();
107 env.set_undefined_behavior(minijinja::UndefinedBehavior::Lenient);
108
109 let ext = extraction.clone();
111 env.add_function(
112 "ref",
113 move |args: &[Value]| -> Result<Value, minijinja::Error> {
114 let mut ext = ext.lock().unwrap();
115 match args.len() {
116 1 => {
117 let name = args[0].to_string();
118 ext.refs.push(RefCall {
119 package: None,
120 name: name.clone(),
121 });
122 Ok(Value::from(format!("__dbt_ref_{}__", name)))
123 }
124 2 => {
125 let pkg = args[0].to_string();
126 let name = args[1].to_string();
127 ext.refs.push(RefCall {
128 package: Some(pkg),
129 name: name.clone(),
130 });
131 Ok(Value::from(format!("__dbt_ref_{}__", name)))
132 }
133 _ => Err(minijinja::Error::new(
134 ErrorKind::TooManyArguments,
135 "ref() takes 1 or 2 arguments",
136 )),
137 }
138 },
139 );
140
141 let ext = extraction.clone();
143 env.add_function(
144 "source",
145 move |args: &[Value]| -> Result<Value, minijinja::Error> {
146 if args.len() >= 2 {
147 let source_name = args[0].to_string();
148 let table_name = args[1].to_string();
149 ext.lock().unwrap().sources.push(SourceCall {
150 source_name: source_name.clone(),
151 table_name: table_name.clone(),
152 });
153 Ok(Value::from(format!(
154 "__dbt_source_{}_{}__",
155 source_name, table_name
156 )))
157 } else {
158 Err(minijinja::Error::new(
159 ErrorKind::MissingArgument,
160 "source() requires 2 arguments",
161 ))
162 }
163 },
164 );
165
166 let ext = extraction.clone();
169 env.add_function(
170 "config",
171 move |kwargs: Kwargs| -> Result<Value, minijinja::Error> {
172 let mut ext = ext.lock().unwrap();
173 if let Ok(mat) = kwargs.get::<&str>("materialized") {
174 ext.config.materialized = Some(mat.to_string());
175 }
176 if let Ok(tags_val) = kwargs.get::<Value>("tags")
177 && let Ok(iter) = tags_val.try_iter()
178 {
179 ext.config.tags = iter.map(|v| v.to_string()).collect();
180 }
181 Ok(Value::from(""))
182 },
183 );
184
185 env.add_function(
187 "is_incremental",
188 move || -> Result<Value, minijinja::Error> { Ok(Value::from(is_incremental)) },
189 );
190
191 env.add_global("this", Value::from("__dbt_this__"));
193
194 let vars_map: HashMap<String, Value> = vars
196 .iter()
197 .map(|(k, v)| (k.clone(), json_to_minijinja(v)))
198 .collect();
199 env.add_function(
200 "var",
201 move |args: &[Value]| -> Result<Value, minijinja::Error> {
202 if let Some(key) = args.first()
203 && let Some(key_str) = key.as_str()
204 && let Some(val) = vars_map.get(key_str)
205 {
206 return Ok(val.clone());
207 }
208 if args.len() >= 2 {
210 Ok(args[1].clone())
211 } else {
212 Ok(Value::from("__dbt_var_unknown__"))
213 }
214 },
215 );
216
217 env.add_function(
219 "env_var",
220 |args: &[Value]| -> Result<Value, minijinja::Error> {
221 if args.len() >= 2 {
222 Ok(args[1].clone())
223 } else {
224 Ok(Value::from(""))
225 }
226 },
227 );
228
229 env.add_function(
231 "return",
232 |args: &[Value]| -> Result<Value, minijinja::Error> {
233 Ok(args.first().cloned().unwrap_or(Value::from("")))
234 },
235 );
236
237 env.add_function(
239 "log",
240 |_args: &[Value]| -> Result<Value, minijinja::Error> { Ok(Value::from("")) },
241 );
242
243 env.add_function(
245 "run_query",
246 |_args: &[Value]| -> Result<Value, minijinja::Error> { Ok(Value::from("")) },
247 );
248
249 env.add_function(
251 "statement",
252 |_args: &[Value]| -> Result<Value, minijinja::Error> { Ok(Value::from("")) },
253 );
254
255 env.add_global("adapter", Value::from("__dbt_adapter__"));
257 env.add_global("exceptions", Value::from("__dbt_exceptions__"));
258 env.add_global("api", Value::from("__dbt_api__"));
259 env.add_global("graph", Value::from("__dbt_graph__"));
260 env.add_global("target", Value::from("__dbt_target__"));
261 env.add_global("invocation_id", Value::from("__dbt_invocation_id__"));
262 env.add_global("run_started_at", Value::from("2025-01-01T00:00:00Z"));
263 env.add_global("flags", Value::from("__dbt_flags__"));
264 env.add_global("modules", Value::from("__dbt_modules__"));
265 env.add_global("dbt_version", Value::from("1.0.0"));
266 env.add_global("model", Value::from("__dbt_model__"));
267 env.add_global("execute", Value::from(true));
268
269 let render_result = env.render_str(sql, ());
270 drop(env);
271
272 match render_result {
273 Ok(_) => {
274 let result = Arc::try_unwrap(extraction)
275 .expect("single owner")
276 .into_inner()
277 .unwrap_or_else(|e| e.into_inner());
278 Some(result)
279 }
280 Err(_) => None,
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn test_simple_ref() {
290 let sql = "SELECT * FROM {{ ref('stg_orders') }}";
291 let ext = extract_via_jinja(sql, "").unwrap();
292 assert_eq!(ext.refs.len(), 1);
293 assert_eq!(ext.refs[0].name, "stg_orders");
294 assert!(ext.refs[0].package.is_none());
295 }
296
297 #[test]
298 fn test_two_arg_ref() {
299 let sql = "SELECT * FROM {{ ref('other_pkg', 'stg_orders') }}";
300 let ext = extract_via_jinja(sql, "").unwrap();
301 assert_eq!(ext.refs.len(), 1);
302 assert_eq!(ext.refs[0].package.as_deref(), Some("other_pkg"));
303 assert_eq!(ext.refs[0].name, "stg_orders");
304 }
305
306 #[test]
307 fn test_source() {
308 let sql = "SELECT * FROM {{ source('raw', 'orders') }}";
309 let ext = extract_via_jinja(sql, "").unwrap();
310 assert_eq!(ext.sources.len(), 1);
311 assert_eq!(ext.sources[0].source_name, "raw");
312 assert_eq!(ext.sources[0].table_name, "orders");
313 }
314
315 #[test]
316 fn test_config() {
317 let sql = "{{ config(materialized='incremental', tags=['nightly', 'finance']) }}\nSELECT 1";
318 let ext = extract_via_jinja(sql, "").unwrap();
319 assert_eq!(ext.config.materialized.as_deref(), Some("incremental"));
320 assert_eq!(ext.config.tags, vec!["nightly", "finance"]);
321 }
322
323 #[test]
324 fn test_mixed() {
325 let sql = r#"
326 {{ config(materialized='table') }}
327 SELECT
328 o.*,
329 c.name
330 FROM {{ ref('stg_orders') }} o
331 JOIN {{ source('raw', 'customers') }} c ON o.customer_id = c.id
332 "#;
333 let ext = extract_via_jinja(sql, "").unwrap();
334 assert_eq!(ext.refs.len(), 1);
335 assert_eq!(ext.sources.len(), 1);
336 assert_eq!(ext.config.materialized.as_deref(), Some("table"));
337 }
338
339 #[test]
340 fn test_ref_inside_set() {
341 let sql = r#"
342 {% set orders = ref('stg_orders') %}
343 SELECT * FROM {{ orders }}
344 "#;
345 let ext = extract_via_jinja(sql, "").unwrap();
346 assert_eq!(ext.refs.len(), 1);
347 assert_eq!(ext.refs[0].name, "stg_orders");
348 }
349
350 #[test]
351 fn test_is_incremental_both_branches() {
352 let sql = r#"
353 {% if is_incremental() %}
354 SELECT * FROM {{ ref('stg_incremental_orders') }}
355 WHERE updated_at > (SELECT max(updated_at) FROM {{ this }})
356 {% else %}
357 SELECT * FROM {{ ref('stg_full_orders') }}
358 {% endif %}
359 "#;
360 let ext = extract_via_jinja(sql, "").unwrap();
361 assert_eq!(ext.refs.len(), 2);
363 assert!(ext.refs.iter().any(|r| r.name == "stg_full_orders"));
364 assert!(ext.refs.iter().any(|r| r.name == "stg_incremental_orders"));
365 }
366
367 #[test]
368 fn test_jinja_comment_ignored() {
369 let sql = r#"
370 {# This is a comment with {{ ref('should_be_ignored') }} #}
371 SELECT * FROM {{ ref('actual_model') }}
372 "#;
373 let ext = extract_via_jinja(sql, "").unwrap();
374 assert_eq!(ext.refs.len(), 1);
375 assert_eq!(ext.refs[0].name, "actual_model");
376 }
377
378 #[test]
379 fn test_whitespace_control() {
380 let sql = "SELECT * FROM {{- ref('stg_orders') -}}";
381 let ext = extract_via_jinja(sql, "").unwrap();
382 assert_eq!(ext.refs.len(), 1);
383 assert_eq!(ext.refs[0].name, "stg_orders");
384 }
385
386 #[test]
387 fn test_var_with_default() {
388 let sql = "SELECT * FROM {{ ref('model_' ~ var('suffix', 'default')) }}";
389 let ext = extract_via_jinja(sql, "").unwrap();
390 assert_eq!(ext.refs.len(), 1);
391 assert_eq!(ext.refs[0].name, "model_default");
392 }
393
394 #[test]
395 fn test_var_resolved_from_project_vars() {
396 let sql = "SELECT * FROM {{ ref('model_' ~ var('suffix')) }}";
397 let mut vars = HashMap::new();
398 vars.insert(
399 "suffix".to_string(),
400 serde_json::Value::String("prod".to_string()),
401 );
402 let ext = extract_via_jinja_with_vars(sql, "", &vars).unwrap();
403 assert_eq!(ext.refs.len(), 1);
404 assert_eq!(ext.refs[0].name, "model_prod");
405 }
406
407 #[test]
408 fn test_var_list_expansion_in_for_loop() {
409 let sql = r#"
412 {%- set categories = var("product_categories") -%}
413 {%- for cat in categories -%}
414 SELECT * FROM {{ ref('stg_' ~ cat ~ '_summary') }}
415 {% if not loop.last %}UNION ALL{% endif %}
416 {% endfor -%}
417 "#;
418 let mut vars = HashMap::new();
419 vars.insert(
420 "product_categories".to_string(),
421 serde_json::json!(["electronics", "clothing"]),
422 );
423 let ext = extract_via_jinja_with_vars(sql, "", &vars).unwrap();
424 assert_eq!(ext.refs.len(), 2);
425 assert!(ext.refs.iter().any(|r| r.name == "stg_electronics_summary"));
426 assert!(ext.refs.iter().any(|r| r.name == "stg_clothing_summary"));
427 }
428
429 #[test]
430 fn test_var_project_overrides_default() {
431 let sql = "SELECT * FROM {{ ref('model_' ~ var('env', 'dev')) }}";
434 let mut vars = HashMap::new();
435 vars.insert(
436 "env".to_string(),
437 serde_json::Value::String("staging".to_string()),
438 );
439 let ext = extract_via_jinja_with_vars(sql, "", &vars).unwrap();
440 assert_eq!(ext.refs.len(), 1);
441 assert_eq!(ext.refs[0].name, "model_staging");
442 }
443
444 #[test]
445 fn test_var_unknown_falls_back_to_default() {
446 let sql = "SELECT * FROM {{ ref('model_' ~ var('missing', 'fallback')) }}";
448 let mut vars = HashMap::new();
449 vars.insert(
450 "other_var".to_string(),
451 serde_json::Value::String("unused".to_string()),
452 );
453 let ext = extract_via_jinja_with_vars(sql, "", &vars).unwrap();
454 assert_eq!(ext.refs.len(), 1);
455 assert_eq!(ext.refs[0].name, "model_fallback");
456 }
457
458 #[test]
459 fn test_for_loop_with_refs() {
460 let sql = r#"
461 {% for src in ['orders', 'customers'] %}
462 SELECT * FROM {{ source('raw', src) }}
463 {% if not loop.last %}UNION ALL{% endif %}
464 {% endfor %}
465 "#;
466 let ext = extract_via_jinja(sql, "").unwrap();
467 assert_eq!(ext.sources.len(), 2);
468 assert_eq!(ext.sources[0].source_name, "raw");
469 assert_eq!(ext.sources[0].table_name, "orders");
470 assert_eq!(ext.sources[1].source_name, "raw");
471 assert_eq!(ext.sources[1].table_name, "customers");
472 }
473
474 #[test]
475 fn test_config_with_extra_kwargs() {
476 let sql = "{{ config(materialized='incremental', schema='analytics', unique_key='id', tags=['nightly']) }}\nSELECT 1";
477 let ext = extract_via_jinja(sql, "").unwrap();
478 assert_eq!(ext.config.materialized.as_deref(), Some("incremental"));
479 assert_eq!(ext.config.tags, vec!["nightly"]);
480 }
481
482 #[test]
483 fn test_returns_none_on_unsupported_template() {
484 let sql = "{% materialization table, default %} SELECT 1 {% endmaterialization %}";
486 let result = extract_via_jinja(sql, "");
487 assert!(result.is_none());
488 }
489
490 #[test]
491 fn test_macro_ref_extraction() {
492 let macro_src = r#"
493 {% macro my_cte() %}
494 SELECT * FROM {{ ref('base_model') }}
495 {% endmacro %}
496 "#;
497 let sql = "SELECT * FROM ({{ my_cte() }})";
498 let ext = extract_via_jinja(sql, macro_src).unwrap();
499 assert_eq!(ext.refs.len(), 1);
500 assert_eq!(ext.refs[0].name, "base_model");
501 }
502
503 #[test]
504 fn test_macro_source_extraction() {
505 let macro_src = r#"
506 {% macro raw_data(table) %}
507 SELECT * FROM {{ source('raw', table) }}
508 {% endmacro %}
509 "#;
510 let sql = "SELECT * FROM ({{ raw_data('orders') }})";
511 let ext = extract_via_jinja(sql, macro_src).unwrap();
512 assert_eq!(ext.sources.len(), 1);
513 assert_eq!(ext.sources[0].source_name, "raw");
514 assert_eq!(ext.sources[0].table_name, "orders");
515 }
516
517 #[test]
518 fn test_macro_with_multiple_refs() {
519 let macro_src = r#"
520 {% macro join_tables(period) %}
521 SELECT * FROM {{ ref('deals') }}
522 LEFT JOIN {{ ref('providers') }} ON 1=1
523 LEFT JOIN {{ source('raw', 'prices') }} ON 1=1
524 {% endmacro %}
525 "#;
526 let sql = "{{ join_tables('day') }}";
527 let ext = extract_via_jinja(sql, macro_src).unwrap();
528 assert_eq!(ext.refs.len(), 2);
529 assert!(ext.refs.iter().any(|r| r.name == "deals"));
530 assert!(ext.refs.iter().any(|r| r.name == "providers"));
531 assert_eq!(ext.sources.len(), 1);
532 assert_eq!(ext.sources[0].table_name, "prices");
533 }
534
535 #[test]
536 fn test_multiple_macro_files() {
537 let sources = vec![
538 r#"
539 {% macro get_orders() %}
540 SELECT * FROM {{ ref('stg_orders') }}
541 {% endmacro %}
542 "#
543 .to_string(),
544 r#"
545 {% macro get_customers() %}
546 SELECT * FROM {{ ref('stg_customers') }}
547 {% endmacro %}
548 "#
549 .to_string(),
550 ];
551 let prefix = build_macro_prefix(&sources);
552 let sql = "{{ get_orders() }} UNION ALL {{ get_customers() }}";
553 let ext = extract_via_jinja(sql, &prefix).unwrap();
554 assert_eq!(ext.refs.len(), 2);
555 assert!(ext.refs.iter().any(|r| r.name == "stg_orders"));
556 assert!(ext.refs.iter().any(|r| r.name == "stg_customers"));
557 }
558
559 #[test]
560 fn test_build_macro_prefix_skips_invalid() {
561 let sources = vec![
562 "{% macro good() %}SELECT 1{% endmacro %}".to_string(),
563 "{% materialization custom %} stuff {% endmaterialization %}".to_string(),
565 "{% macro also_good() %}SELECT 2{% endmacro %}".to_string(),
566 "{% raw %}unclosed raw content".to_string(),
568 ];
569 let prefix = build_macro_prefix(&sources);
570 assert!(prefix.contains("{% macro good() %}"));
571 assert!(prefix.contains("{% macro also_good() %}"));
572 assert!(!prefix.contains("materialization"));
573 assert!(!prefix.contains("{% raw %}"));
574 }
575
576 #[test]
577 fn test_build_macro_prefix_includes_compatible_macros() {
578 let env = Environment::new();
579
580 let macro_a = "{% macro a() %}ok{% endmacro %}".to_string();
581 let macro_b = "{% macro b() %}ok{% endmacro %}".to_string();
582 assert!(env.template_from_str(¯o_a).is_ok());
583 assert!(env.template_from_str(¯o_b).is_ok());
584
585 let sources = vec![macro_a, macro_b];
586 let prefix = build_macro_prefix(&sources);
587 assert!(prefix.contains("{% macro a() %}"));
588 assert!(prefix.contains("{% macro b() %}"));
589 }
590
591 #[test]
592 fn test_invalid_macro_skipped_refs_still_extracted() {
593 let sources = vec![
594 "{% materialization custom %} stuff {% endmaterialization %}".to_string(),
596 ];
597 let prefix = build_macro_prefix(&sources);
598 let sql = "SELECT * FROM {{ ref('orders') }}";
599 let ext = extract_via_jinja(sql, &prefix).unwrap();
600 assert_eq!(ext.refs.len(), 1);
601 assert_eq!(ext.refs[0].name, "orders");
602 }
603}