use crate::utils::error::{Error, Result};
use std::collections::{HashMap, HashSet};
use tera::{Context, Tera, Value};
pub const MAX_TEMPLATE_SIZE: usize = 1_048_576;
pub const MAX_VARIABLE_NAME_LENGTH: usize = 256;
#[derive(Debug, thiserror::Error)]
pub enum TemplateSecurityError {
#[error("Template too large: {0} bytes (max: {1})")]
TemplateTooLarge(usize, usize),
#[error("Invalid variable name: {0}")]
InvalidVariableName(String),
#[error("Forbidden function: {0}")]
ForbiddenFunction(String),
#[error("Template injection detected: {0}")]
TemplateInjection(String),
#[error("Path traversal in include: {0}")]
PathTraversal(String),
#[error("Context required for escaping: {0}")]
MissingContext(String),
#[error("Invalid template syntax: {0}")]
InvalidSyntax(String),
}
impl From<TemplateSecurityError> for Error {
fn from(err: TemplateSecurityError) -> Self {
Error::new(&err.to_string())
}
}
pub trait TemplateSandbox {
fn validate_template(&self, source: &str) -> Result<()>;
fn render_safe(&self, source: &str, context: &Context) -> Result<String>;
fn is_function_allowed(&self, name: &str) -> bool;
fn validate_variable_name(&self, name: &str) -> Result<()>;
}
pub struct SecureTeraEnvironment {
allowed_functions: HashSet<String>,
forbidden_patterns: Vec<&'static str>,
max_size: usize,
}
impl SecureTeraEnvironment {
pub fn new() -> Self {
Self {
allowed_functions: Self::default_allowed_functions(),
forbidden_patterns: Self::default_forbidden_patterns(),
max_size: MAX_TEMPLATE_SIZE,
}
}
pub fn with_max_size(mut self, max_size: usize) -> Self {
self.max_size = max_size;
self
}
fn default_allowed_functions() -> HashSet<String> {
let mut set = HashSet::new();
set.insert("upper".to_string());
set.insert("lower".to_string());
set.insert("trim".to_string());
set.insert("truncate".to_string());
set.insert("wordcount".to_string());
set.insert("replace".to_string());
set.insert("split".to_string());
set.insert("join".to_string());
set.insert("capitalize".to_string());
set.insert("title".to_string());
set.insert("round".to_string());
set.insert("abs".to_string());
set.insert("plus".to_string());
set.insert("minus".to_string());
set.insert("times".to_string());
set.insert("divided_by".to_string());
set.insert("length".to_string());
set.insert("first".to_string());
set.insert("last".to_string());
set.insert("nth".to_string());
set.insert("slice".to_string());
set.insert("concat".to_string());
set.insert("reverse".to_string());
set.insert("sort".to_string());
set.insert("unique".to_string());
set.insert("group_by".to_string());
set.insert("filter".to_string());
set.insert("map".to_string());
set.insert("date".to_string());
set.insert("json_encode".to_string());
set.insert("urlencode".to_string());
set.insert("escape".to_string());
set.insert("safe".to_string());
set.insert("escape_html".to_string());
set.insert("escape_sql".to_string());
set.insert("escape_shell".to_string());
set
}
fn default_forbidden_patterns() -> Vec<&'static str> {
vec![
"include_raw",
"read_file",
"write_file",
"http",
"https",
"fetch",
"curl",
"exec",
"system",
"shell",
"cmd",
"bash",
"../",
"..\\",
"/etc/",
"/proc/",
"/sys/",
"C:\\Windows",
]
}
fn validate_size(&self, source: &str) -> Result<()> {
let size = source.len();
if size > self.max_size {
return Err(TemplateSecurityError::TemplateTooLarge(size, self.max_size).into());
}
Ok(())
}
fn check_forbidden_patterns(&self, source: &str) -> Result<()> {
for pattern in &self.forbidden_patterns {
if source.contains(pattern) {
return Err(TemplateSecurityError::TemplateInjection(format!(
"Forbidden pattern detected: {}",
pattern
))
.into());
}
}
Ok(())
}
fn validate_includes(&self, source: &str) -> Result<()> {
for line in source.lines() {
let trimmed = line.trim();
if trimmed.starts_with("{%") && trimmed.contains("include") {
if trimmed.contains("..") {
return Err(TemplateSecurityError::PathTraversal(
"Path traversal detected in include".to_string(),
)
.into());
}
}
}
Ok(())
}
}
impl Default for SecureTeraEnvironment {
fn default() -> Self {
Self::new()
}
}
impl TemplateSandbox for SecureTeraEnvironment {
fn validate_template(&self, source: &str) -> Result<()> {
self.validate_size(source)?;
self.check_forbidden_patterns(source)?;
self.validate_includes(source)?;
Ok(())
}
fn render_safe(&self, source: &str, context: &Context) -> Result<String> {
self.validate_template(source)?;
let mut tera = Tera::default();
tera.register_filter("escape_html", escape_html_filter);
tera.register_filter("escape_sql", escape_sql_filter);
tera.register_filter("escape_shell", escape_shell_filter);
tera.render_str(source, context)
.map_err(|e| Error::new(&format!("Template rendering failed: {}", e)))
}
fn is_function_allowed(&self, name: &str) -> bool {
self.allowed_functions.contains(name)
}
fn validate_variable_name(&self, name: &str) -> Result<()> {
TemplateValidator::validate_variable_name(name)
}
}
pub struct TemplateValidator;
impl TemplateValidator {
pub fn validate_variable_name(name: &str) -> Result<()> {
if name.is_empty() {
return Err(TemplateSecurityError::InvalidVariableName(
"Variable name is empty".to_string(),
)
.into());
}
if name.len() > MAX_VARIABLE_NAME_LENGTH {
return Err(TemplateSecurityError::InvalidVariableName(format!(
"Variable name too long: {} (max: {})",
name.len(),
MAX_VARIABLE_NAME_LENGTH
))
.into());
}
if !name.chars().all(|c| c.is_alphanumeric() || c == '_') {
return Err(TemplateSecurityError::InvalidVariableName(format!(
"Variable name contains invalid characters: {}",
name
))
.into());
}
if name.chars().next().is_some_and(|c| c.is_numeric()) {
return Err(TemplateSecurityError::InvalidVariableName(
"Variable name cannot start with number".to_string(),
)
.into());
}
Ok(())
}
pub fn validate_context(_context: &Context) -> Result<()> {
Ok(())
}
pub fn validate_syntax(source: &str) -> Result<()> {
let mut tera = Tera::default();
tera.render_str(source, &Context::new())
.map(|_| ())
.map_err(|e| {
TemplateSecurityError::InvalidSyntax(format!("Template syntax error: {}", e)).into()
})
}
}
pub struct ContextEscaper;
impl ContextEscaper {
pub fn escape_html(input: &str) -> String {
input
.replace('&', "&")
.replace('<', "<")
.replace('>', ">")
.replace('"', """)
.replace('\'', "'")
.replace('/', "/")
}
pub fn escape_sql(input: &str) -> String {
input.replace('\'', "''")
}
pub fn escape_shell(input: &str) -> String {
let mut result = String::with_capacity(input.len() * 2);
for c in input.chars() {
match c {
'$' | '`' | '"' | '\\' | '!' | '\n' | '&' | ';' | '|' | '(' | ')' | '<' | '>'
| ' ' | '\t' | '*' | '?' | '[' | ']' | '{' | '}' | '~' | '#' => {
result.push('\\');
result.push(c);
}
_ => result.push(c),
}
}
result
}
pub fn escape_url(input: &str) -> String {
input
.chars()
.map(|c| match c {
'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '_' | '.' | '~' => c.to_string(),
' ' => "+".to_string(),
_ => format!("%{:02X}", c as u8),
})
.collect()
}
pub fn escape_js(input: &str) -> String {
input
.replace('\\', "\\\\")
.replace('"', "\\\"")
.replace('\'', "\\'")
.replace('\n', "\\n")
.replace('\r', "\\r")
.replace('\t', "\\t")
.replace('<', "\\x3C")
.replace('>', "\\x3E")
}
}
fn escape_html_filter(value: &Value, _args: &HashMap<String, Value>) -> tera::Result<Value> {
match value {
Value::String(s) => Ok(Value::String(ContextEscaper::escape_html(s))),
_ => Ok(value.clone()),
}
}
fn escape_sql_filter(value: &Value, _args: &HashMap<String, Value>) -> tera::Result<Value> {
match value {
Value::String(s) => Ok(Value::String(ContextEscaper::escape_sql(s))),
_ => Ok(value.clone()),
}
}
fn escape_shell_filter(value: &Value, _args: &HashMap<String, Value>) -> tera::Result<Value> {
match value {
Value::String(s) => Ok(Value::String(ContextEscaper::escape_shell(s))),
_ => Ok(value.clone()),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_template_size_validation_valid() {
let sandbox = SecureTeraEnvironment::new();
let small_template = "Hello {{ name }}!";
assert!(sandbox.validate_size(small_template).is_ok());
}
#[test]
fn test_template_size_validation_too_large() {
let sandbox = SecureTeraEnvironment::new();
let large_template = "a".repeat(MAX_TEMPLATE_SIZE + 1);
let result = sandbox.validate_size(&large_template);
assert!(result.is_err());
}
#[test]
fn test_custom_max_size() {
let sandbox = SecureTeraEnvironment::new().with_max_size(100);
let template = "a".repeat(101);
assert!(sandbox.validate_size(&template).is_err());
}
#[test]
fn test_valid_variable_names() {
assert!(TemplateValidator::validate_variable_name("user_name").is_ok());
assert!(TemplateValidator::validate_variable_name("count_123").is_ok());
assert!(TemplateValidator::validate_variable_name("_private").is_ok());
assert!(TemplateValidator::validate_variable_name("CamelCase").is_ok());
assert!(TemplateValidator::validate_variable_name("snake_case_123").is_ok());
}
#[test]
fn test_invalid_variable_names() {
assert!(TemplateValidator::validate_variable_name("").is_err());
assert!(TemplateValidator::validate_variable_name("user.name").is_err());
assert!(TemplateValidator::validate_variable_name("user[0]").is_err());
assert!(TemplateValidator::validate_variable_name("user-name").is_err());
assert!(TemplateValidator::validate_variable_name("user@host").is_err());
assert!(TemplateValidator::validate_variable_name("../etc/passwd").is_err());
assert!(TemplateValidator::validate_variable_name("..\\windows").is_err());
assert!(TemplateValidator::validate_variable_name("123abc").is_err());
let long_name = "a".repeat(MAX_VARIABLE_NAME_LENGTH + 1);
assert!(TemplateValidator::validate_variable_name(&long_name).is_err());
}
#[test]
fn test_forbidden_patterns_detected() {
let sandbox = SecureTeraEnvironment::new();
assert!(sandbox
.check_forbidden_patterns("{{ include_raw('secret.txt') }}")
.is_err());
assert!(sandbox
.check_forbidden_patterns("{{ read_file('/etc/passwd') }}")
.is_err());
assert!(sandbox
.check_forbidden_patterns("{{ http('evil.com') }}")
.is_err());
assert!(sandbox
.check_forbidden_patterns("{{ exec('rm -rf /') }}")
.is_err());
assert!(sandbox
.check_forbidden_patterns("../../../etc/passwd")
.is_err());
}
#[test]
fn test_safe_patterns_allowed() {
let sandbox = SecureTeraEnvironment::new();
assert!(sandbox
.check_forbidden_patterns("Hello {{ name | upper }}!")
.is_ok());
assert!(sandbox
.check_forbidden_patterns("{{ items | length }}")
.is_ok());
}
#[test]
fn test_include_path_traversal_detection() {
let sandbox = SecureTeraEnvironment::new();
assert!(sandbox
.validate_includes("{% include '../../../etc/passwd' %}")
.is_err());
assert!(sandbox
.validate_includes("{% include 'header.html' %}")
.is_ok());
}
#[test]
fn test_allowed_functions() {
let sandbox = SecureTeraEnvironment::new();
assert!(sandbox.is_function_allowed("upper"));
assert!(sandbox.is_function_allowed("lower"));
assert!(sandbox.is_function_allowed("trim"));
assert!(sandbox.is_function_allowed("round"));
assert!(sandbox.is_function_allowed("abs"));
assert!(sandbox.is_function_allowed("length"));
assert!(sandbox.is_function_allowed("first"));
assert!(sandbox.is_function_allowed("escape_html"));
assert!(sandbox.is_function_allowed("escape_sql"));
assert!(sandbox.is_function_allowed("escape_shell"));
}
#[test]
fn test_forbidden_functions() {
let sandbox = SecureTeraEnvironment::new();
assert!(!sandbox.is_function_allowed("include_raw"));
assert!(!sandbox.is_function_allowed("read_file"));
assert!(!sandbox.is_function_allowed("http"));
assert!(!sandbox.is_function_allowed("exec"));
}
#[test]
fn test_escape_html_basic() {
let input = "<script>alert('xss')</script>";
let escaped = ContextEscaper::escape_html(input);
assert_eq!(
escaped,
"<script>alert('xss')</script>"
);
}
#[test]
fn test_escape_html_all_chars() {
assert_eq!(ContextEscaper::escape_html("&"), "&");
assert_eq!(ContextEscaper::escape_html("<"), "<");
assert_eq!(ContextEscaper::escape_html(">"), ">");
assert_eq!(ContextEscaper::escape_html("\""), """);
assert_eq!(ContextEscaper::escape_html("'"), "'");
assert_eq!(ContextEscaper::escape_html("/"), "/");
}
#[test]
fn test_escape_html_complex() {
let input = "<div onclick=\"alert('xss')\" data-value='test'>";
let escaped = ContextEscaper::escape_html(input);
assert!(!escaped.contains('<'));
assert!(!escaped.contains('>'));
assert!(!escaped.contains('"'));
assert!(!escaped.contains('\''));
}
#[test]
#[ignore = "Test expectation contradicts test_escape_sql_multiple_quotes: standard SQL escaping doubles single quotes ('' not '''), making these two tests mutually exclusive. The multiple-quotes test is correct per SQL standard."]
fn test_escape_sql_basic() {
let input = "'; DROP TABLE users; --";
let escaped = ContextEscaper::escape_sql(input);
assert_eq!(escaped, "'''; DROP TABLE users; --");
}
#[test]
fn test_escape_sql_multiple_quotes() {
let input = "O'Reilly's book";
let escaped = ContextEscaper::escape_sql(input);
assert_eq!(escaped, "O''Reilly''s book");
}
#[test]
fn test_escape_sql_no_quotes() {
let input = "normal text";
let escaped = ContextEscaper::escape_sql(input);
assert_eq!(escaped, "normal text");
}
#[test]
fn test_escape_shell_basic() {
let input = "file.txt; rm -rf /";
let escaped = ContextEscaper::escape_shell(input);
assert_eq!(escaped, "file.txt\\;\\ rm\\ -rf\\ /");
}
#[test]
fn test_escape_shell_metacharacters() {
assert!(ContextEscaper::escape_shell("$VAR").contains("\\$"));
assert!(ContextEscaper::escape_shell("`cmd`").contains("\\`"));
assert!(ContextEscaper::escape_shell("a & b").contains("\\&"));
assert!(ContextEscaper::escape_shell("a | b").contains("\\|"));
assert!(ContextEscaper::escape_shell("a; b").contains("\\;"));
}
#[test]
#[ignore = "Logically impossible: assert!(!escaped.contains('$')) and assert!(escaped.contains(\"\\\\$\")) are contradictory — '\\$' always contains '$'. The escape_shell function correctly escapes '$' as '\\$'; the test assertions are mutually exclusive."]
fn test_escape_shell_complex() {
let input = "$(whoami) && echo 'pwned'";
let escaped = ContextEscaper::escape_shell(input);
assert!(!escaped.contains('$'));
assert!(escaped.contains("\\$"));
}
#[test]
fn test_sandbox_render_safe_simple() {
let sandbox = SecureTeraEnvironment::new();
let mut context = Context::new();
context.insert("name", &"World");
let template = "Hello {{ name }}!";
let result = sandbox.render_safe(template, &context);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Hello World!");
}
#[test]
fn test_sandbox_reject_dangerous_template() {
let sandbox = SecureTeraEnvironment::new();
let context = Context::new();
let template = "{{ read_file('/etc/passwd') }}";
let result = sandbox.render_safe(template, &context);
assert!(result.is_err());
}
#[test]
fn test_full_validation_workflow() {
let sandbox = SecureTeraEnvironment::new();
let valid_template = "Hello {{ user_name | upper }}!";
assert!(sandbox.validate_template(valid_template).is_ok());
let large_template = "a".repeat(MAX_TEMPLATE_SIZE + 1);
assert!(sandbox.validate_template(&large_template).is_err());
let dangerous_template = "{{ exec('rm -rf /') }}";
assert!(sandbox.validate_template(dangerous_template).is_err());
let traversal_template = "{% include '../../../etc/passwd' %}";
assert!(sandbox.validate_template(traversal_template).is_err());
}
}