1use crate::{Error, HandlerRegistry, Result};
2use async_trait::async_trait;
3use pforge_config::ForgeConfig;
4use pmcp::server::ToolHandler;
5use serde_json::Value;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9pub struct McpServer {
11 config: ForgeConfig,
12 registry: Arc<RwLock<HandlerRegistry>>,
13}
14
15struct PforgeToolAdapter {
17 registry: Arc<RwLock<HandlerRegistry>>,
18 tool_name: String,
19 description: Option<String>,
20}
21
22#[async_trait]
23impl ToolHandler for PforgeToolAdapter {
24 async fn handle(
25 &self,
26 args: Value,
27 _extra: pmcp::server::cancellation::RequestHandlerExtra,
28 ) -> pmcp::Result<Value> {
29 let params = serde_json::to_vec(&args)
31 .map_err(|e| pmcp::Error::protocol_msg(format!("Failed to serialize args: {}", e)))?;
32
33 let registry = self.registry.read().await;
34 let result_bytes = registry
35 .dispatch(&self.tool_name, ¶ms)
36 .await
37 .map_err(|e| pmcp::Error::protocol_msg(e.to_string()))?;
38
39 let result: Value = serde_json::from_slice(&result_bytes).map_err(|e| {
41 pmcp::Error::protocol_msg(format!("Failed to deserialize result: {}", e))
42 })?;
43
44 Ok(result)
45 }
46
47 fn metadata(&self) -> Option<pmcp::types::ToolInfo> {
48 let input_schema = if let Ok(guard) = self.registry.try_read() {
50 if let Some(schema) = guard.get_input_schema(&self.tool_name) {
51 serde_json::to_value(&schema).unwrap_or_else(|_| {
53 serde_json::json!({
54 "type": "object",
55 "properties": {}
56 })
57 })
58 } else {
59 serde_json::json!({
60 "type": "object",
61 "properties": {}
62 })
63 }
64 } else {
65 serde_json::json!({
67 "type": "object",
68 "properties": {}
69 })
70 };
71
72 Some(pmcp::types::ToolInfo::new(
73 self.tool_name.clone(),
74 self.description.clone(),
75 input_schema,
76 ))
77 }
78}
79
80impl McpServer {
81 pub fn new(config: ForgeConfig) -> Self {
83 Self {
84 config,
85 registry: Arc::new(RwLock::new(HandlerRegistry::new())),
86 }
87 }
88
89 pub async fn register_handlers(&self) -> Result<()> {
91 let mut registry = self.registry.write().await;
92
93 for tool in &self.config.tools {
94 match tool {
95 pforge_config::ToolDef::Native { name, .. } => {
96 eprintln!(
98 "Note: Native handler '{}' requires handler implementation",
99 name
100 );
101 }
102 pforge_config::ToolDef::Cli {
103 name,
104 command,
105 args,
106 cwd,
107 env,
108 stream,
109 timeout_ms,
110 ..
111 } => {
112 use crate::handlers::cli::CliHandler;
113 let handler = CliHandler::new(
114 command.clone(),
115 args.clone(),
116 cwd.clone(),
117 env.clone(),
118 *timeout_ms,
119 *stream,
120 );
121 registry.register(name, handler);
122 eprintln!("Registered CLI handler: {}", name);
123 }
124 pforge_config::ToolDef::Http {
125 name,
126 endpoint,
127 method,
128 headers,
129 auth,
130 timeout_ms,
131 ..
132 } => {
133 use crate::handlers::http::{
134 AuthConfig as HttpAuthConfig, HttpHandler, HttpMethod as HandlerHttpMethod,
135 };
136
137 let handler_method = match method {
138 pforge_config::HttpMethod::Get => HandlerHttpMethod::Get,
139 pforge_config::HttpMethod::Post => HandlerHttpMethod::Post,
140 pforge_config::HttpMethod::Put => HandlerHttpMethod::Put,
141 pforge_config::HttpMethod::Delete => HandlerHttpMethod::Delete,
142 pforge_config::HttpMethod::Patch => HandlerHttpMethod::Patch,
143 };
144
145 let handler_auth = auth.as_ref().map(|a| match a {
146 pforge_config::AuthConfig::Bearer { token } => HttpAuthConfig::Bearer {
147 token: token.clone(),
148 },
149 pforge_config::AuthConfig::Basic { username, password } => {
150 HttpAuthConfig::Basic {
151 username: username.clone(),
152 password: password.clone(),
153 }
154 }
155 pforge_config::AuthConfig::ApiKey { key, header } => {
156 HttpAuthConfig::ApiKey {
157 key: key.clone(),
158 header: header.clone(),
159 }
160 }
161 });
162
163 let handler = HttpHandler::new(
164 endpoint.clone(),
165 handler_method,
166 headers.clone(),
167 handler_auth,
168 *timeout_ms,
169 );
170 registry.register(name, handler);
171 eprintln!("Registered HTTP handler: {}", name);
172 }
173 pforge_config::ToolDef::Pipeline { name, steps, .. } => {
174 use crate::handlers::pipeline::PipelineHandlerAdapter;
175 let handler =
176 PipelineHandlerAdapter::from_config_steps(steps, self.registry.clone());
177 registry.register(name, handler);
178 eprintln!("Registered Pipeline handler: {}", name);
179 }
180 }
181 }
182
183 Ok(())
184 }
185
186 pub async fn run(&self) -> Result<()> {
188 eprintln!(
189 "Starting MCP server: {} v{}",
190 self.config.forge.name, self.config.forge.version
191 );
192 eprintln!("Transport: {:?}", self.config.forge.transport);
193 eprintln!("Tools registered: {}", self.config.tools.len());
194
195 self.register_handlers().await?;
197
198 let mut builder = pmcp::Server::builder()
200 .name(&self.config.forge.name)
201 .version(&self.config.forge.version);
202
203 for tool in &self.config.tools {
205 let (tool_name, description) = match tool {
206 pforge_config::ToolDef::Native {
207 name, description, ..
208 } => (name.clone(), Some(description.clone())),
209 pforge_config::ToolDef::Cli {
210 name, description, ..
211 } => (name.clone(), Some(description.clone())),
212 pforge_config::ToolDef::Http {
213 name, description, ..
214 } => (name.clone(), Some(description.clone())),
215 pforge_config::ToolDef::Pipeline {
216 name, description, ..
217 } => (name.clone(), Some(description.clone())),
218 };
219
220 let adapter = PforgeToolAdapter {
221 registry: self.registry.clone(),
222 tool_name: tool_name.clone(),
223 description,
224 };
225 builder = builder.tool(&tool_name, adapter);
226 }
227
228 let server = builder
229 .build()
230 .map_err(|e| Error::Handler(format!("Failed to build MCP server: {}", e)))?;
231
232 eprintln!("MCP server ready, starting protocol loop...");
233
234 match self.config.forge.transport {
236 pforge_config::TransportType::Stdio => {
237 server
238 .run_stdio()
239 .await
240 .map_err(|e| Error::Handler(format!("MCP server error: {}", e)))?;
241 }
242 pforge_config::TransportType::Sse => {
243 use pmcp::shared::{OptimizedSseConfig, OptimizedSseTransport};
244 use std::time::Duration;
245
246 let config = OptimizedSseConfig {
247 url: "http://localhost:8080/sse".to_string(),
248 connection_timeout: Duration::from_secs(30),
249 keepalive_interval: Duration::from_secs(15),
250 max_reconnects: 5,
251 reconnect_delay: Duration::from_secs(1),
252 buffer_size: 100,
253 flush_interval: Duration::from_millis(100),
254 enable_pooling: true,
255 max_connections: 10,
256 enable_compression: false,
257 };
258 let transport = OptimizedSseTransport::new(config);
259 server
260 .run(transport)
261 .await
262 .map_err(|e| Error::Handler(format!("MCP server error: {}", e)))?;
263 }
264 pforge_config::TransportType::WebSocket => {
265 use pmcp::shared::{WebSocketConfig, WebSocketTransport};
266 use std::time::Duration;
267
268 let url = "ws://localhost:8080/ws"
269 .parse()
270 .map_err(|e| Error::Handler(format!("Invalid WebSocket URL: {}", e)))?;
271 let config = WebSocketConfig {
272 url,
273 auto_reconnect: true,
274 reconnect_delay: Duration::from_secs(1),
275 max_reconnect_delay: Duration::from_secs(30),
276 max_reconnect_attempts: Some(5),
277 ping_interval: Some(Duration::from_secs(30)),
278 request_timeout: Duration::from_secs(10),
279 };
280 let transport = WebSocketTransport::new(config);
281 server
282 .run(transport)
283 .await
284 .map_err(|e| Error::Handler(format!("MCP server error: {}", e)))?;
285 }
286 }
287
288 eprintln!("\nShutting down...");
289 Ok(())
290 }
291
292 pub fn registry(&self) -> Arc<RwLock<HandlerRegistry>> {
294 self.registry.clone()
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use pforge_config::{ForgeMetadata, ParamSchema, ToolDef, TransportType};
302
303 fn create_test_config() -> ForgeConfig {
304 ForgeConfig {
305 forge: ForgeMetadata {
306 name: "test-server".to_string(),
307 version: "0.1.0".to_string(),
308 transport: TransportType::Stdio,
309 optimization: pforge_config::OptimizationLevel::Debug,
310 },
311 tools: vec![],
312 resources: vec![],
313 prompts: vec![],
314 state: None,
315 }
316 }
317
318 #[test]
319 fn test_server_new() {
320 let config = create_test_config();
321 let server = McpServer::new(config);
322
323 assert_eq!(server.config.forge.name, "test-server");
324 assert_eq!(server.config.forge.version, "0.1.0");
325 }
326
327 #[tokio::test]
328 async fn test_register_handlers_cli() {
329 let mut config = create_test_config();
330 config.tools.push(ToolDef::Cli {
331 name: "test_cli".to_string(),
332 description: "Test CLI handler".to_string(),
333 command: "echo".to_string(),
334 args: vec!["hello".to_string()],
335 cwd: None,
336 env: rustc_hash::FxHashMap::default(),
337 stream: false,
338 timeout_ms: None,
339 });
340
341 let server = McpServer::new(config);
342 let result = server.register_handlers().await;
343
344 assert!(result.is_ok());
345 }
346
347 #[tokio::test]
348 async fn test_register_handlers_http() {
349 let mut config = create_test_config();
350 config.tools.push(ToolDef::Http {
351 name: "test_http".to_string(),
352 description: "Test HTTP handler".to_string(),
353 endpoint: "https://api.example.com".to_string(),
354 method: pforge_config::HttpMethod::Get,
355 headers: rustc_hash::FxHashMap::default(),
356 auth: None,
357 timeout_ms: None,
358 });
359
360 let server = McpServer::new(config);
361 let result = server.register_handlers().await;
362
363 assert!(result.is_ok());
364 }
365
366 #[tokio::test]
367 async fn test_register_handlers_native() {
368 let mut config = create_test_config();
369 config.tools.push(ToolDef::Native {
370 name: "test_native".to_string(),
371 description: "Test native handler".to_string(),
372 handler: pforge_config::HandlerRef {
373 path: "handlers::test::TestHandler".to_string(),
374 inline: None,
375 },
376 params: ParamSchema {
377 fields: rustc_hash::FxHashMap::default(),
378 },
379 timeout_ms: Some(5000),
380 });
381
382 let server = McpServer::new(config);
383 let result = server.register_handlers().await;
384
385 assert!(result.is_ok());
387 }
388
389 #[tokio::test]
390 async fn test_registry_access() {
391 let config = create_test_config();
392 let server = McpServer::new(config);
393
394 let registry = server.registry();
395 let _lock = registry.read().await;
396
397 }
399
400 #[tokio::test]
401 async fn test_registry_returns_actual_registry() {
402 let mut config = create_test_config();
404 config.tools.push(ToolDef::Cli {
405 name: "test_cli".to_string(),
406 description: "Test CLI".to_string(),
407 command: "echo".to_string(),
408 args: vec!["test".to_string()],
409 cwd: None,
410 env: rustc_hash::FxHashMap::default(),
411 stream: false,
412 timeout_ms: None,
413 });
414
415 let server = McpServer::new(config);
416 server.register_handlers().await.unwrap();
417
418 let registry = server.registry();
420 let reg = registry.read().await;
421
422 assert_eq!(reg.len(), 1, "Registry should contain registered handler");
424 }
425
426 #[tokio::test]
427 async fn test_register_handlers_pipeline() {
428 let mut config = create_test_config();
429 config.tools.push(ToolDef::Pipeline {
430 name: "test_pipeline".to_string(),
431 description: "Test pipeline handler".to_string(),
432 steps: vec![],
433 });
434
435 let server = McpServer::new(config);
436 let result = server.register_handlers().await;
437 assert!(result.is_ok());
438
439 let registry = server.registry();
441 let reg = registry.read().await;
442 assert_eq!(reg.len(), 1, "Pipeline handler should be registered");
443 }
444
445 #[tokio::test]
446 async fn test_server_with_multiple_tools() {
447 let mut config = create_test_config();
448
449 config.tools.push(ToolDef::Cli {
450 name: "cli1".to_string(),
451 description: "CLI 1".to_string(),
452 command: "echo".to_string(),
453 args: vec![],
454 cwd: None,
455 env: rustc_hash::FxHashMap::default(),
456 stream: false,
457 timeout_ms: None,
458 });
459
460 config.tools.push(ToolDef::Http {
461 name: "http1".to_string(),
462 description: "HTTP 1".to_string(),
463 endpoint: "https://example.com".to_string(),
464 method: pforge_config::HttpMethod::Get,
465 headers: rustc_hash::FxHashMap::default(),
466 auth: None,
467 timeout_ms: None,
468 });
469
470 let server = McpServer::new(config);
471 assert_eq!(server.config.tools.len(), 2);
472
473 let result = server.register_handlers().await;
474 assert!(result.is_ok());
475 }
476}