use minijinja::Environment;
use serde_json::Value;
use crate::error::Result;
use crate::interceptor::{Interceptor, InterceptorFactory, InterceptorRef};
use crate::SqlnessError;
pub struct TemplateInterceptorFactory;
pub const PREFIX: &str = "TEMPLATE";
pub const DELIMITER: &str = "__sqlness_delimiter__";
#[derive(Debug)]
pub struct TemplateInterceptor {
data_bindings: Value,
}
fn sql_delimiter() -> std::result::Result<String, minijinja::Error> {
Ok(DELIMITER.to_string())
}
impl Interceptor for TemplateInterceptor {
fn before_execute(&self, execute_query: &mut Vec<String>, _context: &mut crate::QueryContext) {
let input = execute_query.join("\n");
let mut env = Environment::new();
env.add_function("sql_delimiter", sql_delimiter);
env.add_template("sql", &input).unwrap();
let tmpl = env.get_template("sql").unwrap();
let rendered = tmpl.render(&self.data_bindings).unwrap();
*execute_query = rendered
.split('\n')
.map(|v| v.to_string())
.collect::<Vec<_>>();
}
}
impl InterceptorFactory for TemplateInterceptorFactory {
fn try_new(&self, ctx: &str) -> Result<InterceptorRef> {
let data_bindings = if ctx.is_empty() {
serde_json::from_str("{}")
} else {
serde_json::from_str(ctx)
}
.map_err(|e| SqlnessError::InvalidContext {
prefix: PREFIX.to_string(),
msg: format!("Expect json, err:{e}"),
})?;
Ok(Box::new(TemplateInterceptor { data_bindings }))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_template() {
let interceptor = TemplateInterceptorFactory
.try_new(r#"{"name": "test"}"#)
.unwrap();
let mut input = vec!["SELECT * FROM table where name = '{{name}}'".to_string()];
interceptor.before_execute(&mut input, &mut crate::QueryContext::default());
assert_eq!(input, vec!["SELECT * FROM table where name = 'test'"]);
}
#[test]
fn vector_template() {
let interceptor = TemplateInterceptorFactory
.try_new(r#"{"aggr": ["sum", "count", "avg"]}"#)
.unwrap();
let mut input = [
"{%- for item in aggr %}",
"SELECT {{item}}(c) from t;",
"{%- endfor %}",
]
.map(|v| v.to_string())
.to_vec();
interceptor.before_execute(&mut input, &mut crate::QueryContext::default());
assert_eq!(
input,
[
"",
"SELECT sum(c) from t;",
"SELECT count(c) from t;",
"SELECT avg(c) from t;"
]
.map(|v| v.to_string())
.to_vec()
);
}
#[test]
fn range_template() {
let interceptor = TemplateInterceptorFactory.try_new(r#""#).unwrap();
let mut input = [
"INSERT INTO t (c) VALUES",
"{%- for num in range(1, 5) %}",
"({{ num }}){%if not loop.last %}, {% endif %}",
"{%- endfor %}",
";",
]
.map(|v| v.to_string())
.to_vec();
interceptor.before_execute(&mut input, &mut crate::QueryContext::default());
assert_eq!(
input,
[
"INSERT INTO t (c) VALUES",
"(1), ",
"(2), ",
"(3), ",
"(4)",
";"
]
.map(|v| v.to_string())
.to_vec()
);
}
}