1use pforge_config::{ForgeConfig, ParamSchema, ParamType, SimpleType};
2use std::path::PathBuf;
3use thiserror::Error;
4
5#[derive(Debug, Error)]
6pub enum CodegenError {
7 #[error("IO error: {0}: {1}")]
8 IoError(PathBuf, #[source] std::io::Error),
9
10 #[error("Generation error: {0}")]
11 GenerationError(String),
12}
13
14pub type Result<T> = std::result::Result<T, CodegenError>;
15
16pub fn generate_param_struct(tool_name: &str, params: &ParamSchema) -> Result<String> {
18 let struct_name = format!("{}Params", to_pascal_case(tool_name));
19 let mut output = String::new();
20
21 output.push_str("#[derive(Debug, Deserialize, JsonSchema)]\n");
23 output.push_str(&format!("pub struct {} {{\n", struct_name));
24
25 for (field_name, param_type) in ¶ms.fields {
26 let (ty, required, description) = match param_type {
28 ParamType::Simple(simple_ty) => (rust_type_from_simple(simple_ty), true, None),
29 ParamType::Complex {
30 ty,
31 required,
32 description,
33 ..
34 } => (rust_type_from_simple(ty), *required, description.clone()),
35 };
36
37 if let Some(desc) = description {
39 output.push_str(&format!(" /// {}\n", desc));
40 }
41
42 if required {
44 output.push_str(&format!(" pub {}: {},\n", field_name, ty));
45 } else {
46 output.push_str(&format!(" pub {}: Option<{}>,\n", field_name, ty));
47 }
48 }
49
50 output.push_str("}\n");
51
52 Ok(output)
53}
54
55pub fn generate_handler_registration(config: &ForgeConfig) -> Result<String> {
57 let mut output = String::new();
58
59 output.push_str("pub fn register_handlers(registry: &mut HandlerRegistry) {\n");
60
61 for tool in &config.tools {
62 match tool {
63 pforge_config::ToolDef::Native { name, handler, .. } => {
64 let handler_path = &handler.path;
66 output.push_str(&format!(
67 " registry.register(\"{}\", {});\n",
68 name, handler_path
69 ));
70 }
71 pforge_config::ToolDef::Cli {
72 name,
73 command,
74 args,
75 cwd,
76 env: _,
77 stream,
78 ..
79 } => {
80 output.push_str(&format!(
81 " registry.register(\"{}\", CliHandler::new(\n",
82 name
83 ));
84 output.push_str(&format!(" \"{}\".to_string(),\n", command));
85 output.push_str(&format!(" vec![{}],\n", format_string_vec(args)));
86
87 if let Some(cwd_val) = cwd {
88 output.push_str(&format!(" Some(\"{}\".to_string()),\n", cwd_val));
89 } else {
90 output.push_str(" None,\n");
91 }
92
93 output.push_str(" FxHashMap::default(), // env\n");
94 output.push_str(" None, // timeout\n");
95 output.push_str(&format!(" {},\n", stream));
96 output.push_str(" ));\n");
97 }
98 pforge_config::ToolDef::Http {
99 name,
100 endpoint,
101 method,
102 ..
103 } => {
104 output.push_str(&format!(
105 " registry.register(\"{}\", HttpHandler::new(\n",
106 name
107 ));
108 output.push_str(&format!(" \"{}\".to_string(),\n", endpoint));
109 output.push_str(&format!(" HttpMethod::{:?},\n", method));
110 output.push_str(" FxHashMap::default(), // headers\n");
111 output.push_str(" None, // auth\n");
112 output.push_str(" ));\n");
113 }
114 pforge_config::ToolDef::Pipeline {
115 name,
116 steps,
117 description: _,
118 } => {
119 output.push_str(&format!(
120 " registry.register(\"{}\", PipelineHandler::new(vec![\n",
121 name
122 ));
123 for step in steps {
124 output.push_str(" PipelineStep {\n");
125 output.push_str(&format!(
126 " tool: \"{}\".to_string(),\n",
127 step.tool
128 ));
129 if let Some(input_val) = &step.input {
131 output.push_str(&format!(
132 " input: Some(serde_json::json!({})),\n",
133 serde_json::to_string(input_val).unwrap_or_else(|_| "{}".to_string())
134 ));
135 } else {
136 output.push_str(" input: None,\n");
137 }
138 if let Some(var) = &step.output_var {
140 output.push_str(&format!(
141 " output_var: Some(\"{}\".to_string()),\n",
142 var
143 ));
144 } else {
145 output.push_str(" output_var: None,\n");
146 }
147 if let Some(cond) = &step.condition {
149 output.push_str(&format!(
150 " condition: Some(\"{}\".to_string()),\n",
151 cond
152 ));
153 } else {
154 output.push_str(" condition: None,\n");
155 }
156 let policy = match step.error_policy {
158 pforge_config::ErrorPolicy::FailFast => "ErrorPolicy::FailFast",
159 pforge_config::ErrorPolicy::Continue => "ErrorPolicy::Continue",
160 };
161 output.push_str(&format!(" error_policy: {},\n", policy));
162 output.push_str(" },\n");
163 }
164 output.push_str(" ]));\n");
165 }
166 }
167 }
168
169 output.push_str("}\n");
170
171 Ok(output)
172}
173
174fn to_pascal_case(s: &str) -> String {
175 s.split('_')
176 .map(|word| {
177 let mut chars = word.chars();
178 match chars.next() {
179 None => String::new(),
180 Some(first) => first.to_uppercase().chain(chars).collect(),
181 }
182 })
183 .collect()
184}
185
186fn rust_type_from_simple(ty: &SimpleType) -> &'static str {
187 match ty {
188 SimpleType::String => "String",
189 SimpleType::Integer => "i64",
190 SimpleType::Float => "f64",
191 SimpleType::Boolean => "bool",
192 SimpleType::Array => "Vec<serde_json::Value>",
193 SimpleType::Object => "serde_json::Value",
194 }
195}
196
197fn format_string_vec(vec: &[String]) -> String {
198 vec.iter()
199 .map(|s| format!("\"{}\".to_string()", s))
200 .collect::<Vec<_>>()
201 .join(", ")
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use pforge_config::*;
208 use rustc_hash::FxHashMap;
209
210 #[test]
211 fn test_to_pascal_case() {
212 assert_eq!(to_pascal_case("hello_world"), "HelloWorld");
213 assert_eq!(to_pascal_case("test"), "Test");
214 assert_eq!(to_pascal_case("foo_bar_baz"), "FooBarBaz");
215 }
216
217 #[test]
218 fn test_rust_type_from_simple() {
219 assert_eq!(rust_type_from_simple(&SimpleType::String), "String");
220 assert_eq!(rust_type_from_simple(&SimpleType::Integer), "i64");
221 assert_eq!(rust_type_from_simple(&SimpleType::Float), "f64");
222 assert_eq!(rust_type_from_simple(&SimpleType::Boolean), "bool");
223 assert_eq!(
224 rust_type_from_simple(&SimpleType::Array),
225 "Vec<serde_json::Value>"
226 );
227 assert_eq!(
228 rust_type_from_simple(&SimpleType::Object),
229 "serde_json::Value"
230 );
231 }
232
233 #[test]
234 fn test_format_string_vec() {
235 assert_eq!(
236 format_string_vec(&["foo".to_string(), "bar".to_string()]),
237 "\"foo\".to_string(), \"bar\".to_string()"
238 );
239 assert_eq!(format_string_vec(&[]), "");
240 }
241
242 #[test]
243 fn test_generate_param_struct_simple() {
244 let mut fields = FxHashMap::default();
245 fields.insert("name".to_string(), ParamType::Simple(SimpleType::String));
246 fields.insert("age".to_string(), ParamType::Simple(SimpleType::Integer));
247
248 let params = ParamSchema { fields };
249 let result = generate_param_struct("test_tool", ¶ms);
250
251 assert!(result.is_ok());
252 let code = result.unwrap();
253 assert!(code.contains("pub struct TestToolParams"));
254 assert!(code.contains("pub name: String"));
255 assert!(code.contains("pub age: i64"));
256 }
257
258 #[test]
259 fn test_generate_param_struct_complex() {
260 let mut fields = FxHashMap::default();
261 fields.insert(
262 "optional_field".to_string(),
263 ParamType::Complex {
264 ty: SimpleType::String,
265 required: false,
266 description: Some("An optional field".to_string()),
267 default: None,
268 validation: None,
269 },
270 );
271
272 let params = ParamSchema { fields };
273 let result = generate_param_struct("my_tool", ¶ms);
274
275 assert!(result.is_ok());
276 let code = result.unwrap();
277 assert!(code.contains("/// An optional field"));
278 assert!(code.contains("pub optional_field: Option<String>"));
279 }
280
281 #[test]
282 fn test_generate_handler_registration_native() {
283 let config = ForgeConfig {
284 forge: ForgeMetadata {
285 name: "test".to_string(),
286 version: "1.0.0".to_string(),
287 transport: TransportType::Stdio,
288 optimization: OptimizationLevel::Debug,
289 },
290 tools: vec![ToolDef::Native {
291 name: "test_tool".to_string(),
292 description: "Test".to_string(),
293 handler: HandlerRef {
294 path: "handlers::test_handler".to_string(),
295 inline: None,
296 },
297 params: ParamSchema {
298 fields: FxHashMap::default(),
299 },
300 timeout_ms: None,
301 }],
302 resources: vec![],
303 prompts: vec![],
304 state: None,
305 };
306
307 let result = generate_handler_registration(&config);
308 assert!(result.is_ok());
309 let code = result.unwrap();
310 assert!(code.contains("pub fn register_handlers"));
311 assert!(code.contains("registry.register(\"test_tool\", handlers::test_handler)"));
312 }
313
314 #[test]
315 fn test_generate_handler_registration_cli() {
316 let config = ForgeConfig {
317 forge: ForgeMetadata {
318 name: "test".to_string(),
319 version: "1.0.0".to_string(),
320 transport: TransportType::Stdio,
321 optimization: OptimizationLevel::Debug,
322 },
323 tools: vec![ToolDef::Cli {
324 name: "cli_tool".to_string(),
325 description: "CLI Test".to_string(),
326 command: "echo".to_string(),
327 args: vec!["hello".to_string()],
328 cwd: None,
329 env: FxHashMap::default(),
330 stream: false,
331 timeout_ms: None,
332 }],
333 resources: vec![],
334 prompts: vec![],
335 state: None,
336 };
337
338 let result = generate_handler_registration(&config);
339 assert!(result.is_ok());
340 let code = result.unwrap();
341 assert!(code.contains("CliHandler::new"));
342 assert!(code.contains("\"echo\""));
343 assert!(code.contains("\"hello\""));
344 }
345
346 #[test]
347 fn test_generate_handler_registration_http() {
348 let config = ForgeConfig {
349 forge: ForgeMetadata {
350 name: "test".to_string(),
351 version: "1.0.0".to_string(),
352 transport: TransportType::Stdio,
353 optimization: OptimizationLevel::Debug,
354 },
355 tools: vec![ToolDef::Http {
356 name: "http_tool".to_string(),
357 description: "HTTP Test".to_string(),
358 endpoint: "https://api.example.com".to_string(),
359 method: HttpMethod::Get,
360 headers: FxHashMap::default(),
361 auth: None,
362 timeout_ms: None,
363 }],
364 resources: vec![],
365 prompts: vec![],
366 state: None,
367 };
368
369 let result = generate_handler_registration(&config);
370 assert!(result.is_ok());
371 let code = result.unwrap();
372 assert!(code.contains("HttpHandler::new"));
373 assert!(code.contains("https://api.example.com"));
374 assert!(code.contains("HttpMethod::Get"));
375 }
376
377 #[test]
378 fn test_generate_handler_registration_pipeline() {
379 let config = ForgeConfig {
380 forge: ForgeMetadata {
381 name: "test".to_string(),
382 version: "1.0.0".to_string(),
383 transport: TransportType::Stdio,
384 optimization: OptimizationLevel::Debug,
385 },
386 tools: vec![ToolDef::Pipeline {
387 name: "pipeline_tool".to_string(),
388 description: "Pipeline Test".to_string(),
389 steps: vec![
390 pforge_config::PipelineStep {
391 tool: "step1".to_string(),
392 input: Some(serde_json::json!({"key": "value"})),
393 output_var: Some("result".to_string()),
394 condition: None,
395 error_policy: pforge_config::ErrorPolicy::FailFast,
396 },
397 pforge_config::PipelineStep {
398 tool: "step2".to_string(),
399 input: None,
400 output_var: None,
401 condition: Some("result".to_string()),
402 error_policy: pforge_config::ErrorPolicy::Continue,
403 },
404 ],
405 }],
406 resources: vec![],
407 prompts: vec![],
408 state: None,
409 };
410
411 let result = generate_handler_registration(&config);
412 assert!(result.is_ok());
413 let code = result.unwrap();
414 assert!(code.contains("PipelineHandler::new"));
415 assert!(code.contains("tool: \"step1\""));
416 assert!(code.contains("tool: \"step2\""));
417 assert!(code.contains("output_var: Some(\"result\""));
418 assert!(code.contains("condition: Some(\"result\""));
419 assert!(code.contains("ErrorPolicy::FailFast"));
420 assert!(code.contains("ErrorPolicy::Continue"));
421 }
422}