pforge_runtime/
server.rs

1use crate::{Error, HandlerRegistry, Result};
2use async_trait::async_trait;
3use pforge_config::ForgeConfig;
4use pmcp::server::ToolHandler;
5use serde_json::Value;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9/// MCP Server implementation
10pub struct McpServer {
11    config: ForgeConfig,
12    registry: Arc<RwLock<HandlerRegistry>>,
13}
14
15/// Adapter to wrap pforge handlers as pmcp ToolHandler
16struct PforgeToolAdapter {
17    registry: Arc<RwLock<HandlerRegistry>>,
18    tool_name: String,
19    description: Option<String>,
20}
21
22#[async_trait]
23impl ToolHandler for PforgeToolAdapter {
24    async fn handle(
25        &self,
26        args: Value,
27        _extra: pmcp::server::cancellation::RequestHandlerExtra,
28    ) -> pmcp::Result<Value> {
29        // Serialize args to bytes for pforge dispatch
30        let params = serde_json::to_vec(&args)
31            .map_err(|e| pmcp::Error::protocol_msg(format!("Failed to serialize args: {}", e)))?;
32
33        let registry = self.registry.read().await;
34        let result_bytes = registry
35            .dispatch(&self.tool_name, &params)
36            .await
37            .map_err(|e| pmcp::Error::protocol_msg(e.to_string()))?;
38
39        // Deserialize result back to Value
40        let result: Value = serde_json::from_slice(&result_bytes).map_err(|e| {
41            pmcp::Error::protocol_msg(format!("Failed to deserialize result: {}", e))
42        })?;
43
44        Ok(result)
45    }
46
47    fn metadata(&self) -> Option<pmcp::types::ToolInfo> {
48        // Try to get actual schema from registry (may fail if lock is held)
49        let input_schema = if let Ok(guard) = self.registry.try_read() {
50            if let Some(schema) = guard.get_input_schema(&self.tool_name) {
51                // Convert RootSchema to serde_json::Value
52                serde_json::to_value(&schema).unwrap_or_else(|_| {
53                    serde_json::json!({
54                        "type": "object",
55                        "properties": {}
56                    })
57                })
58            } else {
59                serde_json::json!({
60                    "type": "object",
61                    "properties": {}
62                })
63            }
64        } else {
65            // Fallback if lock unavailable
66            serde_json::json!({
67                "type": "object",
68                "properties": {}
69            })
70        };
71
72        Some(pmcp::types::ToolInfo::new(
73            self.tool_name.clone(),
74            self.description.clone(),
75            input_schema,
76        ))
77    }
78}
79
80impl McpServer {
81    /// Create a new MCP server from configuration
82    pub fn new(config: ForgeConfig) -> Self {
83        Self {
84            config,
85            registry: Arc::new(RwLock::new(HandlerRegistry::new())),
86        }
87    }
88
89    /// Register all handlers from configuration
90    pub async fn register_handlers(&self) -> Result<()> {
91        let mut registry = self.registry.write().await;
92
93        for tool in &self.config.tools {
94            match tool {
95                pforge_config::ToolDef::Native { name, .. } => {
96                    // Native handlers will be registered by generated code
97                    eprintln!(
98                        "Note: Native handler '{}' requires handler implementation",
99                        name
100                    );
101                }
102                pforge_config::ToolDef::Cli {
103                    name,
104                    command,
105                    args,
106                    cwd,
107                    env,
108                    stream,
109                    timeout_ms,
110                    ..
111                } => {
112                    use crate::handlers::cli::CliHandler;
113                    let handler = CliHandler::new(
114                        command.clone(),
115                        args.clone(),
116                        cwd.clone(),
117                        env.clone(),
118                        *timeout_ms,
119                        *stream,
120                    );
121                    registry.register(name, handler);
122                    eprintln!("Registered CLI handler: {}", name);
123                }
124                pforge_config::ToolDef::Http {
125                    name,
126                    endpoint,
127                    method,
128                    headers,
129                    auth,
130                    timeout_ms,
131                    ..
132                } => {
133                    use crate::handlers::http::{
134                        AuthConfig as HttpAuthConfig, HttpHandler, HttpMethod as HandlerHttpMethod,
135                    };
136
137                    let handler_method = match method {
138                        pforge_config::HttpMethod::Get => HandlerHttpMethod::Get,
139                        pforge_config::HttpMethod::Post => HandlerHttpMethod::Post,
140                        pforge_config::HttpMethod::Put => HandlerHttpMethod::Put,
141                        pforge_config::HttpMethod::Delete => HandlerHttpMethod::Delete,
142                        pforge_config::HttpMethod::Patch => HandlerHttpMethod::Patch,
143                    };
144
145                    let handler_auth = auth.as_ref().map(|a| match a {
146                        pforge_config::AuthConfig::Bearer { token } => HttpAuthConfig::Bearer {
147                            token: token.clone(),
148                        },
149                        pforge_config::AuthConfig::Basic { username, password } => {
150                            HttpAuthConfig::Basic {
151                                username: username.clone(),
152                                password: password.clone(),
153                            }
154                        }
155                        pforge_config::AuthConfig::ApiKey { key, header } => {
156                            HttpAuthConfig::ApiKey {
157                                key: key.clone(),
158                                header: header.clone(),
159                            }
160                        }
161                    });
162
163                    let handler = HttpHandler::new(
164                        endpoint.clone(),
165                        handler_method,
166                        headers.clone(),
167                        handler_auth,
168                        *timeout_ms,
169                    );
170                    registry.register(name, handler);
171                    eprintln!("Registered HTTP handler: {}", name);
172                }
173                pforge_config::ToolDef::Pipeline { name, steps, .. } => {
174                    use crate::handlers::pipeline::PipelineHandlerAdapter;
175                    let handler =
176                        PipelineHandlerAdapter::from_config_steps(steps, self.registry.clone());
177                    registry.register(name, handler);
178                    eprintln!("Registered Pipeline handler: {}", name);
179                }
180            }
181        }
182
183        Ok(())
184    }
185
186    /// Run the MCP server using pmcp protocol implementation
187    pub async fn run(&self) -> Result<()> {
188        eprintln!(
189            "Starting MCP server: {} v{}",
190            self.config.forge.name, self.config.forge.version
191        );
192        eprintln!("Transport: {:?}", self.config.forge.transport);
193        eprintln!("Tools registered: {}", self.config.tools.len());
194
195        // Register handlers in pforge registry
196        self.register_handlers().await?;
197
198        // Build pmcp server with tool adapters
199        let mut builder = pmcp::Server::builder()
200            .name(&self.config.forge.name)
201            .version(&self.config.forge.version);
202
203        // Add tool adapters for each registered tool
204        for tool in &self.config.tools {
205            let (tool_name, description) = match tool {
206                pforge_config::ToolDef::Native {
207                    name, description, ..
208                } => (name.clone(), Some(description.clone())),
209                pforge_config::ToolDef::Cli {
210                    name, description, ..
211                } => (name.clone(), Some(description.clone())),
212                pforge_config::ToolDef::Http {
213                    name, description, ..
214                } => (name.clone(), Some(description.clone())),
215                pforge_config::ToolDef::Pipeline {
216                    name, description, ..
217                } => (name.clone(), Some(description.clone())),
218            };
219
220            let adapter = PforgeToolAdapter {
221                registry: self.registry.clone(),
222                tool_name: tool_name.clone(),
223                description,
224            };
225            builder = builder.tool(&tool_name, adapter);
226        }
227
228        let server = builder
229            .build()
230            .map_err(|e| Error::Handler(format!("Failed to build MCP server: {}", e)))?;
231
232        eprintln!("MCP server ready, starting protocol loop...");
233
234        // Run the server with appropriate transport
235        match self.config.forge.transport {
236            pforge_config::TransportType::Stdio => {
237                server
238                    .run_stdio()
239                    .await
240                    .map_err(|e| Error::Handler(format!("MCP server error: {}", e)))?;
241            }
242            pforge_config::TransportType::Sse => {
243                use pmcp::shared::{OptimizedSseConfig, OptimizedSseTransport};
244                use std::time::Duration;
245
246                let config = OptimizedSseConfig {
247                    url: "http://localhost:8080/sse".to_string(),
248                    connection_timeout: Duration::from_secs(30),
249                    keepalive_interval: Duration::from_secs(15),
250                    max_reconnects: 5,
251                    reconnect_delay: Duration::from_secs(1),
252                    buffer_size: 100,
253                    flush_interval: Duration::from_millis(100),
254                    enable_pooling: true,
255                    max_connections: 10,
256                    enable_compression: false,
257                };
258                let transport = OptimizedSseTransport::new(config);
259                server
260                    .run(transport)
261                    .await
262                    .map_err(|e| Error::Handler(format!("MCP server error: {}", e)))?;
263            }
264            pforge_config::TransportType::WebSocket => {
265                use pmcp::shared::{WebSocketConfig, WebSocketTransport};
266                use std::time::Duration;
267
268                let url = "ws://localhost:8080/ws"
269                    .parse()
270                    .map_err(|e| Error::Handler(format!("Invalid WebSocket URL: {}", e)))?;
271                let config = WebSocketConfig {
272                    url,
273                    auto_reconnect: true,
274                    reconnect_delay: Duration::from_secs(1),
275                    max_reconnect_delay: Duration::from_secs(30),
276                    max_reconnect_attempts: Some(5),
277                    ping_interval: Some(Duration::from_secs(30)),
278                    request_timeout: Duration::from_secs(10),
279                };
280                let transport = WebSocketTransport::new(config);
281                server
282                    .run(transport)
283                    .await
284                    .map_err(|e| Error::Handler(format!("MCP server error: {}", e)))?;
285            }
286        }
287
288        eprintln!("\nShutting down...");
289        Ok(())
290    }
291
292    /// Get the handler registry (for testing)
293    pub fn registry(&self) -> Arc<RwLock<HandlerRegistry>> {
294        self.registry.clone()
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use pforge_config::{ForgeMetadata, ParamSchema, ToolDef, TransportType};
302
303    fn create_test_config() -> ForgeConfig {
304        ForgeConfig {
305            forge: ForgeMetadata {
306                name: "test-server".to_string(),
307                version: "0.1.0".to_string(),
308                transport: TransportType::Stdio,
309                optimization: pforge_config::OptimizationLevel::Debug,
310            },
311            tools: vec![],
312            resources: vec![],
313            prompts: vec![],
314            state: None,
315        }
316    }
317
318    #[test]
319    fn test_server_new() {
320        let config = create_test_config();
321        let server = McpServer::new(config);
322
323        assert_eq!(server.config.forge.name, "test-server");
324        assert_eq!(server.config.forge.version, "0.1.0");
325    }
326
327    #[tokio::test]
328    async fn test_register_handlers_cli() {
329        let mut config = create_test_config();
330        config.tools.push(ToolDef::Cli {
331            name: "test_cli".to_string(),
332            description: "Test CLI handler".to_string(),
333            command: "echo".to_string(),
334            args: vec!["hello".to_string()],
335            cwd: None,
336            env: rustc_hash::FxHashMap::default(),
337            stream: false,
338            timeout_ms: None,
339        });
340
341        let server = McpServer::new(config);
342        let result = server.register_handlers().await;
343
344        assert!(result.is_ok());
345    }
346
347    #[tokio::test]
348    async fn test_register_handlers_http() {
349        let mut config = create_test_config();
350        config.tools.push(ToolDef::Http {
351            name: "test_http".to_string(),
352            description: "Test HTTP handler".to_string(),
353            endpoint: "https://api.example.com".to_string(),
354            method: pforge_config::HttpMethod::Get,
355            headers: rustc_hash::FxHashMap::default(),
356            auth: None,
357            timeout_ms: None,
358        });
359
360        let server = McpServer::new(config);
361        let result = server.register_handlers().await;
362
363        assert!(result.is_ok());
364    }
365
366    #[tokio::test]
367    async fn test_register_handlers_native() {
368        let mut config = create_test_config();
369        config.tools.push(ToolDef::Native {
370            name: "test_native".to_string(),
371            description: "Test native handler".to_string(),
372            handler: pforge_config::HandlerRef {
373                path: "handlers::test::TestHandler".to_string(),
374                inline: None,
375            },
376            params: ParamSchema {
377                fields: rustc_hash::FxHashMap::default(),
378            },
379            timeout_ms: Some(5000),
380        });
381
382        let server = McpServer::new(config);
383        let result = server.register_handlers().await;
384
385        // Should succeed (native handlers registered by generated code)
386        assert!(result.is_ok());
387    }
388
389    #[tokio::test]
390    async fn test_registry_access() {
391        let config = create_test_config();
392        let server = McpServer::new(config);
393
394        let registry = server.registry();
395        let _lock = registry.read().await;
396
397        // Registry is accessible (test passes if no panic)
398    }
399
400    #[tokio::test]
401    async fn test_registry_returns_actual_registry() {
402        // This test catches mutation: registry() returning a new empty registry
403        let mut config = create_test_config();
404        config.tools.push(ToolDef::Cli {
405            name: "test_cli".to_string(),
406            description: "Test CLI".to_string(),
407            command: "echo".to_string(),
408            args: vec!["test".to_string()],
409            cwd: None,
410            env: rustc_hash::FxHashMap::default(),
411            stream: false,
412            timeout_ms: None,
413        });
414
415        let server = McpServer::new(config);
416        server.register_handlers().await.unwrap();
417
418        // Get registry and verify the handler is registered
419        let registry = server.registry();
420        let reg = registry.read().await;
421
422        // The CLI handler should be registered - verify via len
423        assert_eq!(reg.len(), 1, "Registry should contain registered handler");
424    }
425
426    #[tokio::test]
427    async fn test_register_handlers_pipeline() {
428        let mut config = create_test_config();
429        config.tools.push(ToolDef::Pipeline {
430            name: "test_pipeline".to_string(),
431            description: "Test pipeline handler".to_string(),
432            steps: vec![],
433        });
434
435        let server = McpServer::new(config);
436        let result = server.register_handlers().await;
437        assert!(result.is_ok());
438
439        // Verify the pipeline is actually registered
440        let registry = server.registry();
441        let reg = registry.read().await;
442        assert_eq!(reg.len(), 1, "Pipeline handler should be registered");
443    }
444
445    #[tokio::test]
446    async fn test_server_with_multiple_tools() {
447        let mut config = create_test_config();
448
449        config.tools.push(ToolDef::Cli {
450            name: "cli1".to_string(),
451            description: "CLI 1".to_string(),
452            command: "echo".to_string(),
453            args: vec![],
454            cwd: None,
455            env: rustc_hash::FxHashMap::default(),
456            stream: false,
457            timeout_ms: None,
458        });
459
460        config.tools.push(ToolDef::Http {
461            name: "http1".to_string(),
462            description: "HTTP 1".to_string(),
463            endpoint: "https://example.com".to_string(),
464            method: pforge_config::HttpMethod::Get,
465            headers: rustc_hash::FxHashMap::default(),
466            auth: None,
467            timeout_ms: None,
468        });
469
470        let server = McpServer::new(config);
471        assert_eq!(server.config.tools.len(), 2);
472
473        let result = server.register_handlers().await;
474        assert!(result.is_ok());
475    }
476}