use super::metadata::{ParameterType, PromptMetadata, PromptTemplate};
use anyhow::{Context, Result};
use gray_matter::engine::YAML;
use gray_matter::{Matter, Pod};
use kodegen_mcp_schema::prompt::TemplateParamValue;
use minijinja::Environment;
use std::collections::HashMap;
use std::sync::{LazyLock, OnceLock};
use tokio::time::{timeout, Duration};
static EMPTY_PARAMS: LazyLock<HashMap<String, TemplateParamValue>> = LazyLock::new(HashMap::new);
fn get_max_param_size() -> usize {
static MAX_SIZE: OnceLock<usize> = OnceLock::new();
*MAX_SIZE.get_or_init(|| {
std::env::var("KODEGEN_MAX_PARAM_SIZE")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(1_000_000)
})
}
fn get_max_param_count() -> usize {
static MAX_COUNT: OnceLock<usize> = OnceLock::new();
*MAX_COUNT.get_or_init(|| {
std::env::var("KODEGEN_MAX_PARAM_COUNT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(100)
})
}
fn get_max_total_params_size() -> usize {
static MAX_TOTAL: OnceLock<usize> = OnceLock::new();
*MAX_TOTAL.get_or_init(|| {
std::env::var("KODEGEN_MAX_TOTAL_PARAMS_SIZE")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(10_000_000)
})
}
pub fn parse_template(filename: &str, file_content: &str) -> Result<PromptTemplate> {
let matter = Matter::<YAML>::new();
let parsed: gray_matter::ParsedEntity<Pod> = matter
.parse(file_content)
.map_err(|e| anyhow::anyhow!("Failed to parse frontmatter: {e}"))?;
let metadata: PromptMetadata = parsed
.data
.ok_or_else(|| anyhow::anyhow!("No frontmatter found in template"))?
.deserialize()
.context("Failed to parse YAML frontmatter")?;
validate_metadata(&metadata)?;
let content = parsed.content;
Ok(PromptTemplate {
filename: filename.to_string(),
metadata,
content,
})
}
fn validate_metadata(metadata: &PromptMetadata) -> Result<()> {
if metadata.title.is_empty() {
anyhow::bail!("Title cannot be empty");
}
if metadata.description.is_empty() {
anyhow::bail!("Description cannot be empty");
}
if metadata.categories.is_empty() {
anyhow::bail!("At least one category is required");
}
if metadata.author.is_empty() {
anyhow::bail!("Author cannot be empty");
}
for param in &metadata.parameters {
validate_parameter_definition(param)
.with_context(|| format!("Invalid parameter definition: '{}'", param.name))?;
}
Ok(())
}
fn validate_parameter_definition(param: &super::metadata::ParameterDefinition) -> Result<()> {
if let Some(default) = ¶m.default {
validate_parameter_type(param, default).with_context(|| {
let actual_type = match default {
TemplateParamValue::String(_) => "string",
TemplateParamValue::Number(_) => "number",
TemplateParamValue::Bool(_) => "boolean",
TemplateParamValue::StringArray(_) => "array",
};
format!(
"Parameter '{}' has default value type mismatch. \
Declared as {:?} but default value is {}. \
Default: {:?}\n\
\n\
Fix the template's YAML frontmatter to use the correct type for the default value.",
param.name, param.param_type, actual_type, default
)
})?;
}
if param.required && param.default.is_some() {
anyhow::bail!(
"Parameter '{}' is marked as required but has a default value. \
This is contradictory - remove 'required: true' or remove the default.",
param.name
);
}
Ok(())
}
pub async fn render_template(
template: &PromptTemplate,
parameters: Option<&HashMap<String, TemplateParamValue>>,
) -> Result<String> {
let template_content = template.content.clone();
let template_filename = template.filename.clone();
let ctx = build_context(template, parameters)?;
let render_task = tokio::task::spawn_blocking(move || {
let mut env = Environment::new();
env.set_auto_escape_callback(|_| minijinja::AutoEscape::None);
env.add_template(&template_filename, &template_content)?;
let tmpl = env.get_template(&template_filename)?;
tmpl.render(ctx)
});
match timeout(Duration::from_secs(5), render_task).await {
Ok(Ok(Ok(rendered))) => Ok(rendered),
Ok(Ok(Err(e))) => Err(e.into()),
Ok(Err(e)) => Err(anyhow::anyhow!("Render task panicked: {e}")),
Err(_) => Err(anyhow::anyhow!(
"Template rendering timed out after 5 seconds. \
Template may contain infinite loops, deeply nested constructs, \
or expensive operations. Simplify the template and try again."
)),
}
}
fn build_context(
template: &PromptTemplate,
parameters: Option<&HashMap<String, TemplateParamValue>>,
) -> Result<minijinja::Value> {
let params = parameters.unwrap_or(&EMPTY_PARAMS);
validate_parameter_sizes(params)?;
validate_parameters(template, params)?;
let mut params_with_defaults = apply_defaults(template, params);
add_env_vars(&mut params_with_defaults);
Ok(minijinja::Value::from_serialize(¶ms_with_defaults))
}
fn matches_env_pattern(var_name: &str, pattern: &str) -> bool {
if pattern == "*" {
return true;
}
if pattern.starts_with('*') && pattern.ends_with('*') {
if let Some(stripped) = pattern.strip_prefix('*').and_then(|s| s.strip_suffix('*')) {
var_name.contains(stripped)
} else {
false
}
} else if let Some(suffix) = pattern.strip_prefix('*') {
var_name.ends_with(suffix)
} else if let Some(prefix) = pattern.strip_suffix('*') {
var_name.starts_with(prefix)
} else {
var_name == pattern
}
}
fn load_allowed_env_vars_from_env() -> Vec<String> {
let separator = if cfg!(windows) { ';' } else { ':' };
match std::env::var("KODEGEN_ALLOWED_ENV_VARS") {
Ok(custom) if !custom.is_empty() => {
custom.split(separator)
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect()
}
_ => {
vec![
"USER", "HOME", "SHELL", "PWD", "EDITOR", "TERM",
"USERNAME", "USERPROFILE", "HOMEDRIVE", "HOMEPATH"
]
.into_iter()
.map(String::from)
.collect()
}
}
}
fn load_blocked_env_vars_from_env() -> Vec<String> {
let separator = if cfg!(windows) { ';' } else { ':' };
match std::env::var("KODEGEN_BLOCKED_ENV_VARS") {
Ok(custom) => {
if custom.is_empty() {
vec![] } else {
custom.split(separator)
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect()
}
}
_ => {
vec![
"*_SECRET", "*SECRET*",
"*_PASSWORD", "*PASSWORD*",
"*_TOKEN", "*TOKEN*",
"*_KEY", "*KEY*",
"*_CREDENTIAL", "*CREDENTIAL*",
"*_AUTH", "*AUTH*",
"AWS_SECRET_ACCESS_KEY",
"GITHUB_TOKEN",
"DATABASE_PASSWORD",
]
.into_iter()
.map(String::from)
.collect()
}
}
}
fn add_env_vars(params: &mut HashMap<String, TemplateParamValue>) {
let allowed_patterns = load_allowed_env_vars_from_env();
let blocked_patterns = load_blocked_env_vars_from_env();
let safe_env_vars: Vec<String> = std::env::vars()
.filter(|(key, _)| {
let is_blocked = blocked_patterns
.iter()
.any(|pattern| matches_env_pattern(key, pattern));
if is_blocked {
return false;
}
allowed_patterns
.iter()
.any(|pattern| matches_env_pattern(key, pattern))
})
.map(|(k, v)| format!("{k}={v}"))
.collect();
params.insert("env".to_string(), TemplateParamValue::StringArray(safe_env_vars));
}
fn param_value_size(value: &TemplateParamValue) -> usize {
match value {
TemplateParamValue::String(s) => s.len(),
TemplateParamValue::Number(_) => 8, TemplateParamValue::Bool(_) => 1,
TemplateParamValue::StringArray(arr) => arr.iter().map(|s| s.len()).sum(),
}
}
fn validate_parameter_sizes(params: &HashMap<String, TemplateParamValue>) -> Result<()> {
let max_param_size = get_max_param_size();
let max_param_count = get_max_param_count();
let max_total_size = get_max_total_params_size();
if params.len() > max_param_count {
anyhow::bail!(
"Too many parameters: {} (max {})\n\
Consider: Reducing number of parameters or setting KODEGEN_MAX_PARAM_COUNT",
params.len(),
max_param_count
);
}
let mut total_size = 0;
for (name, value) in params {
let param_size = param_value_size(value);
if param_size > max_param_size {
anyhow::bail!(
"Parameter '{name}' is too large: {param_size} bytes (max {max_param_size} bytes / 1 MB)\n\
\n\
Consider:\n\
- Splitting data into smaller parameters\n\
- Using file references instead of inline data\n\
- Setting KODEGEN_MAX_PARAM_SIZE environment variable if this is legitimate"
);
}
total_size += param_size;
}
if total_size > max_total_size {
anyhow::bail!(
"Total parameter size too large: {total_size} bytes (max {max_total_size} bytes / 10 MB)\n\
\n\
Consider:\n\
- Reducing parameter sizes\n\
- Removing unnecessary parameters\n\
- Setting KODEGEN_MAX_TOTAL_PARAMS_SIZE environment variable"
);
}
Ok(())
}
fn validate_parameters(
template: &PromptTemplate,
params: &HashMap<String, TemplateParamValue>,
) -> Result<()> {
for param_def in &template.metadata.parameters {
if param_def.required && !params.contains_key(¶m_def.name) {
anyhow::bail!(
"Required parameter '{}' not provided. Description: {}",
param_def.name,
param_def.description
);
}
}
for param_def in &template.metadata.parameters {
if let Some(value) = params.get(¶m_def.name) {
validate_parameter_type(param_def, value)?;
}
}
Ok(())
}
fn validate_parameter_type(
param_def: &super::metadata::ParameterDefinition,
value: &TemplateParamValue,
) -> Result<()> {
let valid = match (¶m_def.param_type, value) {
(ParameterType::String, TemplateParamValue::String(_)) => true,
(ParameterType::Number, TemplateParamValue::Number(_)) => true,
(ParameterType::Boolean, TemplateParamValue::Bool(_)) => true,
(ParameterType::Array, TemplateParamValue::StringArray(_)) => true,
_ => false,
};
if !valid {
let actual_type = match value {
TemplateParamValue::String(_) => "string",
TemplateParamValue::Number(_) => "number",
TemplateParamValue::Bool(_) => "boolean",
TemplateParamValue::StringArray(_) => "array",
};
anyhow::bail!(
"Parameter '{}' has wrong type. Expected {:?}, got {}",
param_def.name,
param_def.param_type,
actual_type
);
}
Ok(())
}
fn apply_defaults(
template: &PromptTemplate,
params: &HashMap<String, TemplateParamValue>,
) -> HashMap<String, TemplateParamValue> {
let capacity = params.len() + template.metadata.parameters.len();
let mut result = HashMap::with_capacity(capacity);
for (key, value) in params {
result.insert(key.clone(), value.clone());
}
for param_def in &template.metadata.parameters {
if !result.contains_key(¶m_def.name)
&& let Some(default) = param_def.default.as_ref()
{
result.insert(param_def.name.clone(), default.clone());
}
}
result
}