use std::collections::HashMap;
use chrono::{DateTime, Local};
use uuid::Uuid;
use crate::error::NikaError;
const FORBIDDEN_VAR_CHARS: &[char] = &['/', '\\', '\0'];
const FORBIDDEN_VAR_PATTERNS: &[&str] = &["..", "~"];
fn validate_var_value(key: &str, value: &str) -> Result<(), NikaError> {
if value.is_empty() {
return Ok(());
}
for c in FORBIDDEN_VAR_CHARS {
if value.contains(*c) {
return Err(NikaError::TemplateError {
template: format!("{{{{{}}}}}", key),
reason: format!(
"Variable value contains forbidden character '{}': path traversal risk",
c
),
});
}
}
for pattern in FORBIDDEN_VAR_PATTERNS {
if value.contains(pattern) {
return Err(NikaError::TemplateError {
template: format!("{{{{{}}}}}", key),
reason: format!(
"Variable value contains forbidden pattern '{}': path traversal risk",
pattern
),
});
}
}
Ok(())
}
#[derive(Debug)]
pub struct TemplateResolver {
task_id: String,
workflow_name: String,
timestamp: DateTime<Local>,
custom_vars: HashMap<String, String>,
}
impl TemplateResolver {
pub fn new(task_id: impl Into<String>, workflow_name: impl Into<String>) -> Self {
Self {
task_id: task_id.into(),
workflow_name: workflow_name.into(),
timestamp: Local::now(),
custom_vars: HashMap::new(),
}
}
pub fn with_var(
mut self,
key: impl Into<String>,
value: impl Into<String>,
) -> Result<Self, NikaError> {
let key = key.into();
let value = value.into();
if key.is_empty() {
return Err(NikaError::TemplateError {
template: "{{}}".to_string(),
reason: "Variable name cannot be empty".to_string(),
});
}
validate_var_value(&key, &value)?;
self.custom_vars.insert(key, value);
Ok(self)
}
pub fn with_vars(mut self, vars: HashMap<String, String>) -> Result<Self, NikaError> {
for (key, value) in &vars {
if key.is_empty() {
return Err(NikaError::TemplateError {
template: "{{}}".to_string(),
reason: "Variable name cannot be empty".to_string(),
});
}
validate_var_value(key, value)?;
}
self.custom_vars.extend(vars);
Ok(self)
}
pub fn with_timestamp(mut self, timestamp: DateTime<Local>) -> Self {
self.timestamp = timestamp;
self
}
pub fn resolve(&self, template: &str) -> Result<String, NikaError> {
let mut result = template.to_string();
let mut pos = 0;
while let Some(start) = result[pos..].find("{{") {
let start = pos + start;
let Some(end) = result[start..].find("}}") else {
break;
};
let end = start + end + 2;
let var_name = &result[start + 2..end - 2].trim();
let value = self.resolve_variable(var_name)?;
result.replace_range(start..end, &value);
pos = start + value.len();
}
Ok(result)
}
fn resolve_variable(&self, var_name: &str) -> Result<String, NikaError> {
if let Some(format) = var_name.strip_prefix("date.") {
return Ok(self.format_date(format));
}
if let Some(format) = var_name.strip_prefix("time.") {
return Ok(self.format_time(format));
}
match var_name {
"task_id" => Ok(self.task_id.clone()),
"workflow_name" | "workflow" => Ok(self.workflow_name.clone()),
"date" => Ok(self.timestamp.format("%Y-%m-%d").to_string()),
"time" => Ok(self.timestamp.format("%H-%M-%S").to_string()),
"timestamp" => Ok(self.timestamp.timestamp().to_string()),
"uuid" => Ok(Uuid::new_v4().to_string()),
_ => {
if let Some(value) = self.custom_vars.get(var_name) {
return Ok(value.clone());
}
Err(NikaError::TemplateError {
template: format!("{{{{{}}}}}", var_name),
reason: format!("Unknown template variable: {}", var_name),
})
}
}
}
fn format_date(&self, format: &str) -> String {
let mut result = format.to_string();
result = result.replace("YYYY", &self.timestamp.format("%Y").to_string());
result = result.replace("MM", &self.timestamp.format("%m").to_string());
result = result.replace("DD", &self.timestamp.format("%d").to_string());
result
}
fn format_time(&self, format: &str) -> String {
let mut result = format.to_string();
result = result.replace("HH", &self.timestamp.format("%H").to_string());
result = result.replace("mm", &self.timestamp.format("%M").to_string());
result = result.replace("ss", &self.timestamp.format("%S").to_string());
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::TimeZone;
fn fixed_resolver() -> TemplateResolver {
let ts = Local.with_ymd_and_hms(2024, 1, 15, 14, 30, 45).unwrap();
TemplateResolver::new("test_task", "test_workflow").with_timestamp(ts)
}
#[test]
fn test_resolve_task_id() {
let resolver = fixed_resolver();
let result = resolver.resolve("{{task_id}}/output.json").unwrap();
assert_eq!(result, "test_task/output.json");
}
#[test]
fn test_resolve_workflow_name() {
let resolver = fixed_resolver();
let result = resolver
.resolve("{{workflow_name}}/{{task_id}}.json")
.unwrap();
assert_eq!(result, "test_workflow/test_task.json");
}
#[test]
fn test_resolve_date() {
let resolver = fixed_resolver();
let result = resolver.resolve("{{date}}/output.json").unwrap();
assert_eq!(result, "2024-01-15/output.json");
}
#[test]
fn test_resolve_time() {
let resolver = fixed_resolver();
let result = resolver.resolve("{{time}}.json").unwrap();
assert_eq!(result, "14-30-45.json");
}
#[test]
fn test_resolve_timestamp() {
let resolver = fixed_resolver();
let result = resolver.resolve("{{timestamp}}.json").unwrap();
assert!(result.ends_with(".json"));
let ts_str = result.strip_suffix(".json").unwrap();
assert!(ts_str.parse::<i64>().is_ok());
}
#[test]
fn test_resolve_uuid() {
let resolver = fixed_resolver();
let result = resolver.resolve("{{uuid}}.json").unwrap();
assert!(result.ends_with(".json"));
let uuid_str = result.strip_suffix(".json").unwrap();
assert!(Uuid::parse_str(uuid_str).is_ok());
}
#[test]
fn test_resolve_date_format() {
let resolver = fixed_resolver();
let result = resolver.resolve("{{date.YYYY-MM-DD}}.json").unwrap();
assert_eq!(result, "2024-01-15.json");
}
#[test]
fn test_resolve_date_format_custom() {
let resolver = fixed_resolver();
let result = resolver.resolve("{{date.YYYY/MM/DD}}.json").unwrap();
assert_eq!(result, "2024/01/15.json");
}
#[test]
fn test_resolve_time_format() {
let resolver = fixed_resolver();
let result = resolver.resolve("{{time.HH-mm-ss}}.json").unwrap();
assert_eq!(result, "14-30-45.json");
}
#[test]
fn test_resolve_custom_var() {
let resolver = fixed_resolver().with_var("entity", "qr-code").unwrap();
let result = resolver.resolve("{{entity}}/{{task_id}}.json").unwrap();
assert_eq!(result, "qr-code/test_task.json");
}
#[test]
fn test_resolve_multiple_vars() {
let mut vars = HashMap::new();
vars.insert("locale".to_string(), "fr-FR".to_string());
vars.insert("version".to_string(), "v1".to_string());
let resolver = fixed_resolver().with_vars(vars).unwrap();
let result = resolver
.resolve("{{locale}}/{{version}}/{{task_id}}.json")
.unwrap();
assert_eq!(result, "fr-FR/v1/test_task.json");
}
#[test]
fn test_var_path_traversal_rejected() {
let result = fixed_resolver().with_var("entity", "../escape");
assert!(result.is_err());
let err = result.unwrap_err();
if let NikaError::TemplateError { reason, .. } = err {
assert!(reason.contains("path traversal"));
} else {
panic!("Expected TemplateError");
}
}
#[test]
fn test_var_slash_rejected() {
let result = fixed_resolver().with_var("path", "a/b/c");
assert!(result.is_err());
let err = result.unwrap_err();
if let NikaError::TemplateError { reason, .. } = err {
assert!(reason.contains("forbidden character"));
} else {
panic!("Expected TemplateError");
}
}
#[test]
fn test_empty_var_name_rejected() {
let result = fixed_resolver().with_var("", "value");
assert!(result.is_err());
let err = result.unwrap_err();
if let NikaError::TemplateError { reason, .. } = err {
assert!(reason.contains("empty"));
} else {
panic!("Expected TemplateError");
}
}
#[test]
fn test_empty_var_value_allowed() {
let resolver = fixed_resolver().with_var("empty", "").unwrap();
let result = resolver.resolve("prefix{{empty}}suffix").unwrap();
assert_eq!(result, "prefixsuffix");
}
#[test]
fn test_resolve_unknown_var() {
let resolver = fixed_resolver();
let result = resolver.resolve("{{unknown}}/output.json");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, NikaError::TemplateError { .. }));
}
#[test]
fn test_resolve_unclosed_template() {
let resolver = fixed_resolver();
let result = resolver.resolve("{{task_id/output.json").unwrap();
assert_eq!(result, "{{task_id/output.json");
}
#[test]
fn test_resolve_no_templates() {
let resolver = fixed_resolver();
let result = resolver.resolve("simple/path/output.json").unwrap();
assert_eq!(result, "simple/path/output.json");
}
#[test]
fn test_resolve_whitespace_in_var() {
let resolver = fixed_resolver();
let result = resolver.resolve("{{ task_id }}/output.json").unwrap();
assert_eq!(result, "test_task/output.json");
}
#[test]
fn test_resolve_complex_path() {
let resolver = fixed_resolver().with_var("locale", "es-MX").unwrap();
let result = resolver
.resolve("{{workflow_name}}/{{date}}/{{locale}}/{{task_id}}_{{time}}.json")
.unwrap();
assert_eq!(
result,
"test_workflow/2024-01-15/es-MX/test_task_14-30-45.json"
);
}
#[test]
fn test_resolve_workflow_alias() {
let resolver = fixed_resolver();
let result = resolver.resolve("{{workflow}}/{{task_id}}.json").unwrap();
assert_eq!(result, "test_workflow/test_task.json");
}
}