pforge_codegen/
lib.rs

1pub mod generator;
2
3pub use generator::{generate_handler_registration, generate_param_struct, CodegenError};
4
5use pforge_config::ForgeConfig;
6use std::path::Path;
7
8pub type Result<T> = std::result::Result<T, CodegenError>;
9
10/// Generate all code from a ForgeConfig
11pub fn generate_all(config: &ForgeConfig) -> Result<String> {
12    let mut output = String::new();
13
14    // Generate imports
15    output.push_str("// Auto-generated by pforge\n");
16    output.push_str("// DO NOT EDIT\n\n");
17    output.push_str("use pforge_runtime::*;\n");
18    output.push_str("use serde::{Deserialize, Serialize};\n");
19    output.push_str("use schemars::JsonSchema;\n\n");
20
21    // Generate parameter structs for each tool
22    for tool in &config.tools {
23        if let pforge_config::ToolDef::Native { name, params, .. } = tool {
24            output.push_str(&generate_param_struct(name, params)?);
25            output.push_str("\n\n");
26        }
27    }
28
29    // Generate handler registration function
30    output.push_str(&generate_handler_registration(config)?);
31
32    Ok(output)
33}
34
35/// Write generated code to a file
36pub fn write_generated_code(config: &ForgeConfig, output_path: &Path) -> Result<()> {
37    let code = generate_all(config)?;
38    std::fs::write(output_path, code)
39        .map_err(|e| CodegenError::IoError(output_path.to_path_buf(), e))?;
40    Ok(())
41}
42
43#[cfg(test)]
44mod tests {
45    use super::*;
46    use pforge_config::*;
47    use rustc_hash::FxHashMap;
48
49    fn create_test_config() -> ForgeConfig {
50        ForgeConfig {
51            forge: ForgeMetadata {
52                name: "test_server".to_string(),
53                version: "1.0.0".to_string(),
54                transport: TransportType::Stdio,
55                optimization: OptimizationLevel::Debug,
56            },
57            tools: vec![ToolDef::Native {
58                name: "test_tool".to_string(),
59                description: "Test tool".to_string(),
60                handler: HandlerRef {
61                    path: "handlers::test_handler".to_string(),
62                    inline: None,
63                },
64                params: ParamSchema {
65                    fields: {
66                        let mut map = FxHashMap::default();
67                        map.insert("input".to_string(), ParamType::Simple(SimpleType::String));
68                        map
69                    },
70                },
71                timeout_ms: None,
72            }],
73            resources: vec![],
74            prompts: vec![],
75            state: None,
76        }
77    }
78
79    #[test]
80    fn test_generate_all() {
81        let config = create_test_config();
82        let result = generate_all(&config);
83
84        assert!(result.is_ok());
85        let code = result.unwrap();
86
87        // Check for generated header
88        assert!(code.contains("// Auto-generated by pforge"));
89        assert!(code.contains("// DO NOT EDIT"));
90
91        // Check for imports
92        assert!(code.contains("use pforge_runtime::*"));
93        assert!(code.contains("use serde::{Deserialize, Serialize}"));
94        assert!(code.contains("use schemars::JsonSchema"));
95
96        // Check for param struct
97        assert!(code.contains("pub struct TestToolParams"));
98
99        // Check for registration function
100        assert!(code.contains("pub fn register_handlers"));
101    }
102
103    #[test]
104    fn test_generate_all_empty_tools() {
105        let config = ForgeConfig {
106            forge: ForgeMetadata {
107                name: "empty".to_string(),
108                version: "1.0.0".to_string(),
109                transport: TransportType::Stdio,
110                optimization: OptimizationLevel::Debug,
111            },
112            tools: vec![],
113            resources: vec![],
114            prompts: vec![],
115            state: None,
116        };
117
118        let result = generate_all(&config);
119        assert!(result.is_ok());
120        let code = result.unwrap();
121        assert!(code.contains("pub fn register_handlers"));
122    }
123
124    #[test]
125    fn test_write_generated_code() {
126        let config = create_test_config();
127        let temp_dir = std::env::temp_dir();
128        let output_path = temp_dir.join("test_generated.rs");
129
130        let result = write_generated_code(&config, &output_path);
131        assert!(result.is_ok());
132
133        // Verify file was created
134        assert!(output_path.exists());
135
136        // Verify content
137        let content = std::fs::read_to_string(&output_path).unwrap();
138        assert!(content.contains("pub struct TestToolParams"));
139
140        // Cleanup
141        std::fs::remove_file(&output_path).ok();
142    }
143
144    #[test]
145    fn test_write_generated_code_invalid_path() {
146        let config = create_test_config();
147        let invalid_path = Path::new("/nonexistent/directory/test.rs");
148
149        let result = write_generated_code(&config, invalid_path);
150        assert!(result.is_err());
151        assert!(matches!(result.unwrap_err(), CodegenError::IoError(_, _)));
152    }
153}