pforge_codegen/
generator.rs

1use pforge_config::{ForgeConfig, ParamSchema, ParamType, SimpleType};
2use std::path::PathBuf;
3use thiserror::Error;
4
5#[derive(Debug, Error)]
6pub enum CodegenError {
7    #[error("IO error: {0}: {1}")]
8    IoError(PathBuf, #[source] std::io::Error),
9
10    #[error("Generation error: {0}")]
11    GenerationError(String),
12}
13
14pub type Result<T> = std::result::Result<T, CodegenError>;
15
16/// Generate a parameter struct from ParamSchema
17pub fn generate_param_struct(tool_name: &str, params: &ParamSchema) -> Result<String> {
18    let struct_name = format!("{}Params", to_pascal_case(tool_name));
19    let mut output = String::new();
20
21    // Generate struct
22    output.push_str("#[derive(Debug, Deserialize, JsonSchema)]\n");
23    output.push_str(&format!("pub struct {} {{\n", struct_name));
24
25    for (field_name, param_type) in &params.fields {
26        // Generate field
27        let (ty, required, description) = match param_type {
28            ParamType::Simple(simple_ty) => (rust_type_from_simple(simple_ty), true, None),
29            ParamType::Complex {
30                ty,
31                required,
32                description,
33                ..
34            } => (rust_type_from_simple(ty), *required, description.clone()),
35        };
36
37        // Add documentation if present
38        if let Some(desc) = description {
39            output.push_str(&format!("    /// {}\n", desc));
40        }
41
42        // Add field
43        if required {
44            output.push_str(&format!("    pub {}: {},\n", field_name, ty));
45        } else {
46            output.push_str(&format!("    pub {}: Option<{}>,\n", field_name, ty));
47        }
48    }
49
50    output.push_str("}\n");
51
52    Ok(output)
53}
54
55/// Generate handler registration code
56pub fn generate_handler_registration(config: &ForgeConfig) -> Result<String> {
57    let mut output = String::new();
58
59    output.push_str("pub fn register_handlers(registry: &mut HandlerRegistry) {\n");
60
61    for tool in &config.tools {
62        match tool {
63            pforge_config::ToolDef::Native { name, handler, .. } => {
64                // Extract handler path
65                let handler_path = &handler.path;
66                output.push_str(&format!(
67                    "    registry.register(\"{}\", {});\n",
68                    name, handler_path
69                ));
70            }
71            pforge_config::ToolDef::Cli {
72                name,
73                command,
74                args,
75                cwd,
76                env: _,
77                stream,
78                ..
79            } => {
80                output.push_str(&format!(
81                    "    registry.register(\"{}\", CliHandler::new(\n",
82                    name
83                ));
84                output.push_str(&format!("        \"{}\".to_string(),\n", command));
85                output.push_str(&format!("        vec![{}],\n", format_string_vec(args)));
86
87                if let Some(cwd_val) = cwd {
88                    output.push_str(&format!("        Some(\"{}\".to_string()),\n", cwd_val));
89                } else {
90                    output.push_str("        None,\n");
91                }
92
93                output.push_str("        FxHashMap::default(), // env\n");
94                output.push_str("        None, // timeout\n");
95                output.push_str(&format!("        {},\n", stream));
96                output.push_str("    ));\n");
97            }
98            pforge_config::ToolDef::Http {
99                name,
100                endpoint,
101                method,
102                ..
103            } => {
104                output.push_str(&format!(
105                    "    registry.register(\"{}\", HttpHandler::new(\n",
106                    name
107                ));
108                output.push_str(&format!("        \"{}\".to_string(),\n", endpoint));
109                output.push_str(&format!("        HttpMethod::{:?},\n", method));
110                output.push_str("        FxHashMap::default(), // headers\n");
111                output.push_str("        None, // auth\n");
112                output.push_str("    ));\n");
113            }
114            pforge_config::ToolDef::Pipeline {
115                name,
116                steps,
117                description: _,
118            } => {
119                output.push_str(&format!(
120                    "    registry.register(\"{}\", PipelineHandler::new(vec![\n",
121                    name
122                ));
123                for step in steps {
124                    output.push_str("        PipelineStep {\n");
125                    output.push_str(&format!(
126                        "            tool: \"{}\".to_string(),\n",
127                        step.tool
128                    ));
129                    // Generate input
130                    if let Some(input_val) = &step.input {
131                        output.push_str(&format!(
132                            "            input: Some(serde_json::json!({})),\n",
133                            serde_json::to_string(input_val).unwrap_or_else(|_| "{}".to_string())
134                        ));
135                    } else {
136                        output.push_str("            input: None,\n");
137                    }
138                    // Generate output_var
139                    if let Some(var) = &step.output_var {
140                        output.push_str(&format!(
141                            "            output_var: Some(\"{}\".to_string()),\n",
142                            var
143                        ));
144                    } else {
145                        output.push_str("            output_var: None,\n");
146                    }
147                    // Generate condition
148                    if let Some(cond) = &step.condition {
149                        output.push_str(&format!(
150                            "            condition: Some(\"{}\".to_string()),\n",
151                            cond
152                        ));
153                    } else {
154                        output.push_str("            condition: None,\n");
155                    }
156                    // Generate error policy
157                    let policy = match step.error_policy {
158                        pforge_config::ErrorPolicy::FailFast => "ErrorPolicy::FailFast",
159                        pforge_config::ErrorPolicy::Continue => "ErrorPolicy::Continue",
160                    };
161                    output.push_str(&format!("            error_policy: {},\n", policy));
162                    output.push_str("        },\n");
163                }
164                output.push_str("    ]));\n");
165            }
166        }
167    }
168
169    output.push_str("}\n");
170
171    Ok(output)
172}
173
174fn to_pascal_case(s: &str) -> String {
175    s.split('_')
176        .map(|word| {
177            let mut chars = word.chars();
178            match chars.next() {
179                None => String::new(),
180                Some(first) => first.to_uppercase().chain(chars).collect(),
181            }
182        })
183        .collect()
184}
185
186fn rust_type_from_simple(ty: &SimpleType) -> &'static str {
187    match ty {
188        SimpleType::String => "String",
189        SimpleType::Integer => "i64",
190        SimpleType::Float => "f64",
191        SimpleType::Boolean => "bool",
192        SimpleType::Array => "Vec<serde_json::Value>",
193        SimpleType::Object => "serde_json::Value",
194    }
195}
196
197fn format_string_vec(vec: &[String]) -> String {
198    vec.iter()
199        .map(|s| format!("\"{}\".to_string()", s))
200        .collect::<Vec<_>>()
201        .join(", ")
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use pforge_config::*;
208    use rustc_hash::FxHashMap;
209
210    #[test]
211    fn test_to_pascal_case() {
212        assert_eq!(to_pascal_case("hello_world"), "HelloWorld");
213        assert_eq!(to_pascal_case("test"), "Test");
214        assert_eq!(to_pascal_case("foo_bar_baz"), "FooBarBaz");
215    }
216
217    #[test]
218    fn test_rust_type_from_simple() {
219        assert_eq!(rust_type_from_simple(&SimpleType::String), "String");
220        assert_eq!(rust_type_from_simple(&SimpleType::Integer), "i64");
221        assert_eq!(rust_type_from_simple(&SimpleType::Float), "f64");
222        assert_eq!(rust_type_from_simple(&SimpleType::Boolean), "bool");
223        assert_eq!(
224            rust_type_from_simple(&SimpleType::Array),
225            "Vec<serde_json::Value>"
226        );
227        assert_eq!(
228            rust_type_from_simple(&SimpleType::Object),
229            "serde_json::Value"
230        );
231    }
232
233    #[test]
234    fn test_format_string_vec() {
235        assert_eq!(
236            format_string_vec(&["foo".to_string(), "bar".to_string()]),
237            "\"foo\".to_string(), \"bar\".to_string()"
238        );
239        assert_eq!(format_string_vec(&[]), "");
240    }
241
242    #[test]
243    fn test_generate_param_struct_simple() {
244        let mut fields = FxHashMap::default();
245        fields.insert("name".to_string(), ParamType::Simple(SimpleType::String));
246        fields.insert("age".to_string(), ParamType::Simple(SimpleType::Integer));
247
248        let params = ParamSchema { fields };
249        let result = generate_param_struct("test_tool", &params);
250
251        assert!(result.is_ok());
252        let code = result.unwrap();
253        assert!(code.contains("pub struct TestToolParams"));
254        assert!(code.contains("pub name: String"));
255        assert!(code.contains("pub age: i64"));
256    }
257
258    #[test]
259    fn test_generate_param_struct_complex() {
260        let mut fields = FxHashMap::default();
261        fields.insert(
262            "optional_field".to_string(),
263            ParamType::Complex {
264                ty: SimpleType::String,
265                required: false,
266                description: Some("An optional field".to_string()),
267                default: None,
268                validation: None,
269            },
270        );
271
272        let params = ParamSchema { fields };
273        let result = generate_param_struct("my_tool", &params);
274
275        assert!(result.is_ok());
276        let code = result.unwrap();
277        assert!(code.contains("/// An optional field"));
278        assert!(code.contains("pub optional_field: Option<String>"));
279    }
280
281    #[test]
282    fn test_generate_handler_registration_native() {
283        let config = ForgeConfig {
284            forge: ForgeMetadata {
285                name: "test".to_string(),
286                version: "1.0.0".to_string(),
287                transport: TransportType::Stdio,
288                optimization: OptimizationLevel::Debug,
289            },
290            tools: vec![ToolDef::Native {
291                name: "test_tool".to_string(),
292                description: "Test".to_string(),
293                handler: HandlerRef {
294                    path: "handlers::test_handler".to_string(),
295                    inline: None,
296                },
297                params: ParamSchema {
298                    fields: FxHashMap::default(),
299                },
300                timeout_ms: None,
301            }],
302            resources: vec![],
303            prompts: vec![],
304            state: None,
305        };
306
307        let result = generate_handler_registration(&config);
308        assert!(result.is_ok());
309        let code = result.unwrap();
310        assert!(code.contains("pub fn register_handlers"));
311        assert!(code.contains("registry.register(\"test_tool\", handlers::test_handler)"));
312    }
313
314    #[test]
315    fn test_generate_handler_registration_cli() {
316        let config = ForgeConfig {
317            forge: ForgeMetadata {
318                name: "test".to_string(),
319                version: "1.0.0".to_string(),
320                transport: TransportType::Stdio,
321                optimization: OptimizationLevel::Debug,
322            },
323            tools: vec![ToolDef::Cli {
324                name: "cli_tool".to_string(),
325                description: "CLI Test".to_string(),
326                command: "echo".to_string(),
327                args: vec!["hello".to_string()],
328                cwd: None,
329                env: FxHashMap::default(),
330                stream: false,
331                timeout_ms: None,
332            }],
333            resources: vec![],
334            prompts: vec![],
335            state: None,
336        };
337
338        let result = generate_handler_registration(&config);
339        assert!(result.is_ok());
340        let code = result.unwrap();
341        assert!(code.contains("CliHandler::new"));
342        assert!(code.contains("\"echo\""));
343        assert!(code.contains("\"hello\""));
344    }
345
346    #[test]
347    fn test_generate_handler_registration_http() {
348        let config = ForgeConfig {
349            forge: ForgeMetadata {
350                name: "test".to_string(),
351                version: "1.0.0".to_string(),
352                transport: TransportType::Stdio,
353                optimization: OptimizationLevel::Debug,
354            },
355            tools: vec![ToolDef::Http {
356                name: "http_tool".to_string(),
357                description: "HTTP Test".to_string(),
358                endpoint: "https://api.example.com".to_string(),
359                method: HttpMethod::Get,
360                headers: FxHashMap::default(),
361                auth: None,
362                timeout_ms: None,
363            }],
364            resources: vec![],
365            prompts: vec![],
366            state: None,
367        };
368
369        let result = generate_handler_registration(&config);
370        assert!(result.is_ok());
371        let code = result.unwrap();
372        assert!(code.contains("HttpHandler::new"));
373        assert!(code.contains("https://api.example.com"));
374        assert!(code.contains("HttpMethod::Get"));
375    }
376
377    #[test]
378    fn test_generate_handler_registration_pipeline() {
379        let config = ForgeConfig {
380            forge: ForgeMetadata {
381                name: "test".to_string(),
382                version: "1.0.0".to_string(),
383                transport: TransportType::Stdio,
384                optimization: OptimizationLevel::Debug,
385            },
386            tools: vec![ToolDef::Pipeline {
387                name: "pipeline_tool".to_string(),
388                description: "Pipeline Test".to_string(),
389                steps: vec![
390                    pforge_config::PipelineStep {
391                        tool: "step1".to_string(),
392                        input: Some(serde_json::json!({"key": "value"})),
393                        output_var: Some("result".to_string()),
394                        condition: None,
395                        error_policy: pforge_config::ErrorPolicy::FailFast,
396                    },
397                    pforge_config::PipelineStep {
398                        tool: "step2".to_string(),
399                        input: None,
400                        output_var: None,
401                        condition: Some("result".to_string()),
402                        error_policy: pforge_config::ErrorPolicy::Continue,
403                    },
404                ],
405            }],
406            resources: vec![],
407            prompts: vec![],
408            state: None,
409        };
410
411        let result = generate_handler_registration(&config);
412        assert!(result.is_ok());
413        let code = result.unwrap();
414        assert!(code.contains("PipelineHandler::new"));
415        assert!(code.contains("tool: \"step1\""));
416        assert!(code.contains("tool: \"step2\""));
417        assert!(code.contains("output_var: Some(\"result\""));
418        assert!(code.contains("condition: Some(\"result\""));
419        assert!(code.contains("ErrorPolicy::FailFast"));
420        assert!(code.contains("ErrorPolicy::Continue"));
421    }
422}