pforge_codegen/
lib.rs

1pub mod generator;
2
3pub use generator::{generate_handler_registration, generate_param_struct, CodegenError};
4
5use pforge_config::ForgeConfig;
6use std::path::Path;
7
8pub type Result<T> = std::result::Result<T, CodegenError>;
9
10/// Generate all code from a ForgeConfig
11pub fn generate_all(config: &ForgeConfig) -> Result<String> {
12    let mut output = String::new();
13
14    // Generate imports
15    output.push_str("// Auto-generated by pforge\n");
16    output.push_str("// DO NOT EDIT\n\n");
17    output.push_str("use pforge_runtime::*;\n");
18    output.push_str("use serde::{Deserialize, Serialize};\n");
19    output.push_str("use schemars::JsonSchema;\n\n");
20
21    // Generate parameter structs for each tool
22    for tool in &config.tools {
23        if let pforge_config::ToolDef::Native { name, params, .. } = tool {
24            output.push_str(&generate_param_struct(name, params)?);
25            output.push_str("\n\n");
26        }
27    }
28
29    // Generate handler registration function
30    output.push_str(&generate_handler_registration(config)?);
31
32    Ok(output)
33}
34
35/// Write generated code to a file
36pub fn write_generated_code(config: &ForgeConfig, output_path: &Path) -> Result<()> {
37    let code = generate_all(config)?;
38    std::fs::write(output_path, code)
39        .map_err(|e| CodegenError::IoError(output_path.to_path_buf(), e))?;
40    Ok(())
41}
42
43#[cfg(test)]
44mod tests {
45    use super::*;
46    use pforge_config::*;
47    use std::collections::HashMap;
48
49    fn create_test_config() -> ForgeConfig {
50        ForgeConfig {
51            forge: ForgeMetadata {
52                name: "test_server".to_string(),
53                version: "1.0.0".to_string(),
54                transport: TransportType::Stdio,
55                optimization: OptimizationLevel::Debug,
56            },
57            tools: vec![
58                ToolDef::Native {
59                    name: "test_tool".to_string(),
60                    description: "Test tool".to_string(),
61                    handler: HandlerRef {
62                        path: "handlers::test_handler".to_string(),
63                        inline: None,
64                    },
65                    params: ParamSchema {
66                        fields: {
67                            let mut map = HashMap::new();
68                            map.insert("input".to_string(), ParamType::Simple(SimpleType::String));
69                            map
70                        },
71                    },
72                    timeout_ms: None,
73                }
74            ],
75            resources: vec![],
76            prompts: vec![],
77            state: None,
78        }
79    }
80
81    #[test]
82    fn test_generate_all() {
83        let config = create_test_config();
84        let result = generate_all(&config);
85
86        assert!(result.is_ok());
87        let code = result.unwrap();
88
89        // Check for generated header
90        assert!(code.contains("// Auto-generated by pforge"));
91        assert!(code.contains("// DO NOT EDIT"));
92
93        // Check for imports
94        assert!(code.contains("use pforge_runtime::*"));
95        assert!(code.contains("use serde::{Deserialize, Serialize}"));
96        assert!(code.contains("use schemars::JsonSchema"));
97
98        // Check for param struct
99        assert!(code.contains("pub struct TestToolParams"));
100
101        // Check for registration function
102        assert!(code.contains("pub fn register_handlers"));
103    }
104
105    #[test]
106    fn test_generate_all_empty_tools() {
107        let config = ForgeConfig {
108            forge: ForgeMetadata {
109                name: "empty".to_string(),
110                version: "1.0.0".to_string(),
111                transport: TransportType::Stdio,
112                optimization: OptimizationLevel::Debug,
113            },
114            tools: vec![],
115            resources: vec![],
116            prompts: vec![],
117            state: None,
118        };
119
120        let result = generate_all(&config);
121        assert!(result.is_ok());
122        let code = result.unwrap();
123        assert!(code.contains("pub fn register_handlers"));
124    }
125
126    #[test]
127    fn test_write_generated_code() {
128        let config = create_test_config();
129        let temp_dir = std::env::temp_dir();
130        let output_path = temp_dir.join("test_generated.rs");
131
132        let result = write_generated_code(&config, &output_path);
133        assert!(result.is_ok());
134
135        // Verify file was created
136        assert!(output_path.exists());
137
138        // Verify content
139        let content = std::fs::read_to_string(&output_path).unwrap();
140        assert!(content.contains("pub struct TestToolParams"));
141
142        // Cleanup
143        std::fs::remove_file(&output_path).ok();
144    }
145
146    #[test]
147    fn test_write_generated_code_invalid_path() {
148        let config = create_test_config();
149        let invalid_path = Path::new("/nonexistent/directory/test.rs");
150
151        let result = write_generated_code(&config, invalid_path);
152        assert!(result.is_err());
153        assert!(matches!(result.unwrap_err(), CodegenError::IoError(_, _)));
154    }
155}