pforge_codegen/
generator.rs

1use pforge_config::{ForgeConfig, ParamSchema, ParamType, SimpleType};
2use std::path::PathBuf;
3use thiserror::Error;
4
5#[derive(Debug, Error)]
6pub enum CodegenError {
7    #[error("IO error: {0}: {1}")]
8    IoError(PathBuf, #[source] std::io::Error),
9
10    #[error("Generation error: {0}")]
11    GenerationError(String),
12}
13
14pub type Result<T> = std::result::Result<T, CodegenError>;
15
16/// Generate a parameter struct from ParamSchema
17pub fn generate_param_struct(tool_name: &str, params: &ParamSchema) -> Result<String> {
18    let struct_name = format!("{}Params", to_pascal_case(tool_name));
19    let mut output = String::new();
20
21    // Generate struct
22    output.push_str("#[derive(Debug, Deserialize, JsonSchema)]\n");
23    output.push_str(&format!("pub struct {} {{\n", struct_name));
24
25    for (field_name, param_type) in &params.fields {
26        // Generate field
27        let (ty, required, description) = match param_type {
28            ParamType::Simple(simple_ty) => (rust_type_from_simple(simple_ty), true, None),
29            ParamType::Complex {
30                ty,
31                required,
32                description,
33                ..
34            } => (rust_type_from_simple(ty), *required, description.clone()),
35        };
36
37        // Add documentation if present
38        if let Some(desc) = description {
39            output.push_str(&format!("    /// {}\n", desc));
40        }
41
42        // Add field
43        if required {
44            output.push_str(&format!("    pub {}: {},\n", field_name, ty));
45        } else {
46            output.push_str(&format!("    pub {}: Option<{}>,\n", field_name, ty));
47        }
48    }
49
50    output.push_str("}\n");
51
52    Ok(output)
53}
54
55/// Generate handler registration code
56pub fn generate_handler_registration(config: &ForgeConfig) -> Result<String> {
57    let mut output = String::new();
58
59    output.push_str("pub fn register_handlers(registry: &mut HandlerRegistry) {\n");
60
61    for tool in &config.tools {
62        match tool {
63            pforge_config::ToolDef::Native { name, handler, .. } => {
64                // Extract handler path
65                let handler_path = &handler.path;
66                output.push_str(&format!(
67                    "    registry.register(\"{}\", {});\n",
68                    name, handler_path
69                ));
70            }
71            pforge_config::ToolDef::Cli {
72                name,
73                command,
74                args,
75                cwd,
76                env: _,
77                stream,
78                description: _,
79            } => {
80                output.push_str(&format!(
81                    "    registry.register(\"{}\", CliHandler::new(\n",
82                    name
83                ));
84                output.push_str(&format!("        \"{}\".to_string(),\n", command));
85                output.push_str(&format!("        vec![{}],\n", format_string_vec(args)));
86
87                if let Some(cwd_val) = cwd {
88                    output.push_str(&format!("        Some(\"{}\".to_string()),\n", cwd_val));
89                } else {
90                    output.push_str("        None,\n");
91                }
92
93                output.push_str("        HashMap::new(), // env\n");
94                output.push_str("        None, // timeout\n");
95                output.push_str(&format!("        {},\n", stream));
96                output.push_str("    ));\n");
97            }
98            pforge_config::ToolDef::Http {
99                name,
100                endpoint,
101                method,
102                headers: _,
103                auth: _,
104                description: _,
105            } => {
106                output.push_str(&format!(
107                    "    registry.register(\"{}\", HttpHandler::new(\n",
108                    name
109                ));
110                output.push_str(&format!("        \"{}\".to_string(),\n", endpoint));
111                output.push_str(&format!("        HttpMethod::{:?},\n", method));
112                output.push_str("        HashMap::new(), // headers\n");
113                output.push_str("        None, // auth\n");
114                output.push_str("    ));\n");
115            }
116            pforge_config::ToolDef::Pipeline {
117                name: _,
118                steps: _,
119                description: _,
120            } => {
121                output.push_str("    // Pipeline handler TBD\n");
122            }
123        }
124    }
125
126    output.push_str("}\n");
127
128    Ok(output)
129}
130
131fn to_pascal_case(s: &str) -> String {
132    s.split('_')
133        .map(|word| {
134            let mut chars = word.chars();
135            match chars.next() {
136                None => String::new(),
137                Some(first) => first.to_uppercase().chain(chars).collect(),
138            }
139        })
140        .collect()
141}
142
143fn rust_type_from_simple(ty: &SimpleType) -> &'static str {
144    match ty {
145        SimpleType::String => "String",
146        SimpleType::Integer => "i64",
147        SimpleType::Float => "f64",
148        SimpleType::Boolean => "bool",
149        SimpleType::Array => "Vec<serde_json::Value>",
150        SimpleType::Object => "serde_json::Value",
151    }
152}
153
154fn format_string_vec(vec: &[String]) -> String {
155    vec.iter()
156        .map(|s| format!("\"{}\".to_string()", s))
157        .collect::<Vec<_>>()
158        .join(", ")
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use pforge_config::*;
165    use std::collections::HashMap;
166
167    #[test]
168    fn test_to_pascal_case() {
169        assert_eq!(to_pascal_case("hello_world"), "HelloWorld");
170        assert_eq!(to_pascal_case("test"), "Test");
171        assert_eq!(to_pascal_case("foo_bar_baz"), "FooBarBaz");
172    }
173
174    #[test]
175    fn test_rust_type_from_simple() {
176        assert_eq!(rust_type_from_simple(&SimpleType::String), "String");
177        assert_eq!(rust_type_from_simple(&SimpleType::Integer), "i64");
178        assert_eq!(rust_type_from_simple(&SimpleType::Float), "f64");
179        assert_eq!(rust_type_from_simple(&SimpleType::Boolean), "bool");
180        assert_eq!(
181            rust_type_from_simple(&SimpleType::Array),
182            "Vec<serde_json::Value>"
183        );
184        assert_eq!(
185            rust_type_from_simple(&SimpleType::Object),
186            "serde_json::Value"
187        );
188    }
189
190    #[test]
191    fn test_format_string_vec() {
192        assert_eq!(
193            format_string_vec(&["foo".to_string(), "bar".to_string()]),
194            "\"foo\".to_string(), \"bar\".to_string()"
195        );
196        assert_eq!(format_string_vec(&[]), "");
197    }
198
199    #[test]
200    fn test_generate_param_struct_simple() {
201        let mut fields = HashMap::new();
202        fields.insert("name".to_string(), ParamType::Simple(SimpleType::String));
203        fields.insert("age".to_string(), ParamType::Simple(SimpleType::Integer));
204
205        let params = ParamSchema { fields };
206        let result = generate_param_struct("test_tool", &params);
207
208        assert!(result.is_ok());
209        let code = result.unwrap();
210        assert!(code.contains("pub struct TestToolParams"));
211        assert!(code.contains("pub name: String"));
212        assert!(code.contains("pub age: i64"));
213    }
214
215    #[test]
216    fn test_generate_param_struct_complex() {
217        let mut fields = HashMap::new();
218        fields.insert(
219            "optional_field".to_string(),
220            ParamType::Complex {
221                ty: SimpleType::String,
222                required: false,
223                description: Some("An optional field".to_string()),
224                default: None,
225                validation: None,
226            },
227        );
228
229        let params = ParamSchema { fields };
230        let result = generate_param_struct("my_tool", &params);
231
232        assert!(result.is_ok());
233        let code = result.unwrap();
234        assert!(code.contains("/// An optional field"));
235        assert!(code.contains("pub optional_field: Option<String>"));
236    }
237
238    #[test]
239    fn test_generate_handler_registration_native() {
240        let config = ForgeConfig {
241            forge: ForgeMetadata {
242                name: "test".to_string(),
243                version: "1.0.0".to_string(),
244                transport: TransportType::Stdio,
245                optimization: OptimizationLevel::Debug,
246            },
247            tools: vec![ToolDef::Native {
248                name: "test_tool".to_string(),
249                description: "Test".to_string(),
250                handler: HandlerRef {
251                    path: "handlers::test_handler".to_string(),
252                    inline: None,
253                },
254                params: ParamSchema {
255                    fields: HashMap::new(),
256                },
257                timeout_ms: None,
258            }],
259            resources: vec![],
260            prompts: vec![],
261            state: None,
262        };
263
264        let result = generate_handler_registration(&config);
265        assert!(result.is_ok());
266        let code = result.unwrap();
267        assert!(code.contains("pub fn register_handlers"));
268        assert!(code.contains("registry.register(\"test_tool\", handlers::test_handler)"));
269    }
270
271    #[test]
272    fn test_generate_handler_registration_cli() {
273        let config = ForgeConfig {
274            forge: ForgeMetadata {
275                name: "test".to_string(),
276                version: "1.0.0".to_string(),
277                transport: TransportType::Stdio,
278                optimization: OptimizationLevel::Debug,
279            },
280            tools: vec![ToolDef::Cli {
281                name: "cli_tool".to_string(),
282                description: "CLI Test".to_string(),
283                command: "echo".to_string(),
284                args: vec!["hello".to_string()],
285                cwd: None,
286                env: HashMap::new(),
287                stream: false,
288            }],
289            resources: vec![],
290            prompts: vec![],
291            state: None,
292        };
293
294        let result = generate_handler_registration(&config);
295        assert!(result.is_ok());
296        let code = result.unwrap();
297        assert!(code.contains("CliHandler::new"));
298        assert!(code.contains("\"echo\""));
299        assert!(code.contains("\"hello\""));
300    }
301
302    #[test]
303    fn test_generate_handler_registration_http() {
304        let config = ForgeConfig {
305            forge: ForgeMetadata {
306                name: "test".to_string(),
307                version: "1.0.0".to_string(),
308                transport: TransportType::Stdio,
309                optimization: OptimizationLevel::Debug,
310            },
311            tools: vec![ToolDef::Http {
312                name: "http_tool".to_string(),
313                description: "HTTP Test".to_string(),
314                endpoint: "https://api.example.com".to_string(),
315                method: HttpMethod::Get,
316                headers: HashMap::new(),
317                auth: None,
318            }],
319            resources: vec![],
320            prompts: vec![],
321            state: None,
322        };
323
324        let result = generate_handler_registration(&config);
325        assert!(result.is_ok());
326        let code = result.unwrap();
327        assert!(code.contains("HttpHandler::new"));
328        assert!(code.contains("https://api.example.com"));
329        assert!(code.contains("HttpMethod::Get"));
330    }
331}