1pub mod generator;
2
3pub use generator::{generate_handler_registration, generate_param_struct, CodegenError};
4
5use pforge_config::ForgeConfig;
6use std::path::Path;
7
8pub type Result<T> = std::result::Result<T, CodegenError>;
9
10pub fn generate_all(config: &ForgeConfig) -> Result<String> {
12 let mut output = String::new();
13
14 output.push_str("// Auto-generated by pforge\n");
16 output.push_str("// DO NOT EDIT\n\n");
17 output.push_str("use pforge_runtime::*;\n");
18 output.push_str("use serde::{Deserialize, Serialize};\n");
19 output.push_str("use schemars::JsonSchema;\n\n");
20
21 for tool in &config.tools {
23 if let pforge_config::ToolDef::Native { name, params, .. } = tool {
24 output.push_str(&generate_param_struct(name, params)?);
25 output.push_str("\n\n");
26 }
27 }
28
29 output.push_str(&generate_handler_registration(config)?);
31
32 Ok(output)
33}
34
35pub fn write_generated_code(config: &ForgeConfig, output_path: &Path) -> Result<()> {
37 let code = generate_all(config)?;
38 std::fs::write(output_path, code)
39 .map_err(|e| CodegenError::IoError(output_path.to_path_buf(), e))?;
40 Ok(())
41}
42
43#[cfg(test)]
44mod tests {
45 use super::*;
46 use pforge_config::*;
47 use std::collections::HashMap;
48
49 fn create_test_config() -> ForgeConfig {
50 ForgeConfig {
51 forge: ForgeMetadata {
52 name: "test_server".to_string(),
53 version: "1.0.0".to_string(),
54 transport: TransportType::Stdio,
55 optimization: OptimizationLevel::Debug,
56 },
57 tools: vec![
58 ToolDef::Native {
59 name: "test_tool".to_string(),
60 description: "Test tool".to_string(),
61 handler: HandlerRef {
62 path: "handlers::test_handler".to_string(),
63 inline: None,
64 },
65 params: ParamSchema {
66 fields: {
67 let mut map = HashMap::new();
68 map.insert("input".to_string(), ParamType::Simple(SimpleType::String));
69 map
70 },
71 },
72 timeout_ms: None,
73 }
74 ],
75 resources: vec![],
76 prompts: vec![],
77 state: None,
78 }
79 }
80
81 #[test]
82 fn test_generate_all() {
83 let config = create_test_config();
84 let result = generate_all(&config);
85
86 assert!(result.is_ok());
87 let code = result.unwrap();
88
89 assert!(code.contains("// Auto-generated by pforge"));
91 assert!(code.contains("// DO NOT EDIT"));
92
93 assert!(code.contains("use pforge_runtime::*"));
95 assert!(code.contains("use serde::{Deserialize, Serialize}"));
96 assert!(code.contains("use schemars::JsonSchema"));
97
98 assert!(code.contains("pub struct TestToolParams"));
100
101 assert!(code.contains("pub fn register_handlers"));
103 }
104
105 #[test]
106 fn test_generate_all_empty_tools() {
107 let config = ForgeConfig {
108 forge: ForgeMetadata {
109 name: "empty".to_string(),
110 version: "1.0.0".to_string(),
111 transport: TransportType::Stdio,
112 optimization: OptimizationLevel::Debug,
113 },
114 tools: vec![],
115 resources: vec![],
116 prompts: vec![],
117 state: None,
118 };
119
120 let result = generate_all(&config);
121 assert!(result.is_ok());
122 let code = result.unwrap();
123 assert!(code.contains("pub fn register_handlers"));
124 }
125
126 #[test]
127 fn test_write_generated_code() {
128 let config = create_test_config();
129 let temp_dir = std::env::temp_dir();
130 let output_path = temp_dir.join("test_generated.rs");
131
132 let result = write_generated_code(&config, &output_path);
133 assert!(result.is_ok());
134
135 assert!(output_path.exists());
137
138 let content = std::fs::read_to_string(&output_path).unwrap();
140 assert!(content.contains("pub struct TestToolParams"));
141
142 std::fs::remove_file(&output_path).ok();
144 }
145
146 #[test]
147 fn test_write_generated_code_invalid_path() {
148 let config = create_test_config();
149 let invalid_path = Path::new("/nonexistent/directory/test.rs");
150
151 let result = write_generated_code(&config, invalid_path);
152 assert!(result.is_err());
153 assert!(matches!(result.unwrap_err(), CodegenError::IoError(_, _)));
154 }
155}