use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use thiserror::Error;
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct GeneratorTemplate {
pub generator: GeneratorMeta,
#[serde(default)]
pub params: Vec<ParamSpec>,
pub template: TemplateSpec,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct GeneratorMeta {
pub id: String,
pub name: String,
pub description: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub category: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct ParamSpec {
pub name: String,
pub description: String,
#[serde(default)]
pub required: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub default: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub example: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct TemplateSpec {
pub code: String,
#[serde(default = "default_target_file")]
pub target_file: String,
#[serde(default)]
pub position: InsertPosition,
}
fn default_target_file() -> String {
"src/lib.rs".to_string()
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum InsertPosition {
Top,
#[default]
Bottom,
After(String),
Before(String),
}
#[derive(Debug, Error)]
pub enum GeneratorLoadError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("YAML parse error: {0}")]
Yaml(#[from] serde_yaml::Error),
#[error("JSON parse error: {0}")]
Json(#[from] serde_json::Error),
}
#[derive(Debug, Clone, Error)]
pub enum RenderError {
#[error("missing required parameters: {}", .0.join(", "))]
MissingParams(Vec<String>),
}
impl GeneratorTemplate {
pub fn id(&self) -> &str {
&self.generator.id
}
pub fn name(&self) -> &str {
&self.generator.name
}
pub fn description(&self) -> &str {
&self.generator.description
}
pub fn category(&self) -> Option<&str> {
self.generator.category.as_deref()
}
pub fn is_param_required(&self, name: &str) -> bool {
self.params
.iter()
.find(|p| p.name == name)
.map(|p| p.required)
.unwrap_or(false)
}
pub fn get_param_default(&self, name: &str) -> Option<&str> {
self.params
.iter()
.find(|p| p.name == name)
.and_then(|p| p.default.as_deref())
}
pub fn validate_params(&self, params: &HashMap<String, String>) -> Result<(), Vec<String>> {
let missing: Vec<String> = self
.params
.iter()
.filter(|p| p.required && !params.contains_key(&p.name) && p.default.is_none())
.map(|p| p.name.clone())
.collect();
if missing.is_empty() {
Ok(())
} else {
Err(missing)
}
}
pub fn render(&self, params: &HashMap<String, String>) -> Result<String, RenderError> {
if let Err(missing) = self.validate_params(params) {
return Err(RenderError::MissingParams(missing));
}
let mut complete_params = HashMap::new();
for spec in &self.params {
if let Some(value) = params.get(&spec.name) {
complete_params.insert(spec.name.clone(), value.clone());
} else if let Some(default) = &spec.default {
complete_params.insert(spec.name.clone(), default.clone());
}
}
let mut result = self.template.code.clone();
for (key, value) in &complete_params {
let placeholder = format!("{{{{{}}}}}", key);
result = result.replace(&placeholder, value);
}
Ok(result)
}
pub fn render_target_file(&self, params: &HashMap<String, String>) -> String {
let mut result = self.template.target_file.clone();
for (key, value) in params {
let placeholder = format!("{{{{{}}}}}", key);
result = result.replace(&placeholder, value);
}
result
}
}
pub struct GeneratorLoader;
impl GeneratorLoader {
pub fn load_file(path: impl AsRef<Path>) -> Result<GeneratorTemplate, GeneratorLoadError> {
let path = path.as_ref();
let content = std::fs::read_to_string(path)?;
Self::load_from_str(&content, path)
}
pub fn load_from_str(
content: &str,
path: impl AsRef<Path>,
) -> Result<GeneratorTemplate, GeneratorLoadError> {
let path = path.as_ref();
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
match ext {
"yaml" | "yml" => Self::from_yaml(content),
"json" => Self::from_json(content),
_ => {
Self::from_yaml(content).or_else(|_| Self::from_json(content))
}
}
}
pub fn from_yaml(yaml: &str) -> Result<GeneratorTemplate, GeneratorLoadError> {
Ok(serde_yaml::from_str(yaml)?)
}
pub fn from_json(json: &str) -> Result<GeneratorTemplate, GeneratorLoadError> {
Ok(serde_json::from_str(json)?)
}
pub fn load_dir(dir: impl AsRef<Path>) -> Result<Vec<GeneratorTemplate>, GeneratorLoadError> {
let dir = dir.as_ref();
let mut templates = Vec::new();
if !dir.exists() {
return Ok(templates);
}
for entry in std::fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
if matches!(ext, "yaml" | "yml" | "json") {
match Self::load_file(&path) {
Ok(template) => templates.push(template),
Err(e) => {
eprintln!(
"Warning: Failed to load generator from {}: {}",
path.display(),
e
);
}
}
}
}
Ok(templates)
}
}
#[cfg(test)]
mod tests {
use super::*;
const EXAMPLE_YAML: &str = r#"
generator:
id: GEN001
name: domain_struct
description: Generate a domain struct
category: domain
params:
- name: name
description: Struct name
required: true
- name: module
description: Target module
required: false
default: src/lib.rs
template:
code: |
#[derive(Debug, Clone)]
pub struct {{name}} {
pub id: String,
}
"#;
#[test]
fn test_parse_template() {
let template = GeneratorLoader::from_yaml(EXAMPLE_YAML).unwrap();
assert_eq!(template.id(), "GEN001");
assert_eq!(template.name(), "domain_struct");
assert_eq!(template.params.len(), 2);
assert!(template.is_param_required("name"));
assert!(!template.is_param_required("module"));
}
#[test]
fn test_render_template() {
let template = GeneratorLoader::from_yaml(EXAMPLE_YAML).unwrap();
let mut params = HashMap::new();
params.insert("name".to_string(), "Order".to_string());
let rendered = template.render(¶ms).unwrap();
assert!(rendered.contains("pub struct Order"));
}
#[test]
fn test_missing_required_param() {
let template = GeneratorLoader::from_yaml(EXAMPLE_YAML).unwrap();
let params = HashMap::new();
let result = template.render(¶ms);
assert!(result.is_err());
}
#[test]
fn test_validate_params() {
let template = GeneratorLoader::from_yaml(EXAMPLE_YAML).unwrap();
let params = HashMap::new();
assert!(template.validate_params(¶ms).is_err());
let mut params = HashMap::new();
params.insert("name".to_string(), "Test".to_string());
assert!(template.validate_params(¶ms).is_ok());
}
#[test]
fn test_json_format() {
let json = r#"{
"generator": {
"id": "GEN002",
"name": "api_endpoint",
"description": "Generate API endpoint"
},
"params": [
{"name": "resource", "description": "Resource name", "required": true}
],
"template": {
"code": "pub fn get_{{resource}}() {}"
}
}"#;
let template = GeneratorLoader::from_json(json).unwrap();
assert_eq!(template.id(), "GEN002");
}
}