use std::io;
use std::path::{Path, PathBuf};
use tracing::{debug, error, info};
use super::{ClientTemplateKind, ServerTemplateKind};
use crate::core::protocol::Protocol;
pub trait TemplateConfigReader {
fn get_template_dir(&self) -> Option<String>;
}
pub struct EnvTemplateConfigReader;
impl TemplateConfigReader for EnvTemplateConfigReader {
fn get_template_dir(&self) -> Option<String> {
std::env::var("AGENTERRA_TEMPLATE_DIR").ok()
}
}
#[cfg(test)]
pub struct MockTemplateConfigReader(Option<String>);
#[cfg(test)]
impl MockTemplateConfigReader {
pub fn new(template_dir: Option<String>) -> Self {
Self(template_dir)
}
}
#[cfg(test)]
impl TemplateConfigReader for MockTemplateConfigReader {
fn get_template_dir(&self) -> Option<String> {
self.0.clone()
}
}
#[derive(Debug, Clone)]
pub struct TemplateDir {
template_path: PathBuf,
kind: ServerTemplateKind,
protocol: Protocol,
}
impl TemplateDir {
pub fn discover_with_protocol(
protocol: Protocol,
kind: ServerTemplateKind,
custom_dir: Option<&Path>,
) -> io::Result<Self> {
debug!(
"TemplateDir::discover_with_protocol - protocol: {:?}, kind: {:?}, custom_dir: {:?}",
protocol, kind, custom_dir
);
let template_path =
Self::resolve_template_path(protocol, kind.role().as_str(), kind.as_str(), custom_dir)?;
debug!("Resolved template path: {}", template_path.display());
debug!("Template path exists: {}", template_path.exists());
if !template_path.exists() {
error!(
"Template directory not found at resolved path: {}",
template_path.display()
);
return Err(io::Error::new(
io::ErrorKind::NotFound,
format!("Template directory not found: {}", template_path.display()),
));
}
info!(
"Successfully created TemplateDir for: {}",
template_path.display()
);
Ok(Self {
template_path,
kind,
protocol,
})
}
pub fn discover_client_with_protocol(
protocol: Protocol,
kind: ClientTemplateKind,
custom_dir: Option<&Path>,
) -> io::Result<Self> {
debug!(
"TemplateDir::discover_client_with_protocol - protocol: {:?}, kind: {:?}, custom_dir: {:?}",
protocol, kind, custom_dir
);
let template_path =
Self::resolve_template_path(protocol, kind.role().as_str(), kind.as_str(), custom_dir)?;
debug!("Resolved client template path: {}", template_path.display());
debug!("Client template path exists: {}", template_path.exists());
if !template_path.exists() {
error!(
"Client template directory not found at resolved path: {}",
template_path.display()
);
return Err(io::Error::new(
io::ErrorKind::NotFound,
format!(
"Client template directory not found: {}",
template_path.display()
),
));
}
info!(
"Successfully created TemplateDir for client: {}",
template_path.display()
);
Ok(Self {
template_path,
kind: ServerTemplateKind::Custom, protocol,
})
}
pub fn kind(&self) -> ServerTemplateKind {
self.kind
}
pub fn protocol(&self) -> Protocol {
self.protocol
}
pub fn template_path(&self) -> &Path {
&self.template_path
}
fn resolve_template_path(
protocol: Protocol,
role: &str,
kind: &str,
custom_dir: Option<&Path>,
) -> io::Result<PathBuf> {
if let Some(dir) = custom_dir {
debug!(
"Using custom template directory directly: {}",
dir.display()
);
if !dir.exists() {
error!("Custom template directory not found: {}", dir.display());
return Err(io::Error::new(
io::ErrorKind::NotFound,
format!("Template directory not found: {}", dir.display()),
));
}
Ok(dir.to_path_buf())
} else {
debug!("Auto-discovering template directory...");
let discovered = Self::find_template_base_dir().ok_or_else(|| {
error!("Could not find template directory in any standard location");
io::Error::new(
io::ErrorKind::NotFound,
"Could not find template directory in any standard location",
)
})?;
debug!("Auto-discovered template base: {}", discovered.display());
let template_path = discovered
.join("templates")
.join(protocol.path_segment())
.join(role)
.join(kind);
Ok(template_path)
}
}
fn find_template_base_dir() -> Option<PathBuf> {
Self::find_template_base_dir_with_config(&EnvTemplateConfigReader)
}
fn find_template_base_dir_with_config(
config_reader: &dyn TemplateConfigReader,
) -> Option<PathBuf> {
if let Some(dir) = config_reader.get_template_dir() {
let path = PathBuf::from(dir);
if let Err(e) = Self::validate_template_path_safely(&path) {
error!("Template directory validation failed: {}", e);
return None;
}
if path.exists() {
return Some(path);
}
}
let search_locations = Self::get_template_search_locations();
search_locations
.into_iter()
.find(|location| location.join("templates").exists())
}
fn get_template_search_locations() -> Vec<PathBuf> {
let mut locations = Vec::new();
if let Ok(exe_path) = std::env::current_exe() {
if let Some(exe_dir) = exe_path.parent() {
if let Ok(exe_dir_abs) = exe_dir.canonicalize() {
locations.push(exe_dir_abs.clone());
if let Some(parent_dir) = exe_dir_abs.parent() {
locations.push(parent_dir.to_path_buf());
}
}
}
}
if let Ok(current_dir) = std::env::current_dir() {
locations.push(current_dir);
}
if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
let manifest_path = PathBuf::from(manifest_dir);
locations.push(manifest_path.clone());
if let Some(workspace_root) = manifest_path.parent() {
locations.push(workspace_root.to_path_buf());
}
}
if let Some(config_dir) = dirs::config_dir() {
locations.push(config_dir.join("agenterra"));
}
locations
}
fn validate_template_path(path: &Path) -> Result<(), io::Error> {
let canonical_path = path.canonicalize().map_err(|e| {
error!("Failed to canonicalize template path: {}", e);
io::Error::new(
io::ErrorKind::InvalidInput,
format!("Invalid template path: {}", e),
)
})?;
debug!("Validating template path: {}", canonical_path.display());
Self::validate_unix_system_paths(&canonical_path)?;
if Self::is_path_allowed(&canonical_path) {
return Ok(());
}
debug!(
"Template path validation passed (external location): {}",
canonical_path.display()
);
Ok(())
}
fn is_path_allowed(canonical_path: &Path) -> bool {
if let Some(home_dir) = dirs::home_dir() {
if let Ok(home_canonical) = home_dir.canonicalize() {
if canonical_path.starts_with(&home_canonical) {
debug!(
"Template path allowed under home directory: {}",
canonical_path.display()
);
return true;
}
}
}
if let Ok(current_dir) = std::env::current_dir() {
if let Ok(current_canonical) = current_dir.canonicalize() {
if Self::is_under_workspace(canonical_path, ¤t_canonical) {
return true;
}
}
}
if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
let manifest_path = PathBuf::from(manifest_dir);
if let Ok(manifest_canonical) = manifest_path.canonicalize() {
if canonical_path.starts_with(&manifest_canonical) {
debug!(
"Template path allowed under cargo manifest dir: {}",
canonical_path.display()
);
return true;
}
if let Some(parent) = manifest_canonical.parent() {
if canonical_path.starts_with(parent) {
debug!(
"Template path allowed under cargo workspace: {}",
canonical_path.display()
);
return true;
}
}
}
}
false
}
fn is_under_workspace(path: &Path, workspace_dir: &Path) -> bool {
const MAX_PARENT_DEPTH: usize = 3;
if path.starts_with(workspace_dir) {
debug!(
"Template path allowed under current directory: {}",
path.display()
);
return true;
}
let mut parent = workspace_dir;
for depth in 0..MAX_PARENT_DEPTH {
if let Some(p) = parent.parent() {
if path.starts_with(p) {
debug!(
"Template path allowed under workspace parent (depth {}): {}",
depth + 1,
path.display()
);
return true;
}
parent = p;
} else {
break;
}
}
false
}
fn validate_template_path_safely(path: &Path) -> Result<(), io::Error> {
if path.exists() {
return Self::validate_template_path(path);
}
debug!("Validating non-existent template path: {}", path.display());
let path_str = path.to_string_lossy();
#[cfg(unix)]
{
const SYSTEM_DIRS: &[&str] = &[
"/etc/",
"/usr/bin/",
"/usr/sbin/",
"/root/",
"/boot/",
"/sys/",
"/proc/",
];
for sys_dir in SYSTEM_DIRS {
if path_str.starts_with(sys_dir) {
error!("Potentially unsafe template path rejected: {}", path_str);
return Err(io::Error::new(
io::ErrorKind::PermissionDenied,
format!("Template path not allowed: {}", path_str),
));
}
}
}
if path_str.contains("../../../") {
error!(
"Directory traversal detected in template path: {}",
path_str
);
return Err(io::Error::new(
io::ErrorKind::PermissionDenied,
"Directory traversal not allowed in template paths".to_string(),
));
}
debug!(
"Template path validation passed (non-existent path): {}",
path.display()
);
Ok(())
}
fn validate_unix_system_paths(canonical_path: &Path) -> Result<(), io::Error> {
let components: Vec<_> = canonical_path.components().collect();
if components.len() < 2 {
return Ok(());
}
let second = match (components.first(), components.get(1)) {
(Some(std::path::Component::RootDir), Some(std::path::Component::Normal(dir))) => {
dir.to_str().unwrap_or("")
}
_ => return Ok(()), };
if Self::is_system_directory(second) {
if Self::is_temp_exception(&components, 1) {
return Ok(()); }
return Err(Self::system_directory_error(canonical_path));
}
if second == "private" && components.len() >= 3 {
if let Some(std::path::Component::Normal(third)) = components.get(2) {
let third_str = third.to_str().unwrap_or("");
if Self::is_system_directory(third_str) {
if Self::is_temp_exception(&components, 2) {
return Ok(()); }
return Err(Self::system_directory_error(canonical_path));
}
}
}
Ok(())
}
fn is_system_directory(name: &str) -> bool {
matches!(name, "etc" | "usr" | "root" | "boot" | "sys" | "proc")
}
fn is_temp_exception(components: &[std::path::Component], base_index: usize) -> bool {
if let Some(std::path::Component::Normal(dir)) = components.get(base_index) {
let dir_str = dir.to_str().unwrap_or("");
if dir_str == "tmp" {
return true;
}
if dir_str == "var" {
if let Some(std::path::Component::Normal(subdir)) = components.get(base_index + 1) {
return subdir.to_str().unwrap_or("") == "tmp";
}
}
}
false
}
fn system_directory_error(path: &Path) -> io::Error {
error!("System directory access rejected: {}", path.display());
io::Error::new(
io::ErrorKind::PermissionDenied,
format!(
"Template path not allowed in system directory: {}",
path.display()
),
)
}
}
pub fn resolve_output_dir(
project_name: &str,
custom_output_dir: Option<&Path>,
) -> io::Result<PathBuf> {
let output_path = if let Some(custom_dir) = custom_output_dir {
debug!("Using custom output directory: {}", custom_dir.display());
custom_dir.join(project_name)
} else if let Ok(env_dir) = std::env::var("AGENTERRA_OUTPUT_DIR") {
let env_path = PathBuf::from(env_dir);
debug!("Using AGENTERRA_OUTPUT_DIR: {}", env_path.display());
env_path.join(project_name)
} else {
let current_dir = std::env::current_dir()
.map_err(|e| io::Error::other(format!("Failed to get current directory: {}", e)))?;
let output_dir = current_dir.join(project_name);
debug!("Using default output directory: {}", output_dir.display());
output_dir
};
let absolute_path = if output_path.is_absolute() {
output_path
} else {
std::env::current_dir()
.map_err(|e| io::Error::other(format!("Failed to get current directory: {}", e)))?
.join(output_path)
};
debug!("Resolved output path: {}", absolute_path.display());
Ok(absolute_path)
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tracing_test::traced_test;
pub fn create_test_workspace(test_name: &str) -> std::path::PathBuf {
let workspace_root = std::env::current_dir()
.expect("Failed to get current directory")
.ancestors()
.find(|p| p.join("Cargo.toml").exists())
.expect("Could not find workspace root")
.to_path_buf();
let workspace_dir = workspace_root
.join("target")
.join("tmp")
.join("test-workspaces")
.join(test_name)
.join(uuid::Uuid::new_v4().to_string());
fs::create_dir_all(&workspace_dir).unwrap();
workspace_dir.canonicalize().unwrap_or(workspace_dir)
}
#[test]
fn test_template_dir_validation() {
let temp_dir = create_test_workspace("test_template_dir_validation");
let server_template_dir = temp_dir.join("templates/mcp/server/rust_axum");
fs::create_dir_all(&server_template_dir).unwrap();
let server_template = TemplateDir::discover_with_protocol(
Protocol::Mcp,
ServerTemplateKind::RustAxum,
Some(&server_template_dir),
);
assert!(server_template.is_ok());
assert_eq!(
server_template.unwrap().template_path(),
server_template_dir.as_path()
);
let result = TemplateDir::discover_with_protocol(
Protocol::Mcp,
ServerTemplateKind::RustAxum,
Some(Path::new("/nonexistent")),
);
assert!(result.is_err());
}
#[test]
#[traced_test]
fn test_debug_logging_output() {
let temp_dir = create_test_workspace("test_debug_logging_output");
let server_template_dir = temp_dir.join("templates/mcp/server/rust_axum");
fs::create_dir_all(&server_template_dir).unwrap();
let _result = TemplateDir::discover_with_protocol(
Protocol::Mcp,
ServerTemplateKind::RustAxum,
Some(&temp_dir),
);
assert!(
logs_contain("Auto-discovering template directory")
|| logs_contain("Resolved template path")
);
}
#[test]
fn test_find_template_base_dir_uses_absolute_paths() {
let temp_workspace =
create_test_workspace("test_find_template_base_dir_uses_absolute_paths");
let templates_dir = temp_workspace.join("templates");
let mcp_dir = templates_dir.join("mcp");
let server_dir = mcp_dir.join("server");
let client_dir = mcp_dir.join("client");
fs::create_dir_all(&server_dir).unwrap();
fs::create_dir_all(&client_dir).unwrap();
let mock_config =
MockTemplateConfigReader::new(Some(temp_workspace.to_string_lossy().to_string()));
let result = TemplateDir::find_template_base_dir_with_config(&mock_config);
assert!(result.is_some());
let resolved_path = result.unwrap();
assert!(resolved_path.is_absolute());
assert!(resolved_path.exists());
}
#[test]
fn test_find_template_base_dir_executable_location() {
let temp_workspace =
create_test_workspace("test_find_template_base_dir_executable_location");
let bin_dir = temp_workspace.join("bin");
let templates_dir = temp_workspace.join("templates");
let mcp_dir = templates_dir.join("mcp");
let server_dir = mcp_dir.join("server");
let client_dir = mcp_dir.join("client");
fs::create_dir_all(&bin_dir).unwrap();
fs::create_dir_all(&server_dir).unwrap();
fs::create_dir_all(&client_dir).unwrap();
let mock_config =
MockTemplateConfigReader::new(Some(temp_workspace.to_string_lossy().to_string()));
let result = TemplateDir::find_template_base_dir_with_config(&mock_config);
assert!(result.is_some());
let discovered_path = result.unwrap();
assert!(discovered_path.exists());
}
#[test]
fn test_security_template_dir_validation() {
let malicious_paths = vec![
"/etc/passwd", "/usr/bin/evil", "/root/.ssh/id_rsa", "../../../etc/passwd", "/usr/local/../../etc/passwd", ];
#[cfg(windows)]
let windows_paths = vec!["C:\\Windows\\System32", "C:\\Program Files\\evil"];
#[cfg(windows)]
let all_paths = [malicious_paths, windows_paths].concat();
#[cfg(not(windows))]
let all_paths = malicious_paths;
for path in all_paths {
let mock_config = MockTemplateConfigReader::new(Some(path.to_string()));
let result = TemplateDir::find_template_base_dir_with_config(&mock_config);
assert!(
result.is_none(),
"Malicious path should be rejected: {}",
path
);
}
}
#[test]
fn test_output_directory_traversal_protection() {
let temp_dir = create_test_workspace("test_output_directory_traversal_protection");
let server_template_dir = temp_dir.join("templates/mcp/server/rust_axum");
fs::create_dir_all(&server_template_dir).unwrap();
let malicious_output_paths = vec!["../../../etc", "/etc", "../../sensitive"];
for _path in malicious_output_paths {
}
}
#[test]
#[allow(unsafe_code)] fn test_environment_variable_template_discovery() {
let temp_dir = create_test_workspace("env_var_template_discovery");
let templates_dir = temp_dir.join("templates");
let mcp_dir = templates_dir.join("mcp");
let server_dir = mcp_dir.join("server");
let client_dir = mcp_dir.join("client");
fs::create_dir_all(&server_dir).unwrap();
fs::create_dir_all(&client_dir).unwrap();
let env_config = EnvTemplateConfigReader;
let _no_env_result = env_config.get_template_dir();
unsafe {
std::env::set_var("AGENTERRA_TEMPLATE_DIR", &temp_dir);
}
let with_env_result = env_config.get_template_dir();
assert!(with_env_result.is_some());
assert_eq!(with_env_result.unwrap(), temp_dir.to_string_lossy());
let discovery_result = TemplateDir::find_template_base_dir();
assert!(discovery_result.is_some());
unsafe {
std::env::remove_var("AGENTERRA_TEMPLATE_DIR");
}
let _after_cleanup = env_config.get_template_dir();
}
#[test]
fn test_concurrent_template_discovery() {
use std::sync::{Arc, Barrier};
use std::thread;
let temp_dir = create_test_workspace("test_concurrent_template_discovery");
let server_template_dir = temp_dir.join("templates/mcp/server/rust_axum");
let client_template_dir = temp_dir.join("templates/mcp/client/rust_reqwest");
fs::create_dir_all(&server_template_dir).unwrap();
fs::create_dir_all(&client_template_dir).unwrap();
const NUM_THREADS: usize = 10;
let barrier = Arc::new(Barrier::new(NUM_THREADS));
let mut handles = vec![];
for i in 0..NUM_THREADS {
let barrier_clone = Arc::clone(&barrier);
let temp_dir_path = temp_dir.to_string_lossy().to_string();
let handle = thread::spawn(move || {
barrier_clone.wait();
let mock_config = MockTemplateConfigReader::new(Some(temp_dir_path));
let result = TemplateDir::find_template_base_dir_with_config(&mock_config);
assert!(result.is_some(), "Thread {} failed to discover template", i);
let base_dir = result.unwrap();
assert!(base_dir.exists());
});
handles.push(handle);
}
for handle in handles {
handle.join().expect("Thread panicked");
}
}
#[test]
fn test_discover_with_protocol() {
use crate::core::protocol::Protocol;
let temp_dir = create_test_workspace("test_discover_with_protocol");
let server_template_dir = temp_dir.join("templates/mcp/server/rust_axum");
fs::create_dir_all(&server_template_dir).unwrap();
let result = TemplateDir::discover_with_protocol(
Protocol::Mcp,
ServerTemplateKind::RustAxum,
Some(&server_template_dir), );
assert!(result.is_ok());
assert_eq!(
result.unwrap().template_path(),
server_template_dir.as_path()
);
}
#[test]
fn test_path_construction_with_different_protocols() {
use crate::core::protocol::Protocol;
let temp_dir = create_test_workspace("test_path_construction_with_different_protocols");
let mcp_server_dir = temp_dir.join("templates/mcp/server/rust_axum");
fs::create_dir_all(&mcp_server_dir).unwrap();
let result = TemplateDir::discover_with_protocol(
Protocol::Mcp,
ServerTemplateKind::RustAxum,
Some(&mcp_server_dir),
);
assert!(result.is_ok());
let template_dir = result.unwrap();
let path_str = template_dir.template_path().to_string_lossy();
assert!(
path_str.contains("mcp"),
"Path should contain protocol segment: {}",
path_str
);
assert!(
path_str.contains("server"),
"Path should contain role: {}",
path_str
);
assert!(
path_str.contains("rust_axum"),
"Path should contain template kind: {}",
path_str
);
}
#[test]
fn test_backward_compatibility_with_discover() {
let temp_dir = create_test_workspace("test_backward_compatibility_with_discover");
let server_template_dir = temp_dir.join("templates/mcp/server/rust_axum");
fs::create_dir_all(&server_template_dir).unwrap();
let result = TemplateDir::discover_with_protocol(
Protocol::Mcp,
ServerTemplateKind::RustAxum,
Some(&server_template_dir),
);
assert!(result.is_ok());
assert_eq!(
result.unwrap().template_path(),
server_template_dir.as_path()
);
}
#[test]
fn test_resolve_output_dir_with_custom_path() {
let temp_dir = create_test_workspace("test_resolve_output_dir_with_custom_path");
let custom_output = temp_dir.join("custom_output");
let result = resolve_output_dir("test_project", Some(&custom_output));
assert!(result.is_ok());
let resolved_path = result.unwrap();
assert!(resolved_path.is_absolute());
assert!(resolved_path.ends_with("custom_output/test_project"));
}
#[test]
fn test_resolve_output_dir_with_default() {
let temp_dir = create_test_workspace("test_resolve_output_dir_with_default");
let original_dir = std::env::current_dir().unwrap();
std::env::set_current_dir(&temp_dir).unwrap();
let result = resolve_output_dir("test_project", None);
std::env::set_current_dir(original_dir).unwrap();
assert!(result.is_ok());
let resolved_path = result.unwrap();
assert!(resolved_path.is_absolute());
assert!(resolved_path.to_string_lossy().ends_with("test_project"));
}
#[test]
fn test_resolve_output_dir_fallback_behavior() {
let temp_dir = create_test_workspace("test_resolve_output_dir_fallback_behavior");
let original_dir = std::env::current_dir().unwrap();
std::env::set_current_dir(&temp_dir).unwrap();
let result = resolve_output_dir("fallback_project", None);
std::env::set_current_dir(original_dir).unwrap();
assert!(result.is_ok());
let resolved_path = result.unwrap();
assert!(resolved_path.is_absolute());
assert!(
resolved_path
.to_string_lossy()
.ends_with("fallback_project")
);
}
}