pulseengine_mcp_cli/
server.rs

1//! Server integration utilities
2
3use crate::{CliError, McpConfiguration};
4use pulseengine_mcp_protocol::ServerInfo;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::time::Duration;
8use tracing::info;
9
10/// Run an MCP server with the given configuration
11pub async fn run_server<C>(_config: C) -> std::result::Result<(), CliError>
12where
13    C: McpConfiguration,
14{
15    // This is a placeholder implementation
16    // In the full implementation, this would:
17    // 1. Initialize logging
18    // 2. Create server from configuration
19    // 3. Set up signal handling
20    // 4. Start the server
21    // 5. Handle graceful shutdown
22
23    info!("Starting MCP server...");
24
25    // Initialize logging
26    _config.initialize_logging()?;
27
28    // Validate configuration
29    _config.validate()?;
30
31    info!("Server info: {:?}", _config.get_server_info());
32
33    // TODO: Integrate with actual server implementation
34    Err(CliError::server_setup(
35        "Server implementation not yet complete",
36    ))
37}
38
39/// Create server configuration builder
40pub fn server_builder() -> ServerBuilder {
41    ServerBuilder::new()
42}
43
44/// Transport type for the MCP server
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub enum TransportType {
47    /// HTTP transport
48    Http { port: u16, host: String },
49    /// WebSocket transport
50    WebSocket {
51        port: u16,
52        host: String,
53        path: String,
54    },
55    /// Standard I/O transport
56    Stdio,
57}
58
59impl Default for TransportType {
60    fn default() -> Self {
61        Self::Http {
62            port: 8080,
63            host: "localhost".to_string(),
64        }
65    }
66}
67
68/// CORS policy configuration
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct CorsPolicy {
71    pub allowed_origins: Vec<String>,
72    pub allowed_methods: Vec<String>,
73    pub allowed_headers: Vec<String>,
74    pub allow_credentials: bool,
75    pub max_age: Option<Duration>,
76}
77
78impl CorsPolicy {
79    /// Create a permissive CORS policy (allows all origins)
80    pub fn permissive() -> Self {
81        Self {
82            allowed_origins: vec!["*".to_string()],
83            allowed_methods: vec![
84                "GET".to_string(),
85                "POST".to_string(),
86                "PUT".to_string(),
87                "DELETE".to_string(),
88                "OPTIONS".to_string(),
89            ],
90            allowed_headers: vec!["*".to_string()],
91            allow_credentials: false,
92            max_age: Some(Duration::from_secs(3600)),
93        }
94    }
95
96    /// Create a strict CORS policy
97    pub fn strict() -> Self {
98        Self {
99            allowed_origins: vec![],
100            allowed_methods: vec!["GET".to_string(), "POST".to_string()],
101            allowed_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
102            allow_credentials: true,
103            max_age: Some(Duration::from_secs(300)),
104        }
105    }
106
107    /// Add allowed origin
108    pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
109        self.allowed_origins.push(origin.into());
110        self
111    }
112
113    /// Add allowed method
114    pub fn allow_method(mut self, method: impl Into<String>) -> Self {
115        self.allowed_methods.push(method.into());
116        self
117    }
118}
119
120/// Custom endpoint configuration
121#[derive(Debug, Clone)]
122pub struct CustomEndpoint {
123    pub path: String,
124    pub method: String,
125    pub handler_name: String,
126}
127
128impl CustomEndpoint {
129    pub fn new(
130        path: impl Into<String>,
131        method: impl Into<String>,
132        handler_name: impl Into<String>,
133    ) -> Self {
134        Self {
135            path: path.into(),
136            method: method.into(),
137            handler_name: handler_name.into(),
138        }
139    }
140}
141
142/// Middleware configuration
143#[derive(Debug, Clone)]
144pub struct MiddlewareConfig {
145    pub name: String,
146    pub config: HashMap<String, String>,
147}
148
149impl MiddlewareConfig {
150    pub fn new(name: impl Into<String>) -> Self {
151        Self {
152            name: name.into(),
153            config: HashMap::new(),
154        }
155    }
156
157    pub fn with_config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
158        self.config.insert(key.into(), value.into());
159        self
160    }
161}
162
163/// Builder for server configuration
164pub struct ServerBuilder {
165    server_info: Option<ServerInfo>,
166    transport: Option<TransportType>,
167    cors_policy: Option<CorsPolicy>,
168    middleware: Vec<MiddlewareConfig>,
169    custom_endpoints: Vec<CustomEndpoint>,
170    metrics_endpoint: Option<String>,
171    health_endpoint: Option<String>,
172    connection_timeout: Option<Duration>,
173    max_connections: Option<usize>,
174    enable_compression: bool,
175    enable_tls: bool,
176    tls_cert_path: Option<String>,
177    tls_key_path: Option<String>,
178}
179
180impl ServerBuilder {
181    pub fn new() -> Self {
182        Self {
183            server_info: None,
184            transport: None,
185            cors_policy: None,
186            middleware: Vec::new(),
187            custom_endpoints: Vec::new(),
188            metrics_endpoint: None,
189            health_endpoint: None,
190            connection_timeout: None,
191            max_connections: None,
192            enable_compression: false,
193            enable_tls: false,
194            tls_cert_path: None,
195            tls_key_path: None,
196        }
197    }
198
199    pub fn with_server_info(mut self, info: ServerInfo) -> Self {
200        self.server_info = Some(info);
201        self
202    }
203
204    pub fn with_port(mut self, port: u16) -> Self {
205        self.transport = Some(TransportType::Http {
206            port,
207            host: "localhost".to_string(),
208        });
209        self
210    }
211
212    pub fn with_transport(mut self, transport: TransportType) -> Self {
213        self.transport = Some(transport);
214        self
215    }
216
217    pub fn with_cors_policy(mut self, cors: CorsPolicy) -> Self {
218        self.cors_policy = Some(cors);
219        self
220    }
221
222    pub fn with_middleware(mut self, middleware: MiddlewareConfig) -> Self {
223        self.middleware.push(middleware);
224        self
225    }
226
227    pub fn with_metrics_endpoint(mut self, path: impl Into<String>) -> Self {
228        self.metrics_endpoint = Some(path.into());
229        self
230    }
231
232    pub fn with_health_endpoint(mut self, path: impl Into<String>) -> Self {
233        self.health_endpoint = Some(path.into());
234        self
235    }
236
237    pub fn with_custom_endpoint(
238        mut self,
239        path: impl Into<String>,
240        method: impl Into<String>,
241        handler_name: impl Into<String>,
242    ) -> Self {
243        self.custom_endpoints
244            .push(CustomEndpoint::new(path, method, handler_name));
245        self
246    }
247
248    pub fn with_connection_timeout(mut self, timeout: Duration) -> Self {
249        self.connection_timeout = Some(timeout);
250        self
251    }
252
253    pub fn with_max_connections(mut self, max: usize) -> Self {
254        self.max_connections = Some(max);
255        self
256    }
257
258    pub fn with_compression(mut self, enable: bool) -> Self {
259        self.enable_compression = enable;
260        self
261    }
262
263    pub fn with_tls(mut self, cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
264        self.enable_tls = true;
265        self.tls_cert_path = Some(cert_path.into());
266        self.tls_key_path = Some(key_path.into());
267        self
268    }
269
270    pub fn build(self) -> Result<BuiltServerConfig, CliError> {
271        Ok(BuiltServerConfig {
272            server_info: self
273                .server_info
274                .ok_or_else(|| CliError::configuration("Server info is required"))?,
275            transport: self.transport.unwrap_or_default(),
276            cors_policy: self.cors_policy,
277            middleware: self.middleware,
278            custom_endpoints: self.custom_endpoints,
279            metrics_endpoint: self.metrics_endpoint,
280            health_endpoint: self.health_endpoint,
281            connection_timeout: self.connection_timeout.unwrap_or(Duration::from_secs(30)),
282            max_connections: self.max_connections.unwrap_or(1000),
283            enable_compression: self.enable_compression,
284            enable_tls: self.enable_tls,
285            tls_cert_path: self.tls_cert_path,
286            tls_key_path: self.tls_key_path,
287        })
288    }
289}
290
291impl Default for ServerBuilder {
292    fn default() -> Self {
293        Self::new()
294    }
295}
296
297/// Built server configuration
298#[derive(Debug, Clone)]
299pub struct BuiltServerConfig {
300    pub server_info: ServerInfo,
301    pub transport: TransportType,
302    pub cors_policy: Option<CorsPolicy>,
303    pub middleware: Vec<MiddlewareConfig>,
304    pub custom_endpoints: Vec<CustomEndpoint>,
305    pub metrics_endpoint: Option<String>,
306    pub health_endpoint: Option<String>,
307    pub connection_timeout: Duration,
308    pub max_connections: usize,
309    pub enable_compression: bool,
310    pub enable_tls: bool,
311    pub tls_cert_path: Option<String>,
312    pub tls_key_path: Option<String>,
313}
314
315impl BuiltServerConfig {
316    /// Get the server port from transport configuration
317    pub fn port(&self) -> Option<u16> {
318        match &self.transport {
319            TransportType::Http { port, .. } | TransportType::WebSocket { port, .. } => Some(*port),
320            TransportType::Stdio => None,
321        }
322    }
323
324    /// Get the server host from transport configuration
325    pub fn host(&self) -> Option<&str> {
326        match &self.transport {
327            TransportType::Http { host, .. } | TransportType::WebSocket { host, .. } => Some(host),
328            TransportType::Stdio => None,
329        }
330    }
331
332    /// Check if TLS is enabled and properly configured
333    pub fn is_tls_configured(&self) -> bool {
334        self.enable_tls && self.tls_cert_path.is_some() && self.tls_key_path.is_some()
335    }
336}
337
338/// Authentication middleware configuration
339pub struct AuthMiddleware;
340
341impl AuthMiddleware {
342    pub fn bearer(api_key: impl Into<String>) -> MiddlewareConfig {
343        MiddlewareConfig::new("auth")
344            .with_config("api_key", api_key)
345            .with_config("type", "bearer")
346    }
347
348    pub fn basic_auth(
349        username: impl Into<String>,
350        password: impl Into<String>,
351    ) -> MiddlewareConfig {
352        MiddlewareConfig::new("auth")
353            .with_config("username", username)
354            .with_config("password", password)
355            .with_config("type", "basic")
356    }
357}
358
359/// Rate limiting middleware configuration
360pub struct RateLimitMiddleware;
361
362impl RateLimitMiddleware {
363    pub fn per_second(requests_per_second: u32) -> MiddlewareConfig {
364        MiddlewareConfig::new("rate_limit")
365            .with_config("requests_per_second", requests_per_second.to_string())
366    }
367
368    pub fn with_burst(requests_per_second: u32, burst_size: u32) -> MiddlewareConfig {
369        MiddlewareConfig::new("rate_limit")
370            .with_config("requests_per_second", requests_per_second.to_string())
371            .with_config("burst_size", burst_size.to_string())
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use crate::config::create_server_info;
379
380    #[test]
381    fn test_server_builder_basic() {
382        let server_info = create_server_info(Some("test".to_string()), Some("1.0.0".to_string()));
383
384        let config = server_builder()
385            .with_server_info(server_info)
386            .with_port(3000)
387            .build()
388            .unwrap();
389
390        assert_eq!(config.port(), Some(3000));
391        assert_eq!(config.server_info.server_info.name, "test");
392    }
393
394    #[test]
395    fn test_server_builder_advanced() {
396        let server_info =
397            create_server_info(Some("advanced".to_string()), Some("1.0.0".to_string()));
398
399        let config = server_builder()
400            .with_server_info(server_info)
401            .with_transport(TransportType::Http {
402                port: 8080,
403                host: "0.0.0.0".to_string(),
404            })
405            .with_cors_policy(CorsPolicy::permissive())
406            .with_middleware(AuthMiddleware::bearer("secret-key"))
407            .with_middleware(RateLimitMiddleware::per_second(100))
408            .with_metrics_endpoint("/metrics")
409            .with_health_endpoint("/health")
410            .with_custom_endpoint("/api/v1/custom", "POST", "custom_handler")
411            .with_compression(true)
412            .build()
413            .unwrap();
414
415        assert_eq!(config.port(), Some(8080));
416        assert_eq!(config.host(), Some("0.0.0.0"));
417        assert!(config.cors_policy.is_some());
418        assert_eq!(config.middleware.len(), 2);
419        assert_eq!(config.custom_endpoints.len(), 1);
420        assert_eq!(config.metrics_endpoint, Some("/metrics".to_string()));
421        assert_eq!(config.health_endpoint, Some("/health".to_string()));
422        assert!(config.enable_compression);
423    }
424
425    #[test]
426    fn test_cors_policy() {
427        let cors = CorsPolicy::permissive()
428            .allow_origin("https://example.com")
429            .allow_method("PATCH");
430
431        assert!(cors.allowed_origins.contains(&"*".to_string()));
432        assert!(
433            cors.allowed_origins
434                .contains(&"https://example.com".to_string())
435        );
436        assert!(cors.allowed_methods.contains(&"PATCH".to_string()));
437    }
438
439    #[test]
440    fn test_transport_types() {
441        let http_transport = TransportType::Http {
442            port: 8080,
443            host: "localhost".to_string(),
444        };
445
446        let ws_transport = TransportType::WebSocket {
447            port: 8081,
448            host: "localhost".to_string(),
449            path: "/ws".to_string(),
450        };
451
452        let stdio_transport = TransportType::Stdio;
453
454        assert!(matches!(http_transport, TransportType::Http { .. }));
455        assert!(matches!(ws_transport, TransportType::WebSocket { .. }));
456        assert!(matches!(stdio_transport, TransportType::Stdio));
457    }
458
459    #[test]
460    fn test_tls_configuration() {
461        let server_info =
462            create_server_info(Some("tls-test".to_string()), Some("1.0.0".to_string()));
463
464        // Test TLS configuration
465        let tls_config = server_builder()
466            .with_server_info(server_info.clone())
467            .with_port(443)
468            .with_tls("/path/to/cert.pem", "/path/to/key.pem")
469            .build()
470            .unwrap();
471
472        assert!(tls_config.enable_tls);
473        assert_eq!(
474            tls_config.tls_cert_path,
475            Some("/path/to/cert.pem".to_string())
476        );
477        assert_eq!(
478            tls_config.tls_key_path,
479            Some("/path/to/key.pem".to_string())
480        );
481        assert!(tls_config.is_tls_configured());
482
483        // Test incomplete TLS configuration
484        let incomplete_tls = server_builder()
485            .with_server_info(server_info)
486            .with_port(443)
487            .build()
488            .unwrap();
489
490        assert!(!incomplete_tls.enable_tls);
491        assert!(!incomplete_tls.is_tls_configured());
492    }
493
494    #[test]
495    fn test_connection_limits() {
496        let server_info =
497            create_server_info(Some("limits-test".to_string()), Some("1.0.0".to_string()));
498
499        // Test custom limits
500        let custom_limits = server_builder()
501            .with_server_info(server_info.clone())
502            .with_max_connections(10000)
503            .with_connection_timeout(Duration::from_secs(120))
504            .build()
505            .unwrap();
506
507        assert_eq!(custom_limits.max_connections, 10000);
508        assert_eq!(custom_limits.connection_timeout, Duration::from_secs(120));
509
510        // Test defaults
511        let default_limits = server_builder()
512            .with_server_info(server_info)
513            .build()
514            .unwrap();
515
516        assert_eq!(default_limits.max_connections, 1000);
517        assert_eq!(default_limits.connection_timeout, Duration::from_secs(30));
518    }
519
520    #[test]
521    fn test_custom_endpoints() {
522        let server_info = create_server_info(
523            Some("endpoints-test".to_string()),
524            Some("1.0.0".to_string()),
525        );
526
527        let config = server_builder()
528            .with_server_info(server_info)
529            .with_custom_endpoint("/api/v1/users", "GET", "list_users")
530            .with_custom_endpoint("/api/v1/users", "POST", "create_user")
531            .with_custom_endpoint("/api/v1/users/{id}", "GET", "get_user")
532            .with_custom_endpoint("/api/v1/users/{id}", "PUT", "update_user")
533            .with_custom_endpoint("/api/v1/users/{id}", "DELETE", "delete_user")
534            .build()
535            .unwrap();
536
537        assert_eq!(config.custom_endpoints.len(), 5);
538
539        // Verify endpoints
540        let endpoints = &config.custom_endpoints;
541        assert_eq!(endpoints[0].path, "/api/v1/users");
542        assert_eq!(endpoints[0].method, "GET");
543        assert_eq!(endpoints[0].handler_name, "list_users");
544
545        assert_eq!(endpoints[4].path, "/api/v1/users/{id}");
546        assert_eq!(endpoints[4].method, "DELETE");
547        assert_eq!(endpoints[4].handler_name, "delete_user");
548    }
549
550    #[test]
551    fn test_middleware_ordering() {
552        let server_info = create_server_info(
553            Some("middleware-test".to_string()),
554            Some("1.0.0".to_string()),
555        );
556
557        let config = server_builder()
558            .with_server_info(server_info)
559            .with_middleware(AuthMiddleware::bearer("key1"))
560            .with_middleware(RateLimitMiddleware::per_second(50))
561            .with_middleware(AuthMiddleware::basic_auth("user", "pass"))
562            .with_middleware(RateLimitMiddleware::with_burst(100, 200))
563            .build()
564            .unwrap();
565
566        assert_eq!(config.middleware.len(), 4);
567
568        // Verify middleware order is preserved
569        assert_eq!(config.middleware[0].name, "auth");
570        assert_eq!(
571            config.middleware[0].config.get("api_key"),
572            Some(&"key1".to_string())
573        );
574
575        assert_eq!(config.middleware[1].name, "rate_limit");
576        assert_eq!(
577            config.middleware[1].config.get("requests_per_second"),
578            Some(&"50".to_string())
579        );
580
581        assert_eq!(config.middleware[2].name, "auth");
582        assert_eq!(
583            config.middleware[2].config.get("type"),
584            Some(&"basic".to_string())
585        );
586
587        assert_eq!(config.middleware[3].name, "rate_limit");
588        assert_eq!(
589            config.middleware[3].config.get("burst_size"),
590            Some(&"200".to_string())
591        );
592    }
593}