1pub mod generator;
2
3pub use generator::{generate_handler_registration, generate_param_struct, CodegenError};
4
5use pforge_config::ForgeConfig;
6use std::path::Path;
7
8pub type Result<T> = std::result::Result<T, CodegenError>;
9
10pub fn generate_all(config: &ForgeConfig) -> Result<String> {
12 let mut output = String::new();
13
14 output.push_str("// Auto-generated by pforge\n");
16 output.push_str("// DO NOT EDIT\n\n");
17 output.push_str("use pforge_runtime::*;\n");
18 output.push_str("use serde::{Deserialize, Serialize};\n");
19 output.push_str("use schemars::JsonSchema;\n\n");
20
21 for tool in &config.tools {
23 if let pforge_config::ToolDef::Native { name, params, .. } = tool {
24 output.push_str(&generate_param_struct(name, params)?);
25 output.push_str("\n\n");
26 }
27 }
28
29 output.push_str(&generate_handler_registration(config)?);
31
32 Ok(output)
33}
34
35pub fn write_generated_code(config: &ForgeConfig, output_path: &Path) -> Result<()> {
37 let code = generate_all(config)?;
38 std::fs::write(output_path, code)
39 .map_err(|e| CodegenError::IoError(output_path.to_path_buf(), e))?;
40 Ok(())
41}
42
43#[cfg(test)]
44mod tests {
45 use super::*;
46 use pforge_config::*;
47 use rustc_hash::FxHashMap;
48
49 fn create_test_config() -> ForgeConfig {
50 ForgeConfig {
51 forge: ForgeMetadata {
52 name: "test_server".to_string(),
53 version: "1.0.0".to_string(),
54 transport: TransportType::Stdio,
55 optimization: OptimizationLevel::Debug,
56 },
57 tools: vec![ToolDef::Native {
58 name: "test_tool".to_string(),
59 description: "Test tool".to_string(),
60 handler: HandlerRef {
61 path: "handlers::test_handler".to_string(),
62 inline: None,
63 },
64 params: ParamSchema {
65 fields: {
66 let mut map = FxHashMap::default();
67 map.insert("input".to_string(), ParamType::Simple(SimpleType::String));
68 map
69 },
70 },
71 timeout_ms: None,
72 }],
73 resources: vec![],
74 prompts: vec![],
75 state: None,
76 }
77 }
78
79 #[test]
80 fn test_generate_all() {
81 let config = create_test_config();
82 let result = generate_all(&config);
83
84 assert!(result.is_ok());
85 let code = result.unwrap();
86
87 assert!(code.contains("// Auto-generated by pforge"));
89 assert!(code.contains("// DO NOT EDIT"));
90
91 assert!(code.contains("use pforge_runtime::*"));
93 assert!(code.contains("use serde::{Deserialize, Serialize}"));
94 assert!(code.contains("use schemars::JsonSchema"));
95
96 assert!(code.contains("pub struct TestToolParams"));
98
99 assert!(code.contains("pub fn register_handlers"));
101 }
102
103 #[test]
104 fn test_generate_all_empty_tools() {
105 let config = ForgeConfig {
106 forge: ForgeMetadata {
107 name: "empty".to_string(),
108 version: "1.0.0".to_string(),
109 transport: TransportType::Stdio,
110 optimization: OptimizationLevel::Debug,
111 },
112 tools: vec![],
113 resources: vec![],
114 prompts: vec![],
115 state: None,
116 };
117
118 let result = generate_all(&config);
119 assert!(result.is_ok());
120 let code = result.unwrap();
121 assert!(code.contains("pub fn register_handlers"));
122 }
123
124 #[test]
125 fn test_write_generated_code() {
126 let config = create_test_config();
127 let temp_dir = std::env::temp_dir();
128 let output_path = temp_dir.join("test_generated.rs");
129
130 let result = write_generated_code(&config, &output_path);
131 assert!(result.is_ok());
132
133 assert!(output_path.exists());
135
136 let content = std::fs::read_to_string(&output_path).unwrap();
138 assert!(content.contains("pub struct TestToolParams"));
139
140 std::fs::remove_file(&output_path).ok();
142 }
143
144 #[test]
145 fn test_write_generated_code_invalid_path() {
146 let config = create_test_config();
147 let invalid_path = Path::new("/nonexistent/directory/test.rs");
148
149 let result = write_generated_code(&config, invalid_path);
150 assert!(result.is_err());
151 assert!(matches!(result.unwrap_err(), CodegenError::IoError(_, _)));
152 }
153}