1use crate::{CliError, McpConfiguration};
4use pulseengine_mcp_protocol::ServerInfo;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::time::Duration;
8use tracing::info;
9
10pub async fn run_server<C>(_config: C) -> std::result::Result<(), CliError>
12where
13    C: McpConfiguration,
14{
15    info!("Starting MCP server...");
24
25    _config.initialize_logging()?;
27
28    _config.validate()?;
30
31    info!("Server info: {:?}", _config.get_server_info());
32
33    Err(CliError::server_setup(
38        "Server implementation not yet complete",
39    ))
40}
41
42pub fn server_builder() -> ServerBuilder {
44    ServerBuilder::new()
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub enum TransportType {
50    Http { port: u16, host: String },
52    WebSocket {
54        port: u16,
55        host: String,
56        path: String,
57    },
58    Stdio,
60}
61
62impl Default for TransportType {
63    fn default() -> Self {
64        Self::Http {
65            port: 8080,
66            host: "localhost".to_string(),
67        }
68    }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct CorsPolicy {
74    pub allowed_origins: Vec<String>,
75    pub allowed_methods: Vec<String>,
76    pub allowed_headers: Vec<String>,
77    pub allow_credentials: bool,
78    pub max_age: Option<Duration>,
79}
80
81impl CorsPolicy {
82    pub fn permissive() -> Self {
84        Self {
85            allowed_origins: vec!["*".to_string()],
86            allowed_methods: vec![
87                "GET".to_string(),
88                "POST".to_string(),
89                "PUT".to_string(),
90                "DELETE".to_string(),
91                "OPTIONS".to_string(),
92            ],
93            allowed_headers: vec!["*".to_string()],
94            allow_credentials: false,
95            max_age: Some(Duration::from_secs(3600)),
96        }
97    }
98
99    pub fn strict() -> Self {
101        Self {
102            allowed_origins: vec![],
103            allowed_methods: vec!["GET".to_string(), "POST".to_string()],
104            allowed_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
105            allow_credentials: true,
106            max_age: Some(Duration::from_secs(300)),
107        }
108    }
109
110    pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
112        self.allowed_origins.push(origin.into());
113        self
114    }
115
116    pub fn allow_method(mut self, method: impl Into<String>) -> Self {
118        self.allowed_methods.push(method.into());
119        self
120    }
121}
122
123#[derive(Debug, Clone)]
125pub struct CustomEndpoint {
126    pub path: String,
127    pub method: String,
128    pub handler_name: String,
129}
130
131impl CustomEndpoint {
132    pub fn new(
133        path: impl Into<String>,
134        method: impl Into<String>,
135        handler_name: impl Into<String>,
136    ) -> Self {
137        Self {
138            path: path.into(),
139            method: method.into(),
140            handler_name: handler_name.into(),
141        }
142    }
143}
144
145#[derive(Debug, Clone)]
147pub struct MiddlewareConfig {
148    pub name: String,
149    pub config: HashMap<String, String>,
150}
151
152impl MiddlewareConfig {
153    pub fn new(name: impl Into<String>) -> Self {
154        Self {
155            name: name.into(),
156            config: HashMap::new(),
157        }
158    }
159
160    pub fn with_config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
161        self.config.insert(key.into(), value.into());
162        self
163    }
164}
165
166pub struct ServerBuilder {
168    server_info: Option<ServerInfo>,
169    transport: Option<TransportType>,
170    cors_policy: Option<CorsPolicy>,
171    middleware: Vec<MiddlewareConfig>,
172    custom_endpoints: Vec<CustomEndpoint>,
173    metrics_endpoint: Option<String>,
174    health_endpoint: Option<String>,
175    connection_timeout: Option<Duration>,
176    max_connections: Option<usize>,
177    enable_compression: bool,
178    enable_tls: bool,
179    tls_cert_path: Option<String>,
180    tls_key_path: Option<String>,
181}
182
183impl ServerBuilder {
184    pub fn new() -> Self {
185        Self {
186            server_info: None,
187            transport: None,
188            cors_policy: None,
189            middleware: Vec::new(),
190            custom_endpoints: Vec::new(),
191            metrics_endpoint: None,
192            health_endpoint: None,
193            connection_timeout: None,
194            max_connections: None,
195            enable_compression: false,
196            enable_tls: false,
197            tls_cert_path: None,
198            tls_key_path: None,
199        }
200    }
201
202    pub fn with_server_info(mut self, info: ServerInfo) -> Self {
203        self.server_info = Some(info);
204        self
205    }
206
207    pub fn with_port(mut self, port: u16) -> Self {
208        self.transport = Some(TransportType::Http {
209            port,
210            host: "localhost".to_string(),
211        });
212        self
213    }
214
215    pub fn with_transport(mut self, transport: TransportType) -> Self {
216        self.transport = Some(transport);
217        self
218    }
219
220    pub fn with_cors_policy(mut self, cors: CorsPolicy) -> Self {
221        self.cors_policy = Some(cors);
222        self
223    }
224
225    pub fn with_middleware(mut self, middleware: MiddlewareConfig) -> Self {
226        self.middleware.push(middleware);
227        self
228    }
229
230    pub fn with_metrics_endpoint(mut self, path: impl Into<String>) -> Self {
231        self.metrics_endpoint = Some(path.into());
232        self
233    }
234
235    pub fn with_health_endpoint(mut self, path: impl Into<String>) -> Self {
236        self.health_endpoint = Some(path.into());
237        self
238    }
239
240    pub fn with_custom_endpoint(
241        mut self,
242        path: impl Into<String>,
243        method: impl Into<String>,
244        handler_name: impl Into<String>,
245    ) -> Self {
246        self.custom_endpoints
247            .push(CustomEndpoint::new(path, method, handler_name));
248        self
249    }
250
251    pub fn with_connection_timeout(mut self, timeout: Duration) -> Self {
252        self.connection_timeout = Some(timeout);
253        self
254    }
255
256    pub fn with_max_connections(mut self, max: usize) -> Self {
257        self.max_connections = Some(max);
258        self
259    }
260
261    pub fn with_compression(mut self, enable: bool) -> Self {
262        self.enable_compression = enable;
263        self
264    }
265
266    pub fn with_tls(mut self, cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
267        self.enable_tls = true;
268        self.tls_cert_path = Some(cert_path.into());
269        self.tls_key_path = Some(key_path.into());
270        self
271    }
272
273    pub fn build(self) -> Result<BuiltServerConfig, CliError> {
274        Ok(BuiltServerConfig {
275            server_info: self
276                .server_info
277                .ok_or_else(|| CliError::configuration("Server info is required"))?,
278            transport: self.transport.unwrap_or_default(),
279            cors_policy: self.cors_policy,
280            middleware: self.middleware,
281            custom_endpoints: self.custom_endpoints,
282            metrics_endpoint: self.metrics_endpoint,
283            health_endpoint: self.health_endpoint,
284            connection_timeout: self.connection_timeout.unwrap_or(Duration::from_secs(30)),
285            max_connections: self.max_connections.unwrap_or(1000),
286            enable_compression: self.enable_compression,
287            enable_tls: self.enable_tls,
288            tls_cert_path: self.tls_cert_path,
289            tls_key_path: self.tls_key_path,
290        })
291    }
292}
293
294impl Default for ServerBuilder {
295    fn default() -> Self {
296        Self::new()
297    }
298}
299
300#[derive(Debug, Clone)]
302pub struct BuiltServerConfig {
303    pub server_info: ServerInfo,
304    pub transport: TransportType,
305    pub cors_policy: Option<CorsPolicy>,
306    pub middleware: Vec<MiddlewareConfig>,
307    pub custom_endpoints: Vec<CustomEndpoint>,
308    pub metrics_endpoint: Option<String>,
309    pub health_endpoint: Option<String>,
310    pub connection_timeout: Duration,
311    pub max_connections: usize,
312    pub enable_compression: bool,
313    pub enable_tls: bool,
314    pub tls_cert_path: Option<String>,
315    pub tls_key_path: Option<String>,
316}
317
318impl BuiltServerConfig {
319    pub fn port(&self) -> Option<u16> {
321        match &self.transport {
322            TransportType::Http { port, .. } | TransportType::WebSocket { port, .. } => Some(*port),
323            TransportType::Stdio => None,
324        }
325    }
326
327    pub fn host(&self) -> Option<&str> {
329        match &self.transport {
330            TransportType::Http { host, .. } | TransportType::WebSocket { host, .. } => Some(host),
331            TransportType::Stdio => None,
332        }
333    }
334
335    pub fn is_tls_configured(&self) -> bool {
337        self.enable_tls && self.tls_cert_path.is_some() && self.tls_key_path.is_some()
338    }
339}
340
341pub struct AuthMiddleware;
343
344impl AuthMiddleware {
345    pub fn bearer(api_key: impl Into<String>) -> MiddlewareConfig {
346        MiddlewareConfig::new("auth")
347            .with_config("api_key", api_key)
348            .with_config("type", "bearer")
349    }
350
351    pub fn basic_auth(
352        username: impl Into<String>,
353        password: impl Into<String>,
354    ) -> MiddlewareConfig {
355        MiddlewareConfig::new("auth")
356            .with_config("username", username)
357            .with_config("password", password)
358            .with_config("type", "basic")
359    }
360}
361
362pub struct RateLimitMiddleware;
364
365impl RateLimitMiddleware {
366    pub fn per_second(requests_per_second: u32) -> MiddlewareConfig {
367        MiddlewareConfig::new("rate_limit")
368            .with_config("requests_per_second", requests_per_second.to_string())
369    }
370
371    pub fn with_burst(requests_per_second: u32, burst_size: u32) -> MiddlewareConfig {
372        MiddlewareConfig::new("rate_limit")
373            .with_config("requests_per_second", requests_per_second.to_string())
374            .with_config("burst_size", burst_size.to_string())
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use crate::config::create_server_info;
382
383    #[test]
384    fn test_server_builder_basic() {
385        let server_info = create_server_info(Some("test".to_string()), Some("1.0.0".to_string()));
386
387        let config = server_builder()
388            .with_server_info(server_info)
389            .with_port(3000)
390            .build()
391            .unwrap();
392
393        assert_eq!(config.port(), Some(3000));
394        assert_eq!(config.server_info.server_info.name, "test");
395    }
396
397    #[test]
398    fn test_server_builder_advanced() {
399        let server_info =
400            create_server_info(Some("advanced".to_string()), Some("1.0.0".to_string()));
401
402        let config = server_builder()
403            .with_server_info(server_info)
404            .with_transport(TransportType::Http {
405                port: 8080,
406                host: "0.0.0.0".to_string(),
407            })
408            .with_cors_policy(CorsPolicy::permissive())
409            .with_middleware(AuthMiddleware::bearer("secret-key"))
410            .with_middleware(RateLimitMiddleware::per_second(100))
411            .with_metrics_endpoint("/metrics")
412            .with_health_endpoint("/health")
413            .with_custom_endpoint("/api/v1/custom", "POST", "custom_handler")
414            .with_compression(true)
415            .build()
416            .unwrap();
417
418        assert_eq!(config.port(), Some(8080));
419        assert_eq!(config.host(), Some("0.0.0.0"));
420        assert!(config.cors_policy.is_some());
421        assert_eq!(config.middleware.len(), 2);
422        assert_eq!(config.custom_endpoints.len(), 1);
423        assert_eq!(config.metrics_endpoint, Some("/metrics".to_string()));
424        assert_eq!(config.health_endpoint, Some("/health".to_string()));
425        assert!(config.enable_compression);
426    }
427
428    #[test]
429    fn test_cors_policy() {
430        let cors = CorsPolicy::permissive()
431            .allow_origin("https://example.com")
432            .allow_method("PATCH");
433
434        assert!(cors.allowed_origins.contains(&"*".to_string()));
435        assert!(cors
436            .allowed_origins
437            .contains(&"https://example.com".to_string()));
438        assert!(cors.allowed_methods.contains(&"PATCH".to_string()));
439    }
440
441    #[test]
442    fn test_transport_types() {
443        let http_transport = TransportType::Http {
444            port: 8080,
445            host: "localhost".to_string(),
446        };
447
448        let ws_transport = TransportType::WebSocket {
449            port: 8081,
450            host: "localhost".to_string(),
451            path: "/ws".to_string(),
452        };
453
454        let stdio_transport = TransportType::Stdio;
455
456        assert!(matches!(http_transport, TransportType::Http { .. }));
457        assert!(matches!(ws_transport, TransportType::WebSocket { .. }));
458        assert!(matches!(stdio_transport, TransportType::Stdio));
459    }
460
461    #[test]
462    fn test_tls_configuration() {
463        let server_info =
464            create_server_info(Some("tls-test".to_string()), Some("1.0.0".to_string()));
465
466        let tls_config = server_builder()
468            .with_server_info(server_info.clone())
469            .with_port(443)
470            .with_tls("/path/to/cert.pem", "/path/to/key.pem")
471            .build()
472            .unwrap();
473
474        assert!(tls_config.enable_tls);
475        assert_eq!(
476            tls_config.tls_cert_path,
477            Some("/path/to/cert.pem".to_string())
478        );
479        assert_eq!(
480            tls_config.tls_key_path,
481            Some("/path/to/key.pem".to_string())
482        );
483        assert!(tls_config.is_tls_configured());
484
485        let incomplete_tls = server_builder()
487            .with_server_info(server_info)
488            .with_port(443)
489            .build()
490            .unwrap();
491
492        assert!(!incomplete_tls.enable_tls);
493        assert!(!incomplete_tls.is_tls_configured());
494    }
495
496    #[test]
497    fn test_connection_limits() {
498        let server_info =
499            create_server_info(Some("limits-test".to_string()), Some("1.0.0".to_string()));
500
501        let custom_limits = server_builder()
503            .with_server_info(server_info.clone())
504            .with_max_connections(10000)
505            .with_connection_timeout(Duration::from_secs(120))
506            .build()
507            .unwrap();
508
509        assert_eq!(custom_limits.max_connections, 10000);
510        assert_eq!(custom_limits.connection_timeout, Duration::from_secs(120));
511
512        let default_limits = server_builder()
514            .with_server_info(server_info)
515            .build()
516            .unwrap();
517
518        assert_eq!(default_limits.max_connections, 1000);
519        assert_eq!(default_limits.connection_timeout, Duration::from_secs(30));
520    }
521
522    #[test]
523    fn test_custom_endpoints() {
524        let server_info = create_server_info(
525            Some("endpoints-test".to_string()),
526            Some("1.0.0".to_string()),
527        );
528
529        let config = server_builder()
530            .with_server_info(server_info)
531            .with_custom_endpoint("/api/v1/users", "GET", "list_users")
532            .with_custom_endpoint("/api/v1/users", "POST", "create_user")
533            .with_custom_endpoint("/api/v1/users/{id}", "GET", "get_user")
534            .with_custom_endpoint("/api/v1/users/{id}", "PUT", "update_user")
535            .with_custom_endpoint("/api/v1/users/{id}", "DELETE", "delete_user")
536            .build()
537            .unwrap();
538
539        assert_eq!(config.custom_endpoints.len(), 5);
540
541        let endpoints = &config.custom_endpoints;
543        assert_eq!(endpoints[0].path, "/api/v1/users");
544        assert_eq!(endpoints[0].method, "GET");
545        assert_eq!(endpoints[0].handler_name, "list_users");
546
547        assert_eq!(endpoints[4].path, "/api/v1/users/{id}");
548        assert_eq!(endpoints[4].method, "DELETE");
549        assert_eq!(endpoints[4].handler_name, "delete_user");
550    }
551
552    #[test]
553    fn test_middleware_ordering() {
554        let server_info = create_server_info(
555            Some("middleware-test".to_string()),
556            Some("1.0.0".to_string()),
557        );
558
559        let config = server_builder()
560            .with_server_info(server_info)
561            .with_middleware(AuthMiddleware::bearer("key1"))
562            .with_middleware(RateLimitMiddleware::per_second(50))
563            .with_middleware(AuthMiddleware::basic_auth("user", "pass"))
564            .with_middleware(RateLimitMiddleware::with_burst(100, 200))
565            .build()
566            .unwrap();
567
568        assert_eq!(config.middleware.len(), 4);
569
570        assert_eq!(config.middleware[0].name, "auth");
572        assert_eq!(
573            config.middleware[0].config.get("api_key"),
574            Some(&"key1".to_string())
575        );
576
577        assert_eq!(config.middleware[1].name, "rate_limit");
578        assert_eq!(
579            config.middleware[1].config.get("requests_per_second"),
580            Some(&"50".to_string())
581        );
582
583        assert_eq!(config.middleware[2].name, "auth");
584        assert_eq!(
585            config.middleware[2].config.get("type"),
586            Some(&"basic".to_string())
587        );
588
589        assert_eq!(config.middleware[3].name, "rate_limit");
590        assert_eq!(
591            config.middleware[3].config.get("burst_size"),
592            Some(&"200".to_string())
593        );
594    }
595}