use std::path::{Path, PathBuf};
use bytes::Bytes;
use http::header::{CONTENT_LENGTH, CONTENT_TYPE, HOST};
use salvo::{Depot, FlowCtrl, Request, Response, async_trait};
use tracing::debug;
const MAX_TEMPLATE_SIZE: usize = 1024 * 1024;
pub struct TemplatesHoop {
root: Option<PathBuf>,
}
impl TemplatesHoop {
pub fn new(root: Option<String>) -> Self {
debug!(root = root.as_deref(), "templates middleware initialized");
Self {
root: root.map(PathBuf::from),
}
}
}
struct TemplateVars {
host: String,
path: String,
method: String,
scheme: String,
client_ip: String,
query: String,
uri: String,
remote_addr: String,
server_name: String,
}
#[async_trait]
impl salvo::Handler for TemplatesHoop {
async fn handle(
&self,
req: &mut Request,
depot: &mut Depot,
res: &mut Response,
ctrl: &mut FlowCtrl,
) {
let client_addr = super::client_addr(req);
let host = req
.headers()
.get(HOST)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let server_name = host.split(':').next().unwrap_or("").to_string();
let vars = TemplateVars {
host: host.clone(),
path: req.uri().path().to_string(),
method: req.method().to_string(),
scheme: req.uri().scheme_str().unwrap_or("http").to_string(),
client_ip: client_addr.ip().to_string(),
query: req.uri().query().unwrap_or("").to_string(),
uri: req.uri().to_string(),
remote_addr: client_addr.to_string(),
server_name,
};
ctrl.call_next(req, depot, res).await;
let is_html = res
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|ct| ct.to_ascii_lowercase().contains("text/html"))
.unwrap_or(false);
if !is_html {
return;
}
let body = res.take_body();
let body_bytes = match super::compress::collect_res_body_bytes(body).await {
Ok(b) => b,
Err(_) => return,
};
if body_bytes.len() > MAX_TEMPLATE_SIZE {
debug!(
size = body_bytes.len(),
max = MAX_TEMPLATE_SIZE,
"response too large for template processing"
);
res.body(body_bytes);
return;
}
let html = match std::str::from_utf8(&body_bytes) {
Ok(s) => s.to_string(),
Err(_) => {
res.body(body_bytes);
return;
}
};
let processed = process_template(&html, &vars, &self.root);
let processed_bytes = Bytes::from(processed);
res.headers_mut().remove(CONTENT_LENGTH);
res.headers_mut()
.insert(CONTENT_LENGTH, processed_bytes.len().into());
debug!("applied template substitution to HTML response");
res.body(processed_bytes.to_vec());
}
}
fn process_template(input: &str, vars: &TemplateVars, root: &Option<PathBuf>) -> String {
let mut result = String::with_capacity(input.len());
let bytes = input.as_bytes();
let len = bytes.len();
let mut i = 0;
while i < len {
if i + 1 < len && bytes[i] == b'{' && bytes[i + 1] == b'{' {
if let Some(close_pos) = find_closing_braces(input, i + 2) {
let tag_content = &input[i + 2..close_pos];
let tag_content = tag_content.trim();
let replacement = resolve_tag(tag_content, vars, root);
result.push_str(&replacement);
i = close_pos + 2; continue;
}
}
result.push(bytes[i] as char);
i += 1;
}
result
}
fn find_closing_braces(s: &str, start: usize) -> Option<usize> {
let bytes = s.as_bytes();
let mut i = start;
while i + 1 < bytes.len() {
if bytes[i] == b'}' && bytes[i + 1] == b'}' {
return Some(i);
}
i += 1;
}
None
}
fn resolve_tag(tag: &str, vars: &TemplateVars, root: &Option<PathBuf>) -> String {
match tag {
"host" => return vars.host.clone(),
"path" => return vars.path.clone(),
"method" => return vars.method.clone(),
"scheme" => return vars.scheme.clone(),
"client_ip" => return vars.client_ip.clone(),
"query" => return vars.query.clone(),
"uri" => return vars.uri.clone(),
"remote_addr" => return vars.remote_addr.clone(),
"server_name" => return vars.server_name.clone(),
_ => {}
}
if let Some(var_name) = tag.strip_prefix(".Env.") {
let var_name = var_name.trim();
return std::env::var(var_name).unwrap_or_default();
}
if let Some(rest) = tag.strip_prefix("include ") {
let rest = rest.trim();
let include_path = rest.trim_matches('"').trim_matches('\'');
return resolve_include(include_path, root);
}
format!("{{{{{tag}}}}}")
}
fn resolve_include(path_str: &str, root: &Option<PathBuf>) -> String {
let path = Path::new(path_str);
let full_path = if path.is_relative() {
match root {
Some(root_dir) => root_dir.join(path),
None => PathBuf::from(path),
}
} else {
PathBuf::from(path)
};
let path_str_normalized = full_path.to_string_lossy().replace('\\', "/");
for component in path_str_normalized.split('/') {
if component == ".." {
debug!(path = %full_path.display(), "blocked include with path traversal");
return String::new();
}
}
match std::fs::read_to_string(&full_path) {
Ok(contents) => {
debug!(path = %full_path.display(), "included template file");
contents
}
Err(e) => {
debug!(path = %full_path.display(), error = %e, "failed to include template file");
String::new()
}
}
}