1use pforge_config::{ForgeConfig, ParamSchema, ParamType, SimpleType};
2use std::path::PathBuf;
3use thiserror::Error;
4
5#[derive(Debug, Error)]
6pub enum CodegenError {
7 #[error("IO error: {0}: {1}")]
8 IoError(PathBuf, #[source] std::io::Error),
9
10 #[error("Generation error: {0}")]
11 GenerationError(String),
12}
13
14pub type Result<T> = std::result::Result<T, CodegenError>;
15
16pub fn generate_param_struct(tool_name: &str, params: &ParamSchema) -> Result<String> {
18 let struct_name = format!("{}Params", to_pascal_case(tool_name));
19 let mut output = String::new();
20
21 output.push_str("#[derive(Debug, Deserialize, JsonSchema)]\n");
23 output.push_str(&format!("pub struct {} {{\n", struct_name));
24
25 for (field_name, param_type) in ¶ms.fields {
26 let (ty, required, description) = match param_type {
28 ParamType::Simple(simple_ty) => (rust_type_from_simple(simple_ty), true, None),
29 ParamType::Complex {
30 ty,
31 required,
32 description,
33 ..
34 } => (rust_type_from_simple(ty), *required, description.clone()),
35 };
36
37 if let Some(desc) = description {
39 output.push_str(&format!(" /// {}\n", desc));
40 }
41
42 if required {
44 output.push_str(&format!(" pub {}: {},\n", field_name, ty));
45 } else {
46 output.push_str(&format!(" pub {}: Option<{}>,\n", field_name, ty));
47 }
48 }
49
50 output.push_str("}\n");
51
52 Ok(output)
53}
54
55pub fn generate_handler_registration(config: &ForgeConfig) -> Result<String> {
57 let mut output = String::new();
58
59 output.push_str("pub fn register_handlers(registry: &mut HandlerRegistry) {\n");
60
61 for tool in &config.tools {
62 match tool {
63 pforge_config::ToolDef::Native { name, handler, .. } => {
64 let handler_path = &handler.path;
66 output.push_str(&format!(
67 " registry.register(\"{}\", {});\n",
68 name, handler_path
69 ));
70 }
71 pforge_config::ToolDef::Cli {
72 name,
73 command,
74 args,
75 cwd,
76 env: _,
77 stream,
78 description: _,
79 } => {
80 output.push_str(&format!(
81 " registry.register(\"{}\", CliHandler::new(\n",
82 name
83 ));
84 output.push_str(&format!(" \"{}\".to_string(),\n", command));
85 output.push_str(&format!(" vec![{}],\n", format_string_vec(args)));
86
87 if let Some(cwd_val) = cwd {
88 output.push_str(&format!(" Some(\"{}\".to_string()),\n", cwd_val));
89 } else {
90 output.push_str(" None,\n");
91 }
92
93 output.push_str(" HashMap::new(), // env\n");
94 output.push_str(" None, // timeout\n");
95 output.push_str(&format!(" {},\n", stream));
96 output.push_str(" ));\n");
97 }
98 pforge_config::ToolDef::Http {
99 name,
100 endpoint,
101 method,
102 headers: _,
103 auth: _,
104 description: _,
105 } => {
106 output.push_str(&format!(
107 " registry.register(\"{}\", HttpHandler::new(\n",
108 name
109 ));
110 output.push_str(&format!(" \"{}\".to_string(),\n", endpoint));
111 output.push_str(&format!(" HttpMethod::{:?},\n", method));
112 output.push_str(" HashMap::new(), // headers\n");
113 output.push_str(" None, // auth\n");
114 output.push_str(" ));\n");
115 }
116 pforge_config::ToolDef::Pipeline {
117 name: _,
118 steps: _,
119 description: _,
120 } => {
121 output.push_str(" // Pipeline handler TBD\n");
122 }
123 }
124 }
125
126 output.push_str("}\n");
127
128 Ok(output)
129}
130
131fn to_pascal_case(s: &str) -> String {
132 s.split('_')
133 .map(|word| {
134 let mut chars = word.chars();
135 match chars.next() {
136 None => String::new(),
137 Some(first) => first.to_uppercase().chain(chars).collect(),
138 }
139 })
140 .collect()
141}
142
143fn rust_type_from_simple(ty: &SimpleType) -> &'static str {
144 match ty {
145 SimpleType::String => "String",
146 SimpleType::Integer => "i64",
147 SimpleType::Float => "f64",
148 SimpleType::Boolean => "bool",
149 SimpleType::Array => "Vec<serde_json::Value>",
150 SimpleType::Object => "serde_json::Value",
151 }
152}
153
154fn format_string_vec(vec: &[String]) -> String {
155 vec.iter()
156 .map(|s| format!("\"{}\".to_string()", s))
157 .collect::<Vec<_>>()
158 .join(", ")
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use pforge_config::*;
165 use std::collections::HashMap;
166
167 #[test]
168 fn test_to_pascal_case() {
169 assert_eq!(to_pascal_case("hello_world"), "HelloWorld");
170 assert_eq!(to_pascal_case("test"), "Test");
171 assert_eq!(to_pascal_case("foo_bar_baz"), "FooBarBaz");
172 }
173
174 #[test]
175 fn test_rust_type_from_simple() {
176 assert_eq!(rust_type_from_simple(&SimpleType::String), "String");
177 assert_eq!(rust_type_from_simple(&SimpleType::Integer), "i64");
178 assert_eq!(rust_type_from_simple(&SimpleType::Float), "f64");
179 assert_eq!(rust_type_from_simple(&SimpleType::Boolean), "bool");
180 assert_eq!(
181 rust_type_from_simple(&SimpleType::Array),
182 "Vec<serde_json::Value>"
183 );
184 assert_eq!(
185 rust_type_from_simple(&SimpleType::Object),
186 "serde_json::Value"
187 );
188 }
189
190 #[test]
191 fn test_format_string_vec() {
192 assert_eq!(
193 format_string_vec(&["foo".to_string(), "bar".to_string()]),
194 "\"foo\".to_string(), \"bar\".to_string()"
195 );
196 assert_eq!(format_string_vec(&[]), "");
197 }
198
199 #[test]
200 fn test_generate_param_struct_simple() {
201 let mut fields = HashMap::new();
202 fields.insert("name".to_string(), ParamType::Simple(SimpleType::String));
203 fields.insert("age".to_string(), ParamType::Simple(SimpleType::Integer));
204
205 let params = ParamSchema { fields };
206 let result = generate_param_struct("test_tool", ¶ms);
207
208 assert!(result.is_ok());
209 let code = result.unwrap();
210 assert!(code.contains("pub struct TestToolParams"));
211 assert!(code.contains("pub name: String"));
212 assert!(code.contains("pub age: i64"));
213 }
214
215 #[test]
216 fn test_generate_param_struct_complex() {
217 let mut fields = HashMap::new();
218 fields.insert(
219 "optional_field".to_string(),
220 ParamType::Complex {
221 ty: SimpleType::String,
222 required: false,
223 description: Some("An optional field".to_string()),
224 default: None,
225 validation: None,
226 },
227 );
228
229 let params = ParamSchema { fields };
230 let result = generate_param_struct("my_tool", ¶ms);
231
232 assert!(result.is_ok());
233 let code = result.unwrap();
234 assert!(code.contains("/// An optional field"));
235 assert!(code.contains("pub optional_field: Option<String>"));
236 }
237
238 #[test]
239 fn test_generate_handler_registration_native() {
240 let config = ForgeConfig {
241 forge: ForgeMetadata {
242 name: "test".to_string(),
243 version: "1.0.0".to_string(),
244 transport: TransportType::Stdio,
245 optimization: OptimizationLevel::Debug,
246 },
247 tools: vec![ToolDef::Native {
248 name: "test_tool".to_string(),
249 description: "Test".to_string(),
250 handler: HandlerRef {
251 path: "handlers::test_handler".to_string(),
252 inline: None,
253 },
254 params: ParamSchema {
255 fields: HashMap::new(),
256 },
257 timeout_ms: None,
258 }],
259 resources: vec![],
260 prompts: vec![],
261 state: None,
262 };
263
264 let result = generate_handler_registration(&config);
265 assert!(result.is_ok());
266 let code = result.unwrap();
267 assert!(code.contains("pub fn register_handlers"));
268 assert!(code.contains("registry.register(\"test_tool\", handlers::test_handler)"));
269 }
270
271 #[test]
272 fn test_generate_handler_registration_cli() {
273 let config = ForgeConfig {
274 forge: ForgeMetadata {
275 name: "test".to_string(),
276 version: "1.0.0".to_string(),
277 transport: TransportType::Stdio,
278 optimization: OptimizationLevel::Debug,
279 },
280 tools: vec![ToolDef::Cli {
281 name: "cli_tool".to_string(),
282 description: "CLI Test".to_string(),
283 command: "echo".to_string(),
284 args: vec!["hello".to_string()],
285 cwd: None,
286 env: HashMap::new(),
287 stream: false,
288 }],
289 resources: vec![],
290 prompts: vec![],
291 state: None,
292 };
293
294 let result = generate_handler_registration(&config);
295 assert!(result.is_ok());
296 let code = result.unwrap();
297 assert!(code.contains("CliHandler::new"));
298 assert!(code.contains("\"echo\""));
299 assert!(code.contains("\"hello\""));
300 }
301
302 #[test]
303 fn test_generate_handler_registration_http() {
304 let config = ForgeConfig {
305 forge: ForgeMetadata {
306 name: "test".to_string(),
307 version: "1.0.0".to_string(),
308 transport: TransportType::Stdio,
309 optimization: OptimizationLevel::Debug,
310 },
311 tools: vec![ToolDef::Http {
312 name: "http_tool".to_string(),
313 description: "HTTP Test".to_string(),
314 endpoint: "https://api.example.com".to_string(),
315 method: HttpMethod::Get,
316 headers: HashMap::new(),
317 auth: None,
318 }],
319 resources: vec![],
320 prompts: vec![],
321 state: None,
322 };
323
324 let result = generate_handler_registration(&config);
325 assert!(result.is_ok());
326 let code = result.unwrap();
327 assert!(code.contains("HttpHandler::new"));
328 assert!(code.contains("https://api.example.com"));
329 assert!(code.contains("HttpMethod::Get"));
330 }
331}