use crate::error::{Result, TemplateError};
use crate::renderer::TemplateRenderer;
use serde_json::Value;
use std::collections::HashMap;
use tera::{Filter, Function, Tera};
pub struct CustomFunction<F> {
name: String,
func: F,
}
impl<F> CustomFunction<F>
where
F: Fn(&HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
{
pub fn new(name: &str, func: F) -> Self {
Self {
name: name.to_string(),
func,
}
}
pub fn name(&self) -> &str {
&self.name
}
}
impl<F> Function for CustomFunction<F>
where
F: Fn(&HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
{
fn call(&self, args: &HashMap<String, Value>) -> tera::Result<Value> {
(self.func)(args).map_err(|e| tera::Error::msg(e.to_string()))
}
}
pub struct CustomFilter<F> {
name: String,
filter: F,
}
impl<F> CustomFilter<F>
where
F: Fn(&Value, &HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
{
pub fn new(name: &str, filter: F) -> Self {
Self {
name: name.to_string(),
filter,
}
}
pub fn name(&self) -> &str {
&self.name
}
}
impl<F> Filter for CustomFilter<F>
where
F: Fn(&Value, &HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
{
fn filter(&self, value: &Value, args: &HashMap<String, Value>) -> tera::Result<Value> {
(self.filter)(value, args).map_err(|e| tera::Error::msg(e.to_string()))
}
}
#[derive(Default)]
pub struct FunctionRegistry {
functions: Vec<Box<dyn Function + Send + Sync>>,
filters: Vec<Box<dyn Filter + Send + Sync>>,
}
impl FunctionRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn add_function<F>(mut self, func: CustomFunction<F>) -> Self
where
F: Fn(&HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
{
self.functions.push(Box::new(func));
self
}
pub fn add_filter<F>(mut self, filter: CustomFilter<F>) -> Self
where
F: Fn(&Value, &HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
{
self.filters.push(Box::new(filter));
self
}
pub fn register_all(&self, _tera: &mut Tera) -> Result<()> {
for _func in &self.functions {
}
for _filter in &self.filters {
}
Ok(())
}
pub fn function_count(&self) -> usize {
self.functions.len()
}
pub fn filter_count(&self) -> usize {
self.filters.len()
}
}
pub fn register_custom_function<F>(tera: &mut Tera, name: &str, func: F) -> Result<()>
where
F: Fn(&HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
{
let custom_func = CustomFunction::new(name, func);
tera.register_function(name, custom_func);
Ok(())
}
pub fn register_custom_filter<F>(tera: &mut Tera, name: &str, filter: F) -> Result<()>
where
F: Fn(&Value, &HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
{
let custom_filter = CustomFilter::new(name, filter);
tera.register_filter(name, custom_filter);
Ok(())
}
pub fn simple_string_function(
value: &str,
) -> impl Fn(&HashMap<String, Value>) -> Result<Value> + Send + Sync + '_ {
let value = value.to_string();
move |_| Ok(Value::String(value.clone()))
}
pub fn format_function(
format_str: &str,
) -> impl Fn(&HashMap<String, Value>) -> Result<Value> + Send + Sync + '_ {
let format_str = format_str.to_string();
move |args| {
let mut result = format_str.clone();
for (key, value) in args {
let placeholder = format!("{{{}}}", key);
let replacement = match value {
Value::String(s) => s.clone(),
Value::Number(n) => n.to_string(),
Value::Bool(b) => b.to_string(),
_ => value.to_string(),
};
result = result.replace(&placeholder, &replacement);
}
Ok(Value::String(result))
}
}
pub fn arithmetic_function(
operation: ArithmeticOp,
) -> impl Fn(&HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static {
move |args| {
let a = args.get("a").and_then(|v| v.as_f64()).unwrap_or(0.0);
let b = args.get("b").and_then(|v| v.as_f64()).unwrap_or(0.0);
let result = match operation {
ArithmeticOp::Add => a + b,
ArithmeticOp::Subtract => a - b,
ArithmeticOp::Multiply => a * b,
ArithmeticOp::Divide => {
if b == 0.0 {
return Err(TemplateError::ValidationError(
"Division by zero".to_string(),
));
}
a / b
}
};
Ok(Value::Number(
serde_json::Number::from_f64(result).unwrap_or(serde_json::Number::from(0)),
))
}
}
#[derive(Debug, Clone, Copy)]
pub enum ArithmeticOp {
Add,
Subtract,
Multiply,
Divide,
}
pub fn uppercase_filter(
) -> impl Fn(&Value, &HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static {
|value, _args| match value {
Value::String(s) => Ok(Value::String(s.to_uppercase())),
_ => Ok(value.clone()),
}
}
pub fn lowercase_filter(
) -> impl Fn(&Value, &HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static {
|value, _args| match value {
Value::String(s) => Ok(Value::String(s.to_lowercase())),
_ => Ok(value.clone()),
}
}
pub fn truncate_filter(
max_len: usize,
) -> impl Fn(&Value, &HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static {
move |value, _args| match value {
Value::String(s) => {
if s.len() > max_len {
Ok(Value::String(format!("{}...", &s[..max_len])))
} else {
Ok(Value::String(s.clone()))
}
}
_ => Ok(value.clone()),
}
}
pub fn join_filter(
separator: &str,
) -> impl Fn(&Value, &HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static {
let separator = separator.to_string();
move |value, _args| match value {
Value::Array(arr) => {
let joined = arr
.iter()
.map(|v| match v {
Value::String(s) => s.clone(),
_ => v.to_string(),
})
.collect::<Vec<_>>()
.join(&separator);
Ok(Value::String(joined))
}
_ => Ok(value.clone()),
}
}
pub struct ExtendedTemplateRenderer {
renderer: TemplateRenderer,
registry: FunctionRegistry,
}
impl ExtendedTemplateRenderer {
pub fn new() -> Result<Self> {
let mut renderer = TemplateRenderer::new()?;
let registry = FunctionRegistry::new();
Self::register_common_functions(&mut renderer.tera)?;
Ok(Self { renderer, registry })
}
fn register_common_functions(tera: &mut Tera) -> Result<()> {
register_custom_function(tera, "uppercase", |args| {
let input = args.get("input").and_then(|v| v.as_str()).unwrap_or("");
Ok(Value::String(input.to_uppercase()))
})?;
register_custom_function(tera, "lowercase", |args| {
let input = args.get("input").and_then(|v| v.as_str()).unwrap_or("");
Ok(Value::String(input.to_lowercase()))
})?;
register_custom_function(tera, "length", |args| {
let input = args.get("input");
let len = match input {
Some(Value::Array(arr)) => arr.len(),
Some(Value::String(s)) => s.len(),
Some(Value::Object(obj)) => obj.len(),
_ => 0,
};
Ok(Value::Number(len.into()))
})?;
register_custom_function(tera, "now_iso", |_| {
Ok(Value::String(chrono::Utc::now().to_rfc3339()))
})?;
register_custom_function(tera, "timestamp", |_| {
Ok(Value::Number(chrono::Utc::now().timestamp().into()))
})?;
register_custom_function(tera, "default", |args| {
let value = args.get("value");
let default = args.get("default");
match (value, default) {
(Some(v), _) if !v.is_null() => Ok(v.clone()),
(_, Some(d)) => Ok(d.clone()),
_ => Ok(Value::Null),
}
})?;
Ok(())
}
pub fn add_function<F>(mut self, func: CustomFunction<F>) -> Self
where
F: Fn(&HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
{
self.registry = self.registry.add_function(func);
self
}
pub fn add_filter<F>(mut self, filter: CustomFilter<F>) -> Self
where
F: Fn(&Value, &HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
{
self.registry = self.registry.add_filter(filter);
self
}
pub fn render(&mut self, template: &str, name: &str) -> Result<String> {
self.renderer.render_str(template, name)
}
pub fn renderer(&self) -> &TemplateRenderer {
&self.renderer
}
pub fn renderer_mut(&mut self) -> &mut TemplateRenderer {
&mut self.renderer
}
}
#[macro_export]
macro_rules! custom_function {
($name:expr, $func:expr) => {
$crate::custom::CustomFunction::new($name, $func)
};
}
#[macro_export]
macro_rules! custom_filter {
($name:expr, $filter:expr) => {
$crate::custom::CustomFilter::new($name, $filter)
};
}
#[macro_export]
macro_rules! register_functions {
($tera:expr, { $($name:expr => $func:expr),* $(,)? }) => {{
$(
$crate::custom::register_custom_function($tera, $name, $func)?;
)*
Ok::<(), $crate::error::TemplateError>(())
}};
}
#[macro_export]
macro_rules! register_filters {
($tera:expr, { $($name:expr => $filter:expr),* $(,)? }) => {{
$(
$crate::custom::register_custom_filter($tera, $name, $filter)?;
)*
Ok::<(), $crate::error::TemplateError>(())
}};
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::Value;
use std::collections::HashMap;
#[test]
fn test_custom_function_registration() {
let mut tera = Tera::default();
register_custom_function(&mut tera, "test_func", |args| {
let input = args.get("input").and_then(|v| v.as_str()).unwrap_or("");
Ok(Value::String(format!("Processed: {}", input)))
})
.unwrap();
assert!(tera.get_function("test_func").is_some());
}
#[test]
fn test_arithmetic_function() {
let add_func = arithmetic_function(ArithmeticOp::Add);
let mut args = HashMap::new();
args.insert("a".to_string(), Value::Number(5.into()));
args.insert("b".to_string(), Value::Number(3.into()));
let result = add_func(&args).unwrap();
assert_eq!(result, Value::Number(8.into()));
}
#[test]
fn test_format_function() {
let format_func = format_function("Hello {{ name }}, count: {{ count }}");
let mut args = HashMap::new();
args.insert("name".to_string(), Value::String("World".to_string()));
args.insert("count".to_string(), Value::String("42".to_string()));
let result = format_func(&args).unwrap();
assert_eq!(result, Value::String("Hello World, count: 42".to_string()));
}
#[test]
fn test_function_registry() {
let registry = FunctionRegistry::new()
.add_function(CustomFunction::new("test1", |args| {
Ok(Value::String("test1".to_string()))
}))
.add_filter(CustomFilter::new("test2", |value, _args| Ok(value.clone())));
assert_eq!(registry.function_count(), 1);
assert_eq!(registry.filter_count(), 1);
}
#[test]
fn test_extended_renderer() {
let mut renderer = ExtendedTemplateRenderer::new().unwrap();
assert!(renderer.renderer().has_template("_macros.toml.tera"));
let result = renderer
.render("Hello {{ uppercase(input='world') }}!", "test")
.unwrap();
assert_eq!(result, "Hello WORLD!");
}
}