pulseengine_mcp_cli/
server.rs

1//! Server integration utilities
2
3use crate::{CliError, McpConfiguration};
4use pulseengine_mcp_protocol::ServerInfo;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::time::Duration;
8use tracing::info;
9
10/// Run an MCP server with the given configuration
11pub async fn run_server<C>(_config: C) -> std::result::Result<(), CliError>
12where
13    C: McpConfiguration,
14{
15    // This is a placeholder implementation
16    // In the full implementation, this would:
17    // 1. Initialize logging
18    // 2. Create server from configuration
19    // 3. Set up signal handling
20    // 4. Start the server
21    // 5. Handle graceful shutdown
22
23    info!("Starting MCP server...");
24
25    // Initialize logging
26    _config.initialize_logging()?;
27
28    // Validate configuration
29    _config.validate()?;
30
31    info!("Server info: {:?}", _config.get_server_info());
32
33    // TODO: Integrate with actual server implementation
34    // let server = create_server_from_config(config).await?;
35    // server.run().await?;
36
37    Err(CliError::server_setup(
38        "Server implementation not yet complete",
39    ))
40}
41
42/// Create server configuration builder
43pub fn server_builder() -> ServerBuilder {
44    ServerBuilder::new()
45}
46
47/// Transport type for the MCP server
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub enum TransportType {
50    /// HTTP transport
51    Http {
52        port: u16,
53        host: String,
54    },
55    /// WebSocket transport
56    WebSocket {
57        port: u16,
58        host: String,
59        path: String,
60    },
61    /// Standard I/O transport
62    Stdio,
63}
64
65impl Default for TransportType {
66    fn default() -> Self {
67        Self::Http {
68            port: 8080,
69            host: "localhost".to_string(),
70        }
71    }
72}
73
74/// CORS policy configuration
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct CorsPolicy {
77    pub allowed_origins: Vec<String>,
78    pub allowed_methods: Vec<String>,
79    pub allowed_headers: Vec<String>,
80    pub allow_credentials: bool,
81    pub max_age: Option<Duration>,
82}
83
84impl CorsPolicy {
85    /// Create a permissive CORS policy (allows all origins)
86    pub fn permissive() -> Self {
87        Self {
88            allowed_origins: vec!["*".to_string()],
89            allowed_methods: vec![
90                "GET".to_string(),
91                "POST".to_string(),
92                "PUT".to_string(),
93                "DELETE".to_string(),
94                "OPTIONS".to_string(),
95            ],
96            allowed_headers: vec!["*".to_string()],
97            allow_credentials: false,
98            max_age: Some(Duration::from_secs(3600)),
99        }
100    }
101
102    /// Create a strict CORS policy
103    pub fn strict() -> Self {
104        Self {
105            allowed_origins: vec![],
106            allowed_methods: vec!["GET".to_string(), "POST".to_string()],
107            allowed_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
108            allow_credentials: true,
109            max_age: Some(Duration::from_secs(300)),
110        }
111    }
112
113    /// Add allowed origin
114    pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
115        self.allowed_origins.push(origin.into());
116        self
117    }
118
119    /// Add allowed method
120    pub fn allow_method(mut self, method: impl Into<String>) -> Self {
121        self.allowed_methods.push(method.into());
122        self
123    }
124}
125
126/// Custom endpoint configuration
127#[derive(Debug, Clone)]
128pub struct CustomEndpoint {
129    pub path: String,
130    pub method: String,
131    pub handler_name: String,
132}
133
134impl CustomEndpoint {
135    pub fn new(path: impl Into<String>, method: impl Into<String>, handler_name: impl Into<String>) -> Self {
136        Self {
137            path: path.into(),
138            method: method.into(),
139            handler_name: handler_name.into(),
140        }
141    }
142}
143
144/// Middleware configuration
145#[derive(Debug, Clone)]
146pub struct MiddlewareConfig {
147    pub name: String,
148    pub config: HashMap<String, String>,
149}
150
151impl MiddlewareConfig {
152    pub fn new(name: impl Into<String>) -> Self {
153        Self {
154            name: name.into(),
155            config: HashMap::new(),
156        }
157    }
158
159    pub fn with_config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
160        self.config.insert(key.into(), value.into());
161        self
162    }
163}
164
165/// Builder for server configuration
166pub struct ServerBuilder {
167    server_info: Option<ServerInfo>,
168    transport: Option<TransportType>,
169    cors_policy: Option<CorsPolicy>,
170    middleware: Vec<MiddlewareConfig>,
171    custom_endpoints: Vec<CustomEndpoint>,
172    metrics_endpoint: Option<String>,
173    health_endpoint: Option<String>,
174    connection_timeout: Option<Duration>,
175    max_connections: Option<usize>,
176    enable_compression: bool,
177    enable_tls: bool,
178    tls_cert_path: Option<String>,
179    tls_key_path: Option<String>,
180}
181
182impl ServerBuilder {
183    pub fn new() -> Self {
184        Self {
185            server_info: None,
186            transport: None,
187            cors_policy: None,
188            middleware: Vec::new(),
189            custom_endpoints: Vec::new(),
190            metrics_endpoint: None,
191            health_endpoint: None,
192            connection_timeout: None,
193            max_connections: None,
194            enable_compression: false,
195            enable_tls: false,
196            tls_cert_path: None,
197            tls_key_path: None,
198        }
199    }
200
201    pub fn with_server_info(mut self, info: ServerInfo) -> Self {
202        self.server_info = Some(info);
203        self
204    }
205
206    pub fn with_port(mut self, port: u16) -> Self {
207        self.transport = Some(TransportType::Http {
208            port,
209            host: "localhost".to_string(),
210        });
211        self
212    }
213
214    pub fn with_transport(mut self, transport: TransportType) -> Self {
215        self.transport = Some(transport);
216        self
217    }
218
219    pub fn with_cors_policy(mut self, cors: CorsPolicy) -> Self {
220        self.cors_policy = Some(cors);
221        self
222    }
223
224    pub fn with_middleware(mut self, middleware: MiddlewareConfig) -> Self {
225        self.middleware.push(middleware);
226        self
227    }
228
229    pub fn with_metrics_endpoint(mut self, path: impl Into<String>) -> Self {
230        self.metrics_endpoint = Some(path.into());
231        self
232    }
233
234    pub fn with_health_endpoint(mut self, path: impl Into<String>) -> Self {
235        self.health_endpoint = Some(path.into());
236        self
237    }
238
239    pub fn with_custom_endpoint(
240        mut self,
241        path: impl Into<String>,
242        method: impl Into<String>,
243        handler_name: impl Into<String>,
244    ) -> Self {
245        self.custom_endpoints.push(CustomEndpoint::new(path, method, handler_name));
246        self
247    }
248
249    pub fn with_connection_timeout(mut self, timeout: Duration) -> Self {
250        self.connection_timeout = Some(timeout);
251        self
252    }
253
254    pub fn with_max_connections(mut self, max: usize) -> Self {
255        self.max_connections = Some(max);
256        self
257    }
258
259    pub fn with_compression(mut self, enable: bool) -> Self {
260        self.enable_compression = enable;
261        self
262    }
263
264    pub fn with_tls(mut self, cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
265        self.enable_tls = true;
266        self.tls_cert_path = Some(cert_path.into());
267        self.tls_key_path = Some(key_path.into());
268        self
269    }
270
271    pub fn build(self) -> Result<BuiltServerConfig, CliError> {
272        Ok(BuiltServerConfig {
273            server_info: self
274                .server_info
275                .ok_or_else(|| CliError::configuration("Server info is required"))?,
276            transport: self.transport.unwrap_or_default(),
277            cors_policy: self.cors_policy,
278            middleware: self.middleware,
279            custom_endpoints: self.custom_endpoints,
280            metrics_endpoint: self.metrics_endpoint,
281            health_endpoint: self.health_endpoint,
282            connection_timeout: self.connection_timeout.unwrap_or(Duration::from_secs(30)),
283            max_connections: self.max_connections.unwrap_or(1000),
284            enable_compression: self.enable_compression,
285            enable_tls: self.enable_tls,
286            tls_cert_path: self.tls_cert_path,
287            tls_key_path: self.tls_key_path,
288        })
289    }
290}
291
292impl Default for ServerBuilder {
293    fn default() -> Self {
294        Self::new()
295    }
296}
297
298/// Built server configuration
299#[derive(Debug, Clone)]
300pub struct BuiltServerConfig {
301    pub server_info: ServerInfo,
302    pub transport: TransportType,
303    pub cors_policy: Option<CorsPolicy>,
304    pub middleware: Vec<MiddlewareConfig>,
305    pub custom_endpoints: Vec<CustomEndpoint>,
306    pub metrics_endpoint: Option<String>,
307    pub health_endpoint: Option<String>,
308    pub connection_timeout: Duration,
309    pub max_connections: usize,
310    pub enable_compression: bool,
311    pub enable_tls: bool,
312    pub tls_cert_path: Option<String>,
313    pub tls_key_path: Option<String>,
314}
315
316impl BuiltServerConfig {
317    /// Get the server port from transport configuration
318    pub fn port(&self) -> Option<u16> {
319        match &self.transport {
320            TransportType::Http { port, .. } | TransportType::WebSocket { port, .. } => Some(*port),
321            TransportType::Stdio => None,
322        }
323    }
324
325    /// Get the server host from transport configuration
326    pub fn host(&self) -> Option<&str> {
327        match &self.transport {
328            TransportType::Http { host, .. } | TransportType::WebSocket { host, .. } => Some(host),
329            TransportType::Stdio => None,
330        }
331    }
332
333    /// Check if TLS is enabled and properly configured
334    pub fn is_tls_configured(&self) -> bool {
335        self.enable_tls && self.tls_cert_path.is_some() && self.tls_key_path.is_some()
336    }
337}
338
339/// Authentication middleware configuration
340pub struct AuthMiddleware;
341
342impl AuthMiddleware {
343    pub fn bearer(api_key: impl Into<String>) -> MiddlewareConfig {
344        MiddlewareConfig::new("auth")
345            .with_config("api_key", api_key)
346            .with_config("type", "bearer")
347    }
348
349    pub fn basic_auth(username: impl Into<String>, password: impl Into<String>) -> MiddlewareConfig {
350        MiddlewareConfig::new("auth")
351            .with_config("username", username)
352            .with_config("password", password)
353            .with_config("type", "basic")
354    }
355}
356
357/// Rate limiting middleware configuration
358pub struct RateLimitMiddleware;
359
360impl RateLimitMiddleware {
361    pub fn per_second(requests_per_second: u32) -> MiddlewareConfig {
362        MiddlewareConfig::new("rate_limit")
363            .with_config("requests_per_second", requests_per_second.to_string())
364    }
365
366    pub fn with_burst(requests_per_second: u32, burst_size: u32) -> MiddlewareConfig {
367        MiddlewareConfig::new("rate_limit")
368            .with_config("requests_per_second", requests_per_second.to_string())
369            .with_config("burst_size", burst_size.to_string())
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use crate::config::create_server_info;
377
378    #[test]
379    fn test_server_builder_basic() {
380        let server_info = create_server_info(Some("test".to_string()), Some("1.0.0".to_string()));
381
382        let config = server_builder()
383            .with_server_info(server_info)
384            .with_port(3000)
385            .build()
386            .unwrap();
387
388        assert_eq!(config.port(), Some(3000));
389        assert_eq!(config.server_info.server_info.name, "test");
390    }
391
392    #[test]
393    fn test_server_builder_advanced() {
394        let server_info = create_server_info(Some("advanced".to_string()), Some("1.0.0".to_string()));
395        
396        let config = server_builder()
397            .with_server_info(server_info)
398            .with_transport(TransportType::Http {
399                port: 8080,
400                host: "0.0.0.0".to_string(),
401            })
402            .with_cors_policy(CorsPolicy::permissive())
403            .with_middleware(AuthMiddleware::bearer("secret-key"))
404            .with_middleware(RateLimitMiddleware::per_second(100))
405            .with_metrics_endpoint("/metrics")
406            .with_health_endpoint("/health")
407            .with_custom_endpoint("/api/v1/custom", "POST", "custom_handler")
408            .with_compression(true)
409            .build()
410            .unwrap();
411
412        assert_eq!(config.port(), Some(8080));
413        assert_eq!(config.host(), Some("0.0.0.0"));
414        assert!(config.cors_policy.is_some());
415        assert_eq!(config.middleware.len(), 2);
416        assert_eq!(config.custom_endpoints.len(), 1);
417        assert_eq!(config.metrics_endpoint, Some("/metrics".to_string()));
418        assert_eq!(config.health_endpoint, Some("/health".to_string()));
419        assert!(config.enable_compression);
420    }
421
422    #[test]
423    fn test_cors_policy() {
424        let cors = CorsPolicy::permissive()
425            .allow_origin("https://example.com")
426            .allow_method("PATCH");
427        
428        assert!(cors.allowed_origins.contains(&"*".to_string()));
429        assert!(cors.allowed_origins.contains(&"https://example.com".to_string()));
430        assert!(cors.allowed_methods.contains(&"PATCH".to_string()));
431    }
432
433    #[test]
434    fn test_transport_types() {
435        let http_transport = TransportType::Http {
436            port: 8080,
437            host: "localhost".to_string(),
438        };
439        
440        let ws_transport = TransportType::WebSocket {
441            port: 8081,
442            host: "localhost".to_string(),
443            path: "/ws".to_string(),
444        };
445        
446        let stdio_transport = TransportType::Stdio;
447        
448        assert!(matches!(http_transport, TransportType::Http { .. }));
449        assert!(matches!(ws_transport, TransportType::WebSocket { .. }));
450        assert!(matches!(stdio_transport, TransportType::Stdio));
451    }
452
453    #[test]
454    fn test_tls_configuration() {
455        let server_info = create_server_info(Some("tls-test".to_string()), Some("1.0.0".to_string()));
456        
457        // Test TLS configuration
458        let tls_config = server_builder()
459            .with_server_info(server_info.clone())
460            .with_port(443)
461            .with_tls("/path/to/cert.pem", "/path/to/key.pem")
462            .build()
463            .unwrap();
464        
465        assert!(tls_config.enable_tls);
466        assert_eq!(tls_config.tls_cert_path, Some("/path/to/cert.pem".to_string()));
467        assert_eq!(tls_config.tls_key_path, Some("/path/to/key.pem".to_string()));
468        assert!(tls_config.is_tls_configured());
469        
470        // Test incomplete TLS configuration
471        let incomplete_tls = server_builder()
472            .with_server_info(server_info)
473            .with_port(443)
474            .build()
475            .unwrap();
476        
477        assert!(!incomplete_tls.enable_tls);
478        assert!(!incomplete_tls.is_tls_configured());
479    }
480
481    #[test]
482    fn test_connection_limits() {
483        let server_info = create_server_info(Some("limits-test".to_string()), Some("1.0.0".to_string()));
484        
485        // Test custom limits
486        let custom_limits = server_builder()
487            .with_server_info(server_info.clone())
488            .with_max_connections(10000)
489            .with_connection_timeout(Duration::from_secs(120))
490            .build()
491            .unwrap();
492        
493        assert_eq!(custom_limits.max_connections, 10000);
494        assert_eq!(custom_limits.connection_timeout, Duration::from_secs(120));
495        
496        // Test defaults
497        let default_limits = server_builder()
498            .with_server_info(server_info)
499            .build()
500            .unwrap();
501        
502        assert_eq!(default_limits.max_connections, 1000);
503        assert_eq!(default_limits.connection_timeout, Duration::from_secs(30));
504    }
505
506    #[test]
507    fn test_custom_endpoints() {
508        let server_info = create_server_info(Some("endpoints-test".to_string()), Some("1.0.0".to_string()));
509        
510        let config = server_builder()
511            .with_server_info(server_info)
512            .with_custom_endpoint("/api/v1/users", "GET", "list_users")
513            .with_custom_endpoint("/api/v1/users", "POST", "create_user")
514            .with_custom_endpoint("/api/v1/users/{id}", "GET", "get_user")
515            .with_custom_endpoint("/api/v1/users/{id}", "PUT", "update_user")
516            .with_custom_endpoint("/api/v1/users/{id}", "DELETE", "delete_user")
517            .build()
518            .unwrap();
519        
520        assert_eq!(config.custom_endpoints.len(), 5);
521        
522        // Verify endpoints
523        let endpoints = &config.custom_endpoints;
524        assert_eq!(endpoints[0].path, "/api/v1/users");
525        assert_eq!(endpoints[0].method, "GET");
526        assert_eq!(endpoints[0].handler_name, "list_users");
527        
528        assert_eq!(endpoints[4].path, "/api/v1/users/{id}");
529        assert_eq!(endpoints[4].method, "DELETE");
530        assert_eq!(endpoints[4].handler_name, "delete_user");
531    }
532
533    #[test]
534    fn test_middleware_ordering() {
535        let server_info = create_server_info(Some("middleware-test".to_string()), Some("1.0.0".to_string()));
536        
537        let config = server_builder()
538            .with_server_info(server_info)
539            .with_middleware(AuthMiddleware::bearer("key1"))
540            .with_middleware(RateLimitMiddleware::per_second(50))
541            .with_middleware(AuthMiddleware::basic_auth("user", "pass"))
542            .with_middleware(RateLimitMiddleware::with_burst(100, 200))
543            .build()
544            .unwrap();
545        
546        assert_eq!(config.middleware.len(), 4);
547        
548        // Verify middleware order is preserved
549        assert_eq!(config.middleware[0].name, "auth");
550        assert_eq!(config.middleware[0].config.get("api_key"), Some(&"key1".to_string()));
551        
552        assert_eq!(config.middleware[1].name, "rate_limit");
553        assert_eq!(config.middleware[1].config.get("requests_per_second"), Some(&"50".to_string()));
554        
555        assert_eq!(config.middleware[2].name, "auth");
556        assert_eq!(config.middleware[2].config.get("type"), Some(&"basic".to_string()));
557        
558        assert_eq!(config.middleware[3].name, "rate_limit");
559        assert_eq!(config.middleware[3].config.get("burst_size"), Some(&"200".to_string()));
560    }
561}