pforge_codegen/
generator.rs

1use pforge_config::{ForgeConfig, ParamSchema, ParamType, SimpleType};
2use std::path::PathBuf;
3use thiserror::Error;
4
5#[derive(Debug, Error)]
6pub enum CodegenError {
7    #[error("IO error: {0}: {1}")]
8    IoError(PathBuf, #[source] std::io::Error),
9
10    #[error("Generation error: {0}")]
11    GenerationError(String),
12}
13
14pub type Result<T> = std::result::Result<T, CodegenError>;
15
16/// Generate a parameter struct from ParamSchema
17pub fn generate_param_struct(tool_name: &str, params: &ParamSchema) -> Result<String> {
18    let struct_name = format!("{}Params", to_pascal_case(tool_name));
19    let mut output = String::new();
20
21    // Generate struct
22    output.push_str(&format!("#[derive(Debug, Deserialize, JsonSchema)]\n"));
23    output.push_str(&format!("pub struct {} {{\n", struct_name));
24
25    for (field_name, param_type) in &params.fields {
26        // Generate field
27        let (ty, required, description) = match param_type {
28            ParamType::Simple(simple_ty) => {
29                (rust_type_from_simple(simple_ty), true, None)
30            }
31            ParamType::Complex {
32                ty,
33                required,
34                description,
35                ..
36            } => (rust_type_from_simple(ty), *required, description.clone()),
37        };
38
39        // Add documentation if present
40        if let Some(desc) = description {
41            output.push_str(&format!("    /// {}\n", desc));
42        }
43
44        // Add field
45        if required {
46            output.push_str(&format!("    pub {}: {},\n", field_name, ty));
47        } else {
48            output.push_str(&format!("    pub {}: Option<{}>,\n", field_name, ty));
49        }
50    }
51
52    output.push_str("}\n");
53
54    Ok(output)
55}
56
57/// Generate handler registration code
58pub fn generate_handler_registration(config: &ForgeConfig) -> Result<String> {
59    let mut output = String::new();
60
61    output.push_str("pub fn register_handlers(registry: &mut HandlerRegistry) {\n");
62
63    for tool in &config.tools {
64        match tool {
65            pforge_config::ToolDef::Native { name, handler, .. } => {
66                // Extract handler path
67                let handler_path = &handler.path;
68                output.push_str(&format!(
69                    "    registry.register(\"{}\", {});\n",
70                    name, handler_path
71                ));
72            }
73            pforge_config::ToolDef::Cli {
74                name,
75                command,
76                args,
77                cwd,
78                env,
79                stream,
80                description: _,
81            } => {
82                output.push_str(&format!(
83                    "    registry.register(\"{}\", CliHandler::new(\n",
84                    name
85                ));
86                output.push_str(&format!("        \"{}\".to_string(),\n", command));
87                output.push_str(&format!("        vec![{}],\n", format_string_vec(args)));
88
89                if let Some(cwd_val) = cwd {
90                    output.push_str(&format!(
91                        "        Some(\"{}\".to_string()),\n",
92                        cwd_val
93                    ));
94                } else {
95                    output.push_str("        None,\n");
96                }
97
98                output.push_str(&format!(
99                    "        HashMap::new(), // env\n"
100                ));
101                output.push_str("        None, // timeout\n");
102                output.push_str(&format!("        {},\n", stream));
103                output.push_str("    ));\n");
104            }
105            pforge_config::ToolDef::Http {
106                name,
107                endpoint,
108                method,
109                headers,
110                auth,
111                description: _,
112            } => {
113                output.push_str(&format!(
114                    "    registry.register(\"{}\", HttpHandler::new(\n",
115                    name
116                ));
117                output.push_str(&format!("        \"{}\".to_string(),\n", endpoint));
118                output.push_str(&format!("        HttpMethod::{:?},\n", method));
119                output.push_str("        HashMap::new(), // headers\n");
120                output.push_str("        None, // auth\n");
121                output.push_str("    ));\n");
122            }
123            pforge_config::ToolDef::Pipeline { name, steps, description: _ } => {
124                output.push_str("    // Pipeline handler TBD\n");
125            }
126        }
127    }
128
129    output.push_str("}\n");
130
131    Ok(output)
132}
133
134fn to_pascal_case(s: &str) -> String {
135    s.split('_')
136        .map(|word| {
137            let mut chars = word.chars();
138            match chars.next() {
139                None => String::new(),
140                Some(first) => first.to_uppercase().chain(chars).collect(),
141            }
142        })
143        .collect()
144}
145
146fn rust_type_from_simple(ty: &SimpleType) -> &'static str {
147    match ty {
148        SimpleType::String => "String",
149        SimpleType::Integer => "i64",
150        SimpleType::Float => "f64",
151        SimpleType::Boolean => "bool",
152        SimpleType::Array => "Vec<serde_json::Value>",
153        SimpleType::Object => "serde_json::Value",
154    }
155}
156
157fn format_string_vec(vec: &[String]) -> String {
158    vec.iter()
159        .map(|s| format!("\"{}\".to_string()", s))
160        .collect::<Vec<_>>()
161        .join(", ")
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use pforge_config::*;
168    use std::collections::HashMap;
169
170    #[test]
171    fn test_to_pascal_case() {
172        assert_eq!(to_pascal_case("hello_world"), "HelloWorld");
173        assert_eq!(to_pascal_case("test"), "Test");
174        assert_eq!(to_pascal_case("foo_bar_baz"), "FooBarBaz");
175    }
176
177    #[test]
178    fn test_rust_type_from_simple() {
179        assert_eq!(rust_type_from_simple(&SimpleType::String), "String");
180        assert_eq!(rust_type_from_simple(&SimpleType::Integer), "i64");
181        assert_eq!(rust_type_from_simple(&SimpleType::Float), "f64");
182        assert_eq!(rust_type_from_simple(&SimpleType::Boolean), "bool");
183        assert_eq!(rust_type_from_simple(&SimpleType::Array), "Vec<serde_json::Value>");
184        assert_eq!(rust_type_from_simple(&SimpleType::Object), "serde_json::Value");
185    }
186
187    #[test]
188    fn test_format_string_vec() {
189        assert_eq!(format_string_vec(&vec!["foo".to_string(), "bar".to_string()]), "\"foo\".to_string(), \"bar\".to_string()");
190        assert_eq!(format_string_vec(&vec![]), "");
191    }
192
193    #[test]
194    fn test_generate_param_struct_simple() {
195        let mut fields = HashMap::new();
196        fields.insert("name".to_string(), ParamType::Simple(SimpleType::String));
197        fields.insert("age".to_string(), ParamType::Simple(SimpleType::Integer));
198
199        let params = ParamSchema { fields };
200        let result = generate_param_struct("test_tool", &params);
201
202        assert!(result.is_ok());
203        let code = result.unwrap();
204        assert!(code.contains("pub struct TestToolParams"));
205        assert!(code.contains("pub name: String"));
206        assert!(code.contains("pub age: i64"));
207    }
208
209    #[test]
210    fn test_generate_param_struct_complex() {
211        let mut fields = HashMap::new();
212        fields.insert(
213            "optional_field".to_string(),
214            ParamType::Complex {
215                ty: SimpleType::String,
216                required: false,
217                description: Some("An optional field".to_string()),
218                default: None,
219                validation: None,
220            }
221        );
222
223        let params = ParamSchema { fields };
224        let result = generate_param_struct("my_tool", &params);
225
226        assert!(result.is_ok());
227        let code = result.unwrap();
228        assert!(code.contains("/// An optional field"));
229        assert!(code.contains("pub optional_field: Option<String>"));
230    }
231
232    #[test]
233    fn test_generate_handler_registration_native() {
234        let config = ForgeConfig {
235            forge: ForgeMetadata {
236                name: "test".to_string(),
237                version: "1.0.0".to_string(),
238                transport: TransportType::Stdio,
239                optimization: OptimizationLevel::Debug,
240            },
241            tools: vec![
242                ToolDef::Native {
243                    name: "test_tool".to_string(),
244                    description: "Test".to_string(),
245                    handler: HandlerRef {
246                        path: "handlers::test_handler".to_string(),
247                        inline: None,
248                    },
249                    params: ParamSchema { fields: HashMap::new() },
250                    timeout_ms: None,
251                }
252            ],
253            resources: vec![],
254            prompts: vec![],
255            state: None,
256        };
257
258        let result = generate_handler_registration(&config);
259        assert!(result.is_ok());
260        let code = result.unwrap();
261        assert!(code.contains("pub fn register_handlers"));
262        assert!(code.contains("registry.register(\"test_tool\", handlers::test_handler)"));
263    }
264
265    #[test]
266    fn test_generate_handler_registration_cli() {
267        let config = ForgeConfig {
268            forge: ForgeMetadata {
269                name: "test".to_string(),
270                version: "1.0.0".to_string(),
271                transport: TransportType::Stdio,
272                optimization: OptimizationLevel::Debug,
273            },
274            tools: vec![
275                ToolDef::Cli {
276                    name: "cli_tool".to_string(),
277                    description: "CLI Test".to_string(),
278                    command: "echo".to_string(),
279                    args: vec!["hello".to_string()],
280                    cwd: None,
281                    env: HashMap::new(),
282                    stream: false,
283                }
284            ],
285            resources: vec![],
286            prompts: vec![],
287            state: None,
288        };
289
290        let result = generate_handler_registration(&config);
291        assert!(result.is_ok());
292        let code = result.unwrap();
293        assert!(code.contains("CliHandler::new"));
294        assert!(code.contains("\"echo\""));
295        assert!(code.contains("\"hello\""));
296    }
297
298    #[test]
299    fn test_generate_handler_registration_http() {
300        let config = ForgeConfig {
301            forge: ForgeMetadata {
302                name: "test".to_string(),
303                version: "1.0.0".to_string(),
304                transport: TransportType::Stdio,
305                optimization: OptimizationLevel::Debug,
306            },
307            tools: vec![
308                ToolDef::Http {
309                    name: "http_tool".to_string(),
310                    description: "HTTP Test".to_string(),
311                    endpoint: "https://api.example.com".to_string(),
312                    method: HttpMethod::Get,
313                    headers: HashMap::new(),
314                    auth: None,
315                }
316            ],
317            resources: vec![],
318            prompts: vec![],
319            state: None,
320        };
321
322        let result = generate_handler_registration(&config);
323        assert!(result.is_ok());
324        let code = result.unwrap();
325        assert!(code.contains("HttpHandler::new"));
326        assert!(code.contains("https://api.example.com"));
327        assert!(code.contains("HttpMethod::Get"));
328    }
329}