1use pforge_config::{ForgeConfig, ParamSchema, ParamType, SimpleType};
2use std::path::PathBuf;
3use thiserror::Error;
4
5#[derive(Debug, Error)]
6pub enum CodegenError {
7 #[error("IO error: {0}: {1}")]
8 IoError(PathBuf, #[source] std::io::Error),
9
10 #[error("Generation error: {0}")]
11 GenerationError(String),
12}
13
14pub type Result<T> = std::result::Result<T, CodegenError>;
15
16pub fn generate_param_struct(tool_name: &str, params: &ParamSchema) -> Result<String> {
18 let struct_name = format!("{}Params", to_pascal_case(tool_name));
19 let mut output = String::new();
20
21 output.push_str(&format!("#[derive(Debug, Deserialize, JsonSchema)]\n"));
23 output.push_str(&format!("pub struct {} {{\n", struct_name));
24
25 for (field_name, param_type) in ¶ms.fields {
26 let (ty, required, description) = match param_type {
28 ParamType::Simple(simple_ty) => {
29 (rust_type_from_simple(simple_ty), true, None)
30 }
31 ParamType::Complex {
32 ty,
33 required,
34 description,
35 ..
36 } => (rust_type_from_simple(ty), *required, description.clone()),
37 };
38
39 if let Some(desc) = description {
41 output.push_str(&format!(" /// {}\n", desc));
42 }
43
44 if required {
46 output.push_str(&format!(" pub {}: {},\n", field_name, ty));
47 } else {
48 output.push_str(&format!(" pub {}: Option<{}>,\n", field_name, ty));
49 }
50 }
51
52 output.push_str("}\n");
53
54 Ok(output)
55}
56
57pub fn generate_handler_registration(config: &ForgeConfig) -> Result<String> {
59 let mut output = String::new();
60
61 output.push_str("pub fn register_handlers(registry: &mut HandlerRegistry) {\n");
62
63 for tool in &config.tools {
64 match tool {
65 pforge_config::ToolDef::Native { name, handler, .. } => {
66 let handler_path = &handler.path;
68 output.push_str(&format!(
69 " registry.register(\"{}\", {});\n",
70 name, handler_path
71 ));
72 }
73 pforge_config::ToolDef::Cli {
74 name,
75 command,
76 args,
77 cwd,
78 env,
79 stream,
80 description: _,
81 } => {
82 output.push_str(&format!(
83 " registry.register(\"{}\", CliHandler::new(\n",
84 name
85 ));
86 output.push_str(&format!(" \"{}\".to_string(),\n", command));
87 output.push_str(&format!(" vec![{}],\n", format_string_vec(args)));
88
89 if let Some(cwd_val) = cwd {
90 output.push_str(&format!(
91 " Some(\"{}\".to_string()),\n",
92 cwd_val
93 ));
94 } else {
95 output.push_str(" None,\n");
96 }
97
98 output.push_str(&format!(
99 " HashMap::new(), // env\n"
100 ));
101 output.push_str(" None, // timeout\n");
102 output.push_str(&format!(" {},\n", stream));
103 output.push_str(" ));\n");
104 }
105 pforge_config::ToolDef::Http {
106 name,
107 endpoint,
108 method,
109 headers,
110 auth,
111 description: _,
112 } => {
113 output.push_str(&format!(
114 " registry.register(\"{}\", HttpHandler::new(\n",
115 name
116 ));
117 output.push_str(&format!(" \"{}\".to_string(),\n", endpoint));
118 output.push_str(&format!(" HttpMethod::{:?},\n", method));
119 output.push_str(" HashMap::new(), // headers\n");
120 output.push_str(" None, // auth\n");
121 output.push_str(" ));\n");
122 }
123 pforge_config::ToolDef::Pipeline { name, steps, description: _ } => {
124 output.push_str(" // Pipeline handler TBD\n");
125 }
126 }
127 }
128
129 output.push_str("}\n");
130
131 Ok(output)
132}
133
134fn to_pascal_case(s: &str) -> String {
135 s.split('_')
136 .map(|word| {
137 let mut chars = word.chars();
138 match chars.next() {
139 None => String::new(),
140 Some(first) => first.to_uppercase().chain(chars).collect(),
141 }
142 })
143 .collect()
144}
145
146fn rust_type_from_simple(ty: &SimpleType) -> &'static str {
147 match ty {
148 SimpleType::String => "String",
149 SimpleType::Integer => "i64",
150 SimpleType::Float => "f64",
151 SimpleType::Boolean => "bool",
152 SimpleType::Array => "Vec<serde_json::Value>",
153 SimpleType::Object => "serde_json::Value",
154 }
155}
156
157fn format_string_vec(vec: &[String]) -> String {
158 vec.iter()
159 .map(|s| format!("\"{}\".to_string()", s))
160 .collect::<Vec<_>>()
161 .join(", ")
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use pforge_config::*;
168 use std::collections::HashMap;
169
170 #[test]
171 fn test_to_pascal_case() {
172 assert_eq!(to_pascal_case("hello_world"), "HelloWorld");
173 assert_eq!(to_pascal_case("test"), "Test");
174 assert_eq!(to_pascal_case("foo_bar_baz"), "FooBarBaz");
175 }
176
177 #[test]
178 fn test_rust_type_from_simple() {
179 assert_eq!(rust_type_from_simple(&SimpleType::String), "String");
180 assert_eq!(rust_type_from_simple(&SimpleType::Integer), "i64");
181 assert_eq!(rust_type_from_simple(&SimpleType::Float), "f64");
182 assert_eq!(rust_type_from_simple(&SimpleType::Boolean), "bool");
183 assert_eq!(rust_type_from_simple(&SimpleType::Array), "Vec<serde_json::Value>");
184 assert_eq!(rust_type_from_simple(&SimpleType::Object), "serde_json::Value");
185 }
186
187 #[test]
188 fn test_format_string_vec() {
189 assert_eq!(format_string_vec(&vec!["foo".to_string(), "bar".to_string()]), "\"foo\".to_string(), \"bar\".to_string()");
190 assert_eq!(format_string_vec(&vec![]), "");
191 }
192
193 #[test]
194 fn test_generate_param_struct_simple() {
195 let mut fields = HashMap::new();
196 fields.insert("name".to_string(), ParamType::Simple(SimpleType::String));
197 fields.insert("age".to_string(), ParamType::Simple(SimpleType::Integer));
198
199 let params = ParamSchema { fields };
200 let result = generate_param_struct("test_tool", ¶ms);
201
202 assert!(result.is_ok());
203 let code = result.unwrap();
204 assert!(code.contains("pub struct TestToolParams"));
205 assert!(code.contains("pub name: String"));
206 assert!(code.contains("pub age: i64"));
207 }
208
209 #[test]
210 fn test_generate_param_struct_complex() {
211 let mut fields = HashMap::new();
212 fields.insert(
213 "optional_field".to_string(),
214 ParamType::Complex {
215 ty: SimpleType::String,
216 required: false,
217 description: Some("An optional field".to_string()),
218 default: None,
219 validation: None,
220 }
221 );
222
223 let params = ParamSchema { fields };
224 let result = generate_param_struct("my_tool", ¶ms);
225
226 assert!(result.is_ok());
227 let code = result.unwrap();
228 assert!(code.contains("/// An optional field"));
229 assert!(code.contains("pub optional_field: Option<String>"));
230 }
231
232 #[test]
233 fn test_generate_handler_registration_native() {
234 let config = ForgeConfig {
235 forge: ForgeMetadata {
236 name: "test".to_string(),
237 version: "1.0.0".to_string(),
238 transport: TransportType::Stdio,
239 optimization: OptimizationLevel::Debug,
240 },
241 tools: vec![
242 ToolDef::Native {
243 name: "test_tool".to_string(),
244 description: "Test".to_string(),
245 handler: HandlerRef {
246 path: "handlers::test_handler".to_string(),
247 inline: None,
248 },
249 params: ParamSchema { fields: HashMap::new() },
250 timeout_ms: None,
251 }
252 ],
253 resources: vec![],
254 prompts: vec![],
255 state: None,
256 };
257
258 let result = generate_handler_registration(&config);
259 assert!(result.is_ok());
260 let code = result.unwrap();
261 assert!(code.contains("pub fn register_handlers"));
262 assert!(code.contains("registry.register(\"test_tool\", handlers::test_handler)"));
263 }
264
265 #[test]
266 fn test_generate_handler_registration_cli() {
267 let config = ForgeConfig {
268 forge: ForgeMetadata {
269 name: "test".to_string(),
270 version: "1.0.0".to_string(),
271 transport: TransportType::Stdio,
272 optimization: OptimizationLevel::Debug,
273 },
274 tools: vec![
275 ToolDef::Cli {
276 name: "cli_tool".to_string(),
277 description: "CLI Test".to_string(),
278 command: "echo".to_string(),
279 args: vec!["hello".to_string()],
280 cwd: None,
281 env: HashMap::new(),
282 stream: false,
283 }
284 ],
285 resources: vec![],
286 prompts: vec![],
287 state: None,
288 };
289
290 let result = generate_handler_registration(&config);
291 assert!(result.is_ok());
292 let code = result.unwrap();
293 assert!(code.contains("CliHandler::new"));
294 assert!(code.contains("\"echo\""));
295 assert!(code.contains("\"hello\""));
296 }
297
298 #[test]
299 fn test_generate_handler_registration_http() {
300 let config = ForgeConfig {
301 forge: ForgeMetadata {
302 name: "test".to_string(),
303 version: "1.0.0".to_string(),
304 transport: TransportType::Stdio,
305 optimization: OptimizationLevel::Debug,
306 },
307 tools: vec![
308 ToolDef::Http {
309 name: "http_tool".to_string(),
310 description: "HTTP Test".to_string(),
311 endpoint: "https://api.example.com".to_string(),
312 method: HttpMethod::Get,
313 headers: HashMap::new(),
314 auth: None,
315 }
316 ],
317 resources: vec![],
318 prompts: vec![],
319 state: None,
320 };
321
322 let result = generate_handler_registration(&config);
323 assert!(result.is_ok());
324 let code = result.unwrap();
325 assert!(code.contains("HttpHandler::new"));
326 assert!(code.contains("https://api.example.com"));
327 assert!(code.contains("HttpMethod::Get"));
328 }
329}