use minijinja::{Environment, Value};
use std::collections::HashMap;
use std::sync::Arc;
const DEFAULT_DBT_PACKAGES: &[&str] = &[
"dbt_utils",
"dbt_expectations",
"dbt_date",
"audit_helper",
"codegen",
"metrics",
"elementary",
"fivetran_utils",
];
#[derive(Debug, Clone, PartialEq, Eq)]
struct RelationEmulator {
database: Option<String>,
schema: Option<String>,
identifier: String,
}
impl RelationEmulator {
fn new(identifier: impl Into<String>) -> Self {
Self {
database: None,
schema: None,
identifier: identifier.into(),
}
}
fn with_schema(schema: impl Into<String>, identifier: impl Into<String>) -> Self {
Self {
database: None,
schema: Some(schema.into()),
identifier: identifier.into(),
}
}
fn with_database(
database: impl Into<String>,
schema: impl Into<String>,
identifier: impl Into<String>,
) -> Self {
Self {
database: Some(database.into()),
schema: Some(schema.into()),
identifier: identifier.into(),
}
}
}
impl std::fmt::Display for RelationEmulator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match (&self.database, &self.schema) {
(Some(db), Some(schema)) => write!(f, "{}.{}.{}", db, schema, self.identifier),
(None, Some(schema)) => write!(f, "{}.{}", schema, self.identifier),
_ => write!(f, "{}", self.identifier),
}
}
}
impl minijinja::value::Object for RelationEmulator {
fn get_value(self: &Arc<Self>, key: &Value) -> Option<Value> {
match key.as_str()? {
"database" => Some(
self.database
.as_ref()
.map(|s| Value::from(s.as_str()))
.unwrap_or(Value::UNDEFINED),
),
"schema" => Some(
self.schema
.as_ref()
.map(|s| Value::from(s.as_str()))
.unwrap_or(Value::UNDEFINED),
),
"identifier" | "name" | "table" => Some(Value::from(self.identifier.as_str())),
"is_table" | "is_view" | "is_cte" => Some(Value::from(true)),
_ => None,
}
}
fn call_method(
self: &Arc<Self>,
_state: &minijinja::State,
method: &str,
_args: &[Value],
) -> Result<Value, minijinja::Error> {
match method {
"render" => Ok(Value::from(self.to_string())),
"quote" | "include" | "exclude" | "replace" => Ok(Value::from_object((**self).clone())),
_ => Ok(Value::from_object((**self).clone())),
}
}
fn render(self: &Arc<Self>, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self, f)
}
}
pub(crate) fn register_dbt_builtins(
env: &mut Environment,
context: &HashMap<String, serde_json::Value>,
) {
let vars = extract_vars(context);
env.add_function("ref", |args: &[Value]| -> Result<Value, minijinja::Error> {
match args.len() {
1 => {
let model = args[0].as_str().unwrap_or("model");
Ok(Value::from_object(RelationEmulator::new(model)))
}
2 => {
let project = args[0].as_str().unwrap_or("project");
let model = args[1].as_str().unwrap_or("model");
Ok(Value::from_object(RelationEmulator::with_schema(
project, model,
)))
}
_ => Err(minijinja::Error::new(
minijinja::ErrorKind::InvalidOperation,
"ref() expects 1 or 2 arguments",
)),
}
});
env.add_function(
"source",
|schema: Value, table: Value| -> Result<Value, minijinja::Error> {
let schema_str = schema.as_str().unwrap_or("schema");
let table_str = table.as_str().unwrap_or("table");
Ok(Value::from_object(RelationEmulator::with_schema(
schema_str, table_str,
)))
},
);
env.add_function("config", |_args: &[Value]| -> Value { Value::from("") });
let vars_clone = vars.clone();
env.add_function(
"var",
move |args: &[Value]| -> Result<Value, minijinja::Error> {
match args.len() {
1 => {
let name = args[0].as_str().unwrap_or("");
match vars_clone.get(name) {
Some(v) => Ok(v.clone()),
None => Ok(Value::from(name)), }
}
2 => {
let name = args[0].as_str().unwrap_or("");
let default = &args[1];
match vars_clone.get(name) {
Some(v) => Ok(v.clone()),
None => Ok(default.clone()),
}
}
_ => Err(minijinja::Error::new(
minijinja::ErrorKind::InvalidOperation,
"var() expects 1 or 2 arguments",
)),
}
},
);
env.add_function("is_incremental", || -> Value { Value::from(false) });
let this_value = extract_this_relation(context);
env.add_global("this", this_value);
env.add_global("execute", Value::from(false));
let env_vars = extract_env_vars(context);
env.add_function(
"env_var",
move |args: &[Value]| -> Result<Value, minijinja::Error> {
match args.len() {
1 => {
let name = args[0].as_str().unwrap_or("");
match env_vars.get(name) {
Some(v) => Ok(v.clone()),
None => Ok(Value::from(format!("__ENV_VAR_{name}__"))),
}
}
2 => {
let name = args[0].as_str().unwrap_or("");
let default = &args[1];
match env_vars.get(name) {
Some(v) => Ok(v.clone()),
None => Ok(default.clone()),
}
}
_ => Err(minijinja::Error::new(
minijinja::ErrorKind::InvalidOperation,
"env_var() expects 1 or 2 arguments",
)),
}
},
);
env.add_function("run_query", |_sql: Value| -> Value {
Value::from(Vec::<Value>::new())
});
env.add_function("zip", |args: &[Value]| -> Result<Value, minijinja::Error> {
zip_impl(args, false)
});
env.add_function(
"zip_strict",
|args: &[Value]| -> Result<Value, minijinja::Error> { zip_impl(args, true) },
);
for package in DEFAULT_DBT_PACKAGES {
env.add_global(
*package,
Value::from_object(PassthroughNamespace::new(package)),
);
}
if let Some(serde_json::Value::Array(packages)) = context.get("dbt_packages") {
for pkg in packages {
if let Some(name) = pkg.as_str() {
if !DEFAULT_DBT_PACKAGES.contains(&name) {
env.add_global(
name.to_string(),
Value::from_object(PassthroughNamespace::new(name)),
);
}
}
}
}
}
#[derive(Debug)]
struct PassthroughNamespace {
name: String,
}
impl PassthroughNamespace {
fn new(name: &str) -> Self {
Self {
name: name.to_string(),
}
}
}
impl std::fmt::Display for PassthroughNamespace {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "__{}_namespace__", self.name)
}
}
pub(super) fn passthrough_arg_to_string(value: &Value) -> Option<String> {
if let Some(s) = value.as_str() {
return Some(s.to_string());
}
if value.is_undefined() || value.is_none() {
return None;
}
if value.as_object().is_some() {
return Some(value.to_string());
}
if value.is_number() {
return Some(value.to_string());
}
if matches!(value.kind(), minijinja::value::ValueKind::Bool) {
return Some(value.to_string());
}
None
}
const MAX_ZIP_ARGS: usize = 100;
const MAX_ZIP_SEQUENCE_LENGTH: usize = 10_000;
fn zip_impl(args: &[Value], strict: bool) -> Result<Value, minijinja::Error> {
if args.is_empty() {
return Ok(Value::from(Vec::<Value>::new()));
}
if args.len() > MAX_ZIP_ARGS {
return Err(minijinja::Error::new(
minijinja::ErrorKind::InvalidOperation,
format!(
"zip: too many sequences ({}, max: {})",
args.len(),
MAX_ZIP_ARGS
),
));
}
let sequences: Vec<Vec<Value>> = args
.iter()
.map(|v| {
let seq: Vec<Value> = v.try_iter()?.collect();
if seq.len() > MAX_ZIP_SEQUENCE_LENGTH {
return Err(minijinja::Error::new(
minijinja::ErrorKind::InvalidOperation,
format!(
"zip: sequence too long ({} elements, max: {})",
seq.len(),
MAX_ZIP_SEQUENCE_LENGTH
),
));
}
Ok(seq)
})
.collect::<Result<_, _>>()?;
if sequences.is_empty() {
return Ok(Value::from(Vec::<Value>::new()));
}
let result_len = if strict {
let first_len = sequences[0].len();
for (i, seq) in sequences.iter().enumerate().skip(1) {
if seq.len() != first_len {
return Err(minijinja::Error::new(
minijinja::ErrorKind::InvalidOperation,
format!(
"zip_strict: argument {} has length {} but argument 0 has length {}",
i,
seq.len(),
first_len
),
));
}
}
first_len
} else {
sequences.iter().map(|s| s.len()).min().unwrap_or(0)
};
let result: Vec<Value> = (0..result_len)
.map(|i| Value::from(sequences.iter().map(|s| s[i].clone()).collect::<Vec<_>>()))
.collect();
Ok(Value::from(result))
}
impl minijinja::value::Object for PassthroughNamespace {
fn call_method(
self: &std::sync::Arc<Self>,
_state: &minijinja::State,
method: &str,
args: &[Value],
) -> Result<Value, minijinja::Error> {
if let Some(first) = args.first() {
if let Some(rendered) = passthrough_arg_to_string(first) {
return Ok(Value::from(rendered));
}
}
Ok(Value::from(format!("__{}_{method}__", self.name)))
}
}
fn extract_context_object(
context: &HashMap<String, serde_json::Value>,
key: &str,
) -> HashMap<String, Value> {
context
.get(key)
.and_then(|v| v.as_object())
.map(|obj| {
obj.iter()
.map(|(k, v)| (k.clone(), Value::from_serialize(v)))
.collect()
})
.unwrap_or_default()
}
fn extract_vars(context: &HashMap<String, serde_json::Value>) -> HashMap<String, Value> {
extract_context_object(context, "vars")
}
fn extract_env_vars(context: &HashMap<String, serde_json::Value>) -> HashMap<String, Value> {
extract_context_object(context, "env_vars")
}
fn extract_this_relation(context: &HashMap<String, serde_json::Value>) -> Value {
let model_name = context
.get("model_name")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
match model_name {
Some(name) => {
let schema = context
.get("schema")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let database = context
.get("database")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let relation = match (database, schema) {
(Some(db), Some(sch)) => RelationEmulator::with_database(db, sch, name),
(None, Some(sch)) => RelationEmulator::with_schema(sch, name),
_ => RelationEmulator::new(name),
};
Value::from_object(relation)
}
None => Value::UNDEFINED,
}
}
#[cfg(test)]
mod tests {
use super::super::jinja::render_dbt;
use std::collections::HashMap;
#[test]
fn ref_single_arg() {
let ctx = HashMap::new();
let result = render_dbt("SELECT * FROM {{ ref('users') }}", &ctx).unwrap();
assert_eq!(result, "SELECT * FROM users");
}
#[test]
fn ref_two_args() {
let ctx = HashMap::new();
let result = render_dbt("SELECT * FROM {{ ref('analytics', 'users') }}", &ctx).unwrap();
assert_eq!(result, "SELECT * FROM analytics.users");
}
#[test]
fn source_macro() {
let ctx = HashMap::new();
let result = render_dbt("SELECT * FROM {{ source('raw', 'events') }}", &ctx).unwrap();
assert_eq!(result, "SELECT * FROM raw.events");
}
#[test]
fn config_macro_returns_empty() {
let ctx = HashMap::new();
let result = render_dbt(
"{{ config(materialized='table') }}SELECT * FROM users",
&ctx,
)
.unwrap();
assert_eq!(result, "SELECT * FROM users");
}
#[test]
fn var_with_default() {
let ctx = HashMap::new();
let result = render_dbt("SELECT * FROM {{ var('schema', 'public') }}.users", &ctx).unwrap();
assert_eq!(result, "SELECT * FROM public.users");
}
#[test]
fn var_from_context() {
let mut ctx = HashMap::new();
ctx.insert(
"vars".to_string(),
serde_json::json!({ "schema": "analytics" }),
);
let result = render_dbt("SELECT * FROM {{ var('schema', 'public') }}.users", &ctx).unwrap();
assert_eq!(result, "SELECT * FROM analytics.users");
}
#[test]
fn is_incremental_returns_false() {
let ctx = HashMap::new();
let template = r#"{% if is_incremental() %}WHERE updated_at > (SELECT MAX(updated_at) FROM {{ this }}){% endif %}SELECT * FROM users"#;
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "SELECT * FROM users");
}
#[test]
fn complex_dbt_model() {
let ctx = HashMap::new();
let template = r#"{{ config(materialized='incremental') }}
SELECT
id,
name,
created_at
FROM {{ ref('stg_users') }}
{% if is_incremental() %}
WHERE created_at > (SELECT MAX(created_at) FROM {{ this }})
{% endif %}"#;
let result = render_dbt(template, &ctx).unwrap();
assert!(result.contains("FROM stg_users"));
assert!(!result.contains("is_incremental"));
}
#[test]
fn custom_dbt_package_from_context() {
let mut ctx = HashMap::new();
ctx.insert(
"dbt_packages".to_string(),
serde_json::json!(["my_custom_pkg"]),
);
let template = "SELECT {{ my_custom_pkg.generate_column('user_id') }} FROM users";
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "SELECT user_id FROM users");
}
#[test]
fn default_dbt_utils_package() {
let ctx = HashMap::new();
let template = "SELECT {{ dbt_utils.star('users') }} FROM users";
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "SELECT users FROM users");
}
#[test]
fn dbt_utils_relation_argument_passthrough() {
let ctx = HashMap::new();
let template = "SELECT {{ dbt_utils.star(ref('orders')) }} FROM dual";
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "SELECT orders FROM dual");
}
#[test]
fn stubbed_custom_macro_preserves_relation_argument() {
let ctx = HashMap::new();
let template = "SELECT {{ custom_macro(ref('users')) }}";
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "SELECT users");
}
#[test]
fn ref_returns_relation_with_attribute_access() {
let ctx = HashMap::new();
let template = "SELECT * FROM {{ ref('users').identifier }}";
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "SELECT * FROM users");
}
#[test]
fn source_returns_relation_with_schema_attribute() {
let ctx = HashMap::new();
let template = "SELECT '{{ source('raw', 'events').schema }}' as schema_name FROM {{ source('raw', 'events') }}";
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "SELECT 'raw' as schema_name FROM raw.events");
}
#[test]
fn ref_relation_include_method() {
let ctx = HashMap::new();
let template = "SELECT * FROM {{ ref('users').include() }}";
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "SELECT * FROM users");
}
#[test]
fn ref_relation_quote_method() {
let ctx = HashMap::new();
let template = "SELECT * FROM {{ ref('users').quote(identifier=False).identifier }}";
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "SELECT * FROM users");
}
#[test]
fn this_undefined_without_model_name() {
let ctx = HashMap::new();
let template = "SELECT '{{ this }}' as this_value FROM users";
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "SELECT '' as this_value FROM users");
}
#[test]
fn this_with_model_name() {
let mut ctx = HashMap::new();
ctx.insert("model_name".to_string(), serde_json::json!("orders"));
let template = "SELECT * FROM {{ this }}";
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "SELECT * FROM orders");
}
#[test]
fn this_with_model_name_and_schema() {
let mut ctx = HashMap::new();
ctx.insert("model_name".to_string(), serde_json::json!("orders"));
ctx.insert("schema".to_string(), serde_json::json!("analytics"));
let template = "SELECT * FROM {{ this }}";
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "SELECT * FROM analytics.orders");
}
#[test]
fn this_with_full_context() {
let mut ctx = HashMap::new();
ctx.insert("model_name".to_string(), serde_json::json!("orders"));
ctx.insert("schema".to_string(), serde_json::json!("analytics"));
ctx.insert("database".to_string(), serde_json::json!("warehouse"));
let template = "SELECT * FROM {{ this }}";
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "SELECT * FROM warehouse.analytics.orders");
}
#[test]
fn this_attribute_access() {
let mut ctx = HashMap::new();
ctx.insert("model_name".to_string(), serde_json::json!("orders"));
ctx.insert("schema".to_string(), serde_json::json!("analytics"));
let template = "SELECT '{{ this.schema }}' as schema, '{{ this.identifier }}' as table_name FROM users";
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(
result,
"SELECT 'analytics' as schema, 'orders' as table_name FROM users"
);
}
#[test]
fn execute_flag_is_false() {
let ctx = HashMap::new();
let template = "{% if execute %}RUN THIS{% else %}SKIP THIS{% endif %}";
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "SKIP THIS");
}
#[test]
fn env_var_with_default() {
let ctx = HashMap::new();
let template = "SELECT '{{ env_var('DB_HOST', 'localhost') }}' as host FROM dual";
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "SELECT 'localhost' as host FROM dual");
}
#[test]
fn env_var_from_context() {
let mut ctx = HashMap::new();
ctx.insert(
"env_vars".to_string(),
serde_json::json!({ "DB_HOST": "prod-db.example.com" }),
);
let template = "SELECT '{{ env_var('DB_HOST', 'localhost') }}' as host FROM dual";
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "SELECT 'prod-db.example.com' as host FROM dual");
}
#[test]
fn env_var_without_default() {
let ctx = HashMap::new();
let template = "SELECT '{{ env_var('UNDEFINED_VAR') }}' as value FROM dual";
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(
result,
"SELECT '__ENV_VAR_UNDEFINED_VAR__' as value FROM dual"
);
}
#[test]
fn run_query_returns_empty_iterable() {
let ctx = HashMap::new();
let template = r#"{% for row in run_query("SELECT 1") %}{{ row }}{% endfor %}SELECT 1"#;
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "SELECT 1");
}
#[test]
fn run_query_with_execute_check() {
let ctx = HashMap::new();
let template = r#"{% if execute %}{% set results = run_query("SELECT 1") %}{% endif %}SELECT * FROM users"#;
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "SELECT * FROM users");
}
#[test]
fn zip_two_lists() {
let ctx = HashMap::new();
let template = r#"{% for a, b in zip(['x', 'y'], [1, 2]) %}{{ a }}{{ b }}{% if not loop.last %},{% endif %}{% endfor %}"#;
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "x1,y2");
}
#[test]
fn zip_three_lists() {
let ctx = HashMap::new();
let template = r#"{% for a, b, c in zip(['x', 'y'], [1, 2], ['!', '?']) %}{{ a }}{{ b }}{{ c }}{% if not loop.last %},{% endif %}{% endfor %}"#;
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "x1!,y2?");
}
#[test]
fn zip_unequal_lengths_truncates() {
let ctx = HashMap::new();
let template = r#"{% for a, b in zip(['x', 'y', 'z'], [1, 2]) %}{{ a }}{{ b }}{% if not loop.last %},{% endif %}{% endfor %}"#;
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "x1,y2");
}
#[test]
fn zip_empty_list() {
let ctx = HashMap::new();
let template = r#"{% for a, b in zip([], [1, 2]) %}{{ a }}{{ b }}{% endfor %}DONE"#;
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "DONE");
}
#[test]
fn zip_strict_equal_lengths() {
let ctx = HashMap::new();
let template = r#"{% for a, b in zip_strict(['x', 'y'], [1, 2]) %}{{ a }}{{ b }}{% if not loop.last %},{% endif %}{% endfor %}"#;
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "x1,y2");
}
#[test]
fn zip_strict_unequal_lengths_errors() {
let ctx = HashMap::new();
let template =
r#"{% for a, b in zip_strict(['x', 'y', 'z'], [1, 2]) %}{{ a }}{{ b }}{% endfor %}"#;
let result = render_dbt(template, &ctx);
assert!(
result.is_err(),
"zip_strict with unequal lengths should error"
);
let err = result.unwrap_err();
assert!(
err.to_string().contains("zip_strict"),
"Error should mention zip_strict: {}",
err
);
}
#[test]
fn loop_first_variable() {
let ctx = HashMap::new();
let template = r#"{% for item in ['a', 'b', 'c'] %}{% if loop.first %}FIRST:{% endif %}{{ item }}{% if not loop.last %},{% endif %}{% endfor %}"#;
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "FIRST:a,b,c");
}
#[test]
fn loop_index_variables() {
let ctx = HashMap::new();
let template = r#"{% for item in ['a', 'b'] %}{{ loop.index }}:{{ loop.index0 }}:{{ item }}{% if not loop.last %},{% endif %}{% endfor %}"#;
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "1:0:a,2:1:b");
}
#[test]
fn nested_loops_with_conditionals() {
let ctx = HashMap::new();
let template = r#"{% for outer in ['X', 'Y'] %}{% for inner in [1, 2] %}{% if loop.first %}[{% endif %}{{ outer }}{{ inner }}{% if loop.last %}]{% endif %}{% endfor %}{% endfor %}"#;
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "[X1X2][Y1Y2]");
}
#[test]
fn whitespace_control_tags() {
let ctx = HashMap::new();
let template = "SELECT\n {%- for col in ['a', 'b'] %}\n {{ col }}{%- if not loop.last %},{% endif %}\n {%- endfor %}\nFROM t";
let result = render_dbt(template, &ctx).unwrap();
assert!(
result.contains("a,"),
"Should have 'a,' without extra whitespace: {}",
result
);
assert!(
!result.contains("\n\n\n"),
"Should not have multiple blank lines: {}",
result
);
}
#[test]
fn raw_block_preserves_syntax() {
let ctx = HashMap::new();
let template = r#"{% raw %}{{ this_is_literal }}{% endraw %} and {{ ref('real') }}"#;
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "{{ this_is_literal }} and real");
}
#[test]
fn multi_variable_assignment_with() {
let ctx = HashMap::new();
let template = r#"{% with x = 'hello', y = 'world' %}{{ x }} {{ y }}{% endwith %}"#;
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "hello world");
}
#[test]
fn inline_string_with_newlines() {
let ctx = HashMap::new();
let template = r#"{{ "line1\nline2" }}"#;
let result = render_dbt(template, &ctx).unwrap();
assert_eq!(result, "line1\nline2");
}
}