1use crate::error::{Result, TemplateError};
10use crate::renderer::TemplateRenderer;
11use serde_json::Value;
12use std::collections::HashMap;
13use tera::{Filter, Function, Tera};
14
15pub struct CustomFunction<F> {
19 name: String,
21 func: F,
23}
24
25impl<F> CustomFunction<F>
26where
27 F: Fn(&HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
28{
29 pub fn new(name: &str, func: F) -> Self {
35 Self {
36 name: name.to_string(),
37 func,
38 }
39 }
40
41 pub fn name(&self) -> &str {
43 &self.name
44 }
45}
46
47impl<F> Function for CustomFunction<F>
48where
49 F: Fn(&HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
50{
51 fn call(&self, args: &HashMap<String, Value>) -> tera::Result<Value> {
52 (self.func)(args).map_err(|e| tera::Error::msg(e.to_string()))
53 }
54}
55
56pub struct CustomFilter<F> {
60 name: String,
62 filter: F,
64}
65
66impl<F> CustomFilter<F>
67where
68 F: Fn(&Value, &HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
69{
70 pub fn new(name: &str, filter: F) -> Self {
76 Self {
77 name: name.to_string(),
78 filter,
79 }
80 }
81
82 pub fn name(&self) -> &str {
84 &self.name
85 }
86}
87
88impl<F> Filter for CustomFilter<F>
89where
90 F: Fn(&Value, &HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
91{
92 fn filter(&self, value: &Value, args: &HashMap<String, Value>) -> tera::Result<Value> {
93 (self.filter)(value, args).map_err(|e| tera::Error::msg(e.to_string()))
94 }
95}
96
97#[derive(Default)]
101pub struct FunctionRegistry {
102 functions: Vec<Box<dyn Function + Send + Sync>>,
104 filters: Vec<Box<dyn Filter + Send + Sync>>,
106}
107
108impl FunctionRegistry {
109 pub fn new() -> Self {
111 Self::default()
112 }
113
114 pub fn add_function<F>(mut self, func: CustomFunction<F>) -> Self
116 where
117 F: Fn(&HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
118 {
119 self.functions.push(Box::new(func));
120 self
121 }
122
123 pub fn add_filter<F>(mut self, filter: CustomFilter<F>) -> Self
125 where
126 F: Fn(&Value, &HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
127 {
128 self.filters.push(Box::new(filter));
129 self
130 }
131
132 pub fn register_all(&self, _tera: &mut Tera) -> Result<()> {
137 for _func in &self.functions {
138 }
142
143 for _filter in &self.filters {
144 }
146
147 Ok(())
148 }
149
150 pub fn function_count(&self) -> usize {
152 self.functions.len()
153 }
154
155 pub fn filter_count(&self) -> usize {
157 self.filters.len()
158 }
159}
160
161pub fn register_custom_function<F>(tera: &mut Tera, name: &str, func: F) -> Result<()>
170where
171 F: Fn(&HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
172{
173 let custom_func = CustomFunction::new(name, func);
174 tera.register_function(name, custom_func);
175 Ok(())
176}
177
178pub fn register_custom_filter<F>(tera: &mut Tera, name: &str, filter: F) -> Result<()>
185where
186 F: Fn(&Value, &HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
187{
188 let custom_filter = CustomFilter::new(name, filter);
189 tera.register_filter(name, custom_filter);
190 Ok(())
191}
192
193pub fn simple_string_function(
197 value: &str,
198) -> impl Fn(&HashMap<String, Value>) -> Result<Value> + Send + Sync + '_ {
199 let value = value.to_string();
200 move |_| Ok(Value::String(value.clone()))
201}
202
203pub fn format_function(
205 format_str: &str,
206) -> impl Fn(&HashMap<String, Value>) -> Result<Value> + Send + Sync + '_ {
207 let format_str = format_str.to_string();
208 move |args| {
209 let mut result = format_str.clone();
210 for (key, value) in args {
211 let placeholder = format!("{{{}}}", key);
212 let replacement = match value {
213 Value::String(s) => s.clone(),
214 Value::Number(n) => n.to_string(),
215 Value::Bool(b) => b.to_string(),
216 _ => value.to_string(),
217 };
218 result = result.replace(&placeholder, &replacement);
219 }
220 Ok(Value::String(result))
221 }
222}
223
224pub fn arithmetic_function(
226 operation: ArithmeticOp,
227) -> impl Fn(&HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static {
228 move |args| {
229 let a = args.get("a").and_then(|v| v.as_f64()).unwrap_or(0.0);
230 let b = args.get("b").and_then(|v| v.as_f64()).unwrap_or(0.0);
231
232 let result = match operation {
233 ArithmeticOp::Add => a + b,
234 ArithmeticOp::Subtract => a - b,
235 ArithmeticOp::Multiply => a * b,
236 ArithmeticOp::Divide => {
237 if b == 0.0 {
238 return Err(TemplateError::ValidationError(
239 "Division by zero".to_string(),
240 ));
241 }
242 a / b
243 }
244 };
245
246 Ok(Value::Number(
247 serde_json::Number::from_f64(result).unwrap_or(serde_json::Number::from(0)),
248 ))
249 }
250}
251
252#[derive(Debug, Clone, Copy)]
254pub enum ArithmeticOp {
255 Add,
257 Subtract,
259 Multiply,
261 Divide,
263}
264
265pub fn uppercase_filter(
269) -> impl Fn(&Value, &HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static {
270 |value, _args| match value {
271 Value::String(s) => Ok(Value::String(s.to_uppercase())),
272 _ => Ok(value.clone()),
273 }
274}
275
276pub fn lowercase_filter(
278) -> impl Fn(&Value, &HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static {
279 |value, _args| match value {
280 Value::String(s) => Ok(Value::String(s.to_lowercase())),
281 _ => Ok(value.clone()),
282 }
283}
284
285pub fn truncate_filter(
287 max_len: usize,
288) -> impl Fn(&Value, &HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static {
289 move |value, _args| match value {
290 Value::String(s) => {
291 if s.len() > max_len {
292 Ok(Value::String(format!("{}...", &s[..max_len])))
293 } else {
294 Ok(Value::String(s.clone()))
295 }
296 }
297 _ => Ok(value.clone()),
298 }
299}
300
301pub fn join_filter(
303 separator: &str,
304) -> impl Fn(&Value, &HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static {
305 let separator = separator.to_string();
306 move |value, _args| match value {
307 Value::Array(arr) => {
308 let joined = arr
309 .iter()
310 .map(|v| match v {
311 Value::String(s) => s.clone(),
312 _ => v.to_string(),
313 })
314 .collect::<Vec<_>>()
315 .join(&separator);
316 Ok(Value::String(joined))
317 }
318 _ => Ok(value.clone()),
319 }
320}
321
322pub struct ExtendedTemplateRenderer {
326 renderer: TemplateRenderer,
328 registry: FunctionRegistry,
330}
331
332impl ExtendedTemplateRenderer {
333 pub fn new() -> Result<Self> {
335 let mut renderer = TemplateRenderer::new()?;
336 let registry = FunctionRegistry::new();
337
338 Self::register_common_functions(&mut renderer.tera)?;
340
341 Ok(Self { renderer, registry })
342 }
343
344 fn register_common_functions(tera: &mut Tera) -> Result<()> {
346 register_custom_function(tera, "uppercase", |args| {
348 let input = args.get("input").and_then(|v| v.as_str()).unwrap_or("");
349 Ok(Value::String(input.to_uppercase()))
350 })?;
351
352 register_custom_function(tera, "lowercase", |args| {
353 let input = args.get("input").and_then(|v| v.as_str()).unwrap_or("");
354 Ok(Value::String(input.to_lowercase()))
355 })?;
356
357 register_custom_function(tera, "length", |args| {
359 let input = args.get("input");
360 let len = match input {
361 Some(Value::Array(arr)) => arr.len(),
362 Some(Value::String(s)) => s.len(),
363 Some(Value::Object(obj)) => obj.len(),
364 _ => 0,
365 };
366 Ok(Value::Number(len.into()))
367 })?;
368
369 register_custom_function(tera, "now_iso", |_| {
371 Ok(Value::String(chrono::Utc::now().to_rfc3339()))
372 })?;
373
374 register_custom_function(tera, "timestamp", |_| {
375 Ok(Value::Number(chrono::Utc::now().timestamp().into()))
376 })?;
377
378 register_custom_function(tera, "default", |args| {
380 let value = args.get("value");
381 let default = args.get("default");
382 match (value, default) {
383 (Some(v), _) if !v.is_null() => Ok(v.clone()),
384 (_, Some(d)) => Ok(d.clone()),
385 _ => Ok(Value::Null),
386 }
387 })?;
388
389 Ok(())
390 }
391
392 pub fn add_function<F>(mut self, func: CustomFunction<F>) -> Self
394 where
395 F: Fn(&HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
396 {
397 self.registry = self.registry.add_function(func);
398 self
399 }
400
401 pub fn add_filter<F>(mut self, filter: CustomFilter<F>) -> Self
403 where
404 F: Fn(&Value, &HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
405 {
406 self.registry = self.registry.add_filter(filter);
407 self
408 }
409
410 pub fn render(&mut self, template: &str, name: &str) -> Result<String> {
412 self.renderer.render_str(template, name)
413 }
414
415 pub fn renderer(&self) -> &TemplateRenderer {
417 &self.renderer
418 }
419
420 pub fn renderer_mut(&mut self) -> &mut TemplateRenderer {
422 &mut self.renderer
423 }
424}
425
426#[macro_export]
444macro_rules! custom_function {
445 ($name:expr, $func:expr) => {
446 $crate::custom::CustomFunction::new($name, $func)
447 };
448}
449
450#[macro_export]
452macro_rules! custom_filter {
453 ($name:expr, $filter:expr) => {
454 $crate::custom::CustomFilter::new($name, $filter)
455 };
456}
457
458#[macro_export]
477macro_rules! register_functions {
478 ($tera:expr, { $($name:expr => $func:expr),* $(,)? }) => {{
479 $(
480 $crate::custom::register_custom_function($tera, $name, $func)?;
481 )*
482 Ok::<(), $crate::error::TemplateError>(())
483 }};
484}
485
486#[macro_export]
488macro_rules! register_filters {
489 ($tera:expr, { $($name:expr => $filter:expr),* $(,)? }) => {{
490 $(
491 $crate::custom::register_custom_filter($tera, $name, $filter)?;
492 )*
493 Ok::<(), $crate::error::TemplateError>(())
494 }};
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500 use serde_json::Value;
501 use std::collections::HashMap;
502
503 #[test]
504 fn test_custom_function_registration() {
505 let mut tera = Tera::default();
506
507 register_custom_function(&mut tera, "test_func", |args| {
508 let input = args.get("input").and_then(|v| v.as_str()).unwrap_or("");
509 Ok(Value::String(format!("Processed: {}", input)))
510 })
511 .unwrap();
512
513 assert!(tera.get_function("test_func").is_some());
515 }
516
517 #[test]
518 fn test_arithmetic_function() {
519 let add_func = arithmetic_function(ArithmeticOp::Add);
520 let mut args = HashMap::new();
521 args.insert("a".to_string(), Value::Number(5.into()));
522 args.insert("b".to_string(), Value::Number(3.into()));
523
524 let result = add_func(&args).unwrap();
525 assert_eq!(result, Value::Number(8.into()));
526 }
527
528 #[test]
529 fn test_format_function() {
530 let format_func = format_function("Hello {{ name }}, count: {{ count }}");
531 let mut args = HashMap::new();
532 args.insert("name".to_string(), Value::String("World".to_string()));
533 args.insert("count".to_string(), Value::String("42".to_string()));
534
535 let result = format_func(&args).unwrap();
536 assert_eq!(result, Value::String("Hello World, count: 42".to_string()));
537 }
538
539 #[test]
540 fn test_function_registry() {
541 let registry = FunctionRegistry::new()
542 .add_function(CustomFunction::new("test1", |args| {
543 Ok(Value::String("test1".to_string()))
544 }))
545 .add_filter(CustomFilter::new("test2", |value, _args| Ok(value.clone())));
546
547 assert_eq!(registry.function_count(), 1);
548 assert_eq!(registry.filter_count(), 1);
549 }
550
551 #[test]
552 fn test_extended_renderer() {
553 let mut renderer = ExtendedTemplateRenderer::new().unwrap();
554
555 assert!(renderer.renderer().has_template("_macros.toml.tera"));
557
558 let result = renderer
560 .render("Hello {{ uppercase(input='world') }}!", "test")
561 .unwrap();
562 assert_eq!(result, "Hello WORLD!");
563 }
564}