1use crate::{CliError, McpConfiguration};
4use pulseengine_mcp_protocol::ServerInfo;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::time::Duration;
8use tracing::info;
9
10pub async fn run_server<C>(_config: C) -> std::result::Result<(), CliError>
12where
13 C: McpConfiguration,
14{
15 info!("Starting MCP server...");
24
25 _config.initialize_logging()?;
27
28 _config.validate()?;
30
31 info!("Server info: {:?}", _config.get_server_info());
32
33 Err(CliError::server_setup(
35 "Server implementation not yet complete",
36 ))
37}
38
39pub fn server_builder() -> ServerBuilder {
41 ServerBuilder::new()
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub enum TransportType {
47 Http { port: u16, host: String },
49 WebSocket {
51 port: u16,
52 host: String,
53 path: String,
54 },
55 Stdio,
57}
58
59impl Default for TransportType {
60 fn default() -> Self {
61 Self::Http {
62 port: 8080,
63 host: "localhost".to_string(),
64 }
65 }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct CorsPolicy {
71 pub allowed_origins: Vec<String>,
72 pub allowed_methods: Vec<String>,
73 pub allowed_headers: Vec<String>,
74 pub allow_credentials: bool,
75 pub max_age: Option<Duration>,
76}
77
78impl CorsPolicy {
79 pub fn permissive() -> Self {
81 Self {
82 allowed_origins: vec!["*".to_string()],
83 allowed_methods: vec![
84 "GET".to_string(),
85 "POST".to_string(),
86 "PUT".to_string(),
87 "DELETE".to_string(),
88 "OPTIONS".to_string(),
89 ],
90 allowed_headers: vec!["*".to_string()],
91 allow_credentials: false,
92 max_age: Some(Duration::from_secs(3600)),
93 }
94 }
95
96 pub fn strict() -> Self {
98 Self {
99 allowed_origins: vec![],
100 allowed_methods: vec!["GET".to_string(), "POST".to_string()],
101 allowed_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
102 allow_credentials: true,
103 max_age: Some(Duration::from_secs(300)),
104 }
105 }
106
107 pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
109 self.allowed_origins.push(origin.into());
110 self
111 }
112
113 pub fn allow_method(mut self, method: impl Into<String>) -> Self {
115 self.allowed_methods.push(method.into());
116 self
117 }
118}
119
120#[derive(Debug, Clone)]
122pub struct CustomEndpoint {
123 pub path: String,
124 pub method: String,
125 pub handler_name: String,
126}
127
128impl CustomEndpoint {
129 pub fn new(
130 path: impl Into<String>,
131 method: impl Into<String>,
132 handler_name: impl Into<String>,
133 ) -> Self {
134 Self {
135 path: path.into(),
136 method: method.into(),
137 handler_name: handler_name.into(),
138 }
139 }
140}
141
142#[derive(Debug, Clone)]
144pub struct MiddlewareConfig {
145 pub name: String,
146 pub config: HashMap<String, String>,
147}
148
149impl MiddlewareConfig {
150 pub fn new(name: impl Into<String>) -> Self {
151 Self {
152 name: name.into(),
153 config: HashMap::new(),
154 }
155 }
156
157 pub fn with_config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
158 self.config.insert(key.into(), value.into());
159 self
160 }
161}
162
163pub struct ServerBuilder {
165 server_info: Option<ServerInfo>,
166 transport: Option<TransportType>,
167 cors_policy: Option<CorsPolicy>,
168 middleware: Vec<MiddlewareConfig>,
169 custom_endpoints: Vec<CustomEndpoint>,
170 metrics_endpoint: Option<String>,
171 health_endpoint: Option<String>,
172 connection_timeout: Option<Duration>,
173 max_connections: Option<usize>,
174 enable_compression: bool,
175 enable_tls: bool,
176 tls_cert_path: Option<String>,
177 tls_key_path: Option<String>,
178}
179
180impl ServerBuilder {
181 pub fn new() -> Self {
182 Self {
183 server_info: None,
184 transport: None,
185 cors_policy: None,
186 middleware: Vec::new(),
187 custom_endpoints: Vec::new(),
188 metrics_endpoint: None,
189 health_endpoint: None,
190 connection_timeout: None,
191 max_connections: None,
192 enable_compression: false,
193 enable_tls: false,
194 tls_cert_path: None,
195 tls_key_path: None,
196 }
197 }
198
199 pub fn with_server_info(mut self, info: ServerInfo) -> Self {
200 self.server_info = Some(info);
201 self
202 }
203
204 pub fn with_port(mut self, port: u16) -> Self {
205 self.transport = Some(TransportType::Http {
206 port,
207 host: "localhost".to_string(),
208 });
209 self
210 }
211
212 pub fn with_transport(mut self, transport: TransportType) -> Self {
213 self.transport = Some(transport);
214 self
215 }
216
217 pub fn with_cors_policy(mut self, cors: CorsPolicy) -> Self {
218 self.cors_policy = Some(cors);
219 self
220 }
221
222 pub fn with_middleware(mut self, middleware: MiddlewareConfig) -> Self {
223 self.middleware.push(middleware);
224 self
225 }
226
227 pub fn with_metrics_endpoint(mut self, path: impl Into<String>) -> Self {
228 self.metrics_endpoint = Some(path.into());
229 self
230 }
231
232 pub fn with_health_endpoint(mut self, path: impl Into<String>) -> Self {
233 self.health_endpoint = Some(path.into());
234 self
235 }
236
237 pub fn with_custom_endpoint(
238 mut self,
239 path: impl Into<String>,
240 method: impl Into<String>,
241 handler_name: impl Into<String>,
242 ) -> Self {
243 self.custom_endpoints
244 .push(CustomEndpoint::new(path, method, handler_name));
245 self
246 }
247
248 pub fn with_connection_timeout(mut self, timeout: Duration) -> Self {
249 self.connection_timeout = Some(timeout);
250 self
251 }
252
253 pub fn with_max_connections(mut self, max: usize) -> Self {
254 self.max_connections = Some(max);
255 self
256 }
257
258 pub fn with_compression(mut self, enable: bool) -> Self {
259 self.enable_compression = enable;
260 self
261 }
262
263 pub fn with_tls(mut self, cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
264 self.enable_tls = true;
265 self.tls_cert_path = Some(cert_path.into());
266 self.tls_key_path = Some(key_path.into());
267 self
268 }
269
270 pub fn build(self) -> Result<BuiltServerConfig, CliError> {
271 Ok(BuiltServerConfig {
272 server_info: self
273 .server_info
274 .ok_or_else(|| CliError::configuration("Server info is required"))?,
275 transport: self.transport.unwrap_or_default(),
276 cors_policy: self.cors_policy,
277 middleware: self.middleware,
278 custom_endpoints: self.custom_endpoints,
279 metrics_endpoint: self.metrics_endpoint,
280 health_endpoint: self.health_endpoint,
281 connection_timeout: self.connection_timeout.unwrap_or(Duration::from_secs(30)),
282 max_connections: self.max_connections.unwrap_or(1000),
283 enable_compression: self.enable_compression,
284 enable_tls: self.enable_tls,
285 tls_cert_path: self.tls_cert_path,
286 tls_key_path: self.tls_key_path,
287 })
288 }
289}
290
291impl Default for ServerBuilder {
292 fn default() -> Self {
293 Self::new()
294 }
295}
296
297#[derive(Debug, Clone)]
299pub struct BuiltServerConfig {
300 pub server_info: ServerInfo,
301 pub transport: TransportType,
302 pub cors_policy: Option<CorsPolicy>,
303 pub middleware: Vec<MiddlewareConfig>,
304 pub custom_endpoints: Vec<CustomEndpoint>,
305 pub metrics_endpoint: Option<String>,
306 pub health_endpoint: Option<String>,
307 pub connection_timeout: Duration,
308 pub max_connections: usize,
309 pub enable_compression: bool,
310 pub enable_tls: bool,
311 pub tls_cert_path: Option<String>,
312 pub tls_key_path: Option<String>,
313}
314
315impl BuiltServerConfig {
316 pub fn port(&self) -> Option<u16> {
318 match &self.transport {
319 TransportType::Http { port, .. } | TransportType::WebSocket { port, .. } => Some(*port),
320 TransportType::Stdio => None,
321 }
322 }
323
324 pub fn host(&self) -> Option<&str> {
326 match &self.transport {
327 TransportType::Http { host, .. } | TransportType::WebSocket { host, .. } => Some(host),
328 TransportType::Stdio => None,
329 }
330 }
331
332 pub fn is_tls_configured(&self) -> bool {
334 self.enable_tls && self.tls_cert_path.is_some() && self.tls_key_path.is_some()
335 }
336}
337
338pub struct AuthMiddleware;
340
341impl AuthMiddleware {
342 pub fn bearer(api_key: impl Into<String>) -> MiddlewareConfig {
343 MiddlewareConfig::new("auth")
344 .with_config("api_key", api_key)
345 .with_config("type", "bearer")
346 }
347
348 pub fn basic_auth(
349 username: impl Into<String>,
350 password: impl Into<String>,
351 ) -> MiddlewareConfig {
352 MiddlewareConfig::new("auth")
353 .with_config("username", username)
354 .with_config("password", password)
355 .with_config("type", "basic")
356 }
357}
358
359pub struct RateLimitMiddleware;
361
362impl RateLimitMiddleware {
363 pub fn per_second(requests_per_second: u32) -> MiddlewareConfig {
364 MiddlewareConfig::new("rate_limit")
365 .with_config("requests_per_second", requests_per_second.to_string())
366 }
367
368 pub fn with_burst(requests_per_second: u32, burst_size: u32) -> MiddlewareConfig {
369 MiddlewareConfig::new("rate_limit")
370 .with_config("requests_per_second", requests_per_second.to_string())
371 .with_config("burst_size", burst_size.to_string())
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378 use crate::config::create_server_info;
379
380 #[test]
381 fn test_server_builder_basic() {
382 let server_info = create_server_info(Some("test".to_string()), Some("1.0.0".to_string()));
383
384 let config = server_builder()
385 .with_server_info(server_info)
386 .with_port(3000)
387 .build()
388 .unwrap();
389
390 assert_eq!(config.port(), Some(3000));
391 assert_eq!(config.server_info.server_info.name, "test");
392 }
393
394 #[test]
395 fn test_server_builder_advanced() {
396 let server_info =
397 create_server_info(Some("advanced".to_string()), Some("1.0.0".to_string()));
398
399 let config = server_builder()
400 .with_server_info(server_info)
401 .with_transport(TransportType::Http {
402 port: 8080,
403 host: "0.0.0.0".to_string(),
404 })
405 .with_cors_policy(CorsPolicy::permissive())
406 .with_middleware(AuthMiddleware::bearer("secret-key"))
407 .with_middleware(RateLimitMiddleware::per_second(100))
408 .with_metrics_endpoint("/metrics")
409 .with_health_endpoint("/health")
410 .with_custom_endpoint("/api/v1/custom", "POST", "custom_handler")
411 .with_compression(true)
412 .build()
413 .unwrap();
414
415 assert_eq!(config.port(), Some(8080));
416 assert_eq!(config.host(), Some("0.0.0.0"));
417 assert!(config.cors_policy.is_some());
418 assert_eq!(config.middleware.len(), 2);
419 assert_eq!(config.custom_endpoints.len(), 1);
420 assert_eq!(config.metrics_endpoint, Some("/metrics".to_string()));
421 assert_eq!(config.health_endpoint, Some("/health".to_string()));
422 assert!(config.enable_compression);
423 }
424
425 #[test]
426 fn test_cors_policy() {
427 let cors = CorsPolicy::permissive()
428 .allow_origin("https://example.com")
429 .allow_method("PATCH");
430
431 assert!(cors.allowed_origins.contains(&"*".to_string()));
432 assert!(
433 cors.allowed_origins
434 .contains(&"https://example.com".to_string())
435 );
436 assert!(cors.allowed_methods.contains(&"PATCH".to_string()));
437 }
438
439 #[test]
440 fn test_transport_types() {
441 let http_transport = TransportType::Http {
442 port: 8080,
443 host: "localhost".to_string(),
444 };
445
446 let ws_transport = TransportType::WebSocket {
447 port: 8081,
448 host: "localhost".to_string(),
449 path: "/ws".to_string(),
450 };
451
452 let stdio_transport = TransportType::Stdio;
453
454 assert!(matches!(http_transport, TransportType::Http { .. }));
455 assert!(matches!(ws_transport, TransportType::WebSocket { .. }));
456 assert!(matches!(stdio_transport, TransportType::Stdio));
457 }
458
459 #[test]
460 fn test_tls_configuration() {
461 let server_info =
462 create_server_info(Some("tls-test".to_string()), Some("1.0.0".to_string()));
463
464 let tls_config = server_builder()
466 .with_server_info(server_info.clone())
467 .with_port(443)
468 .with_tls("/path/to/cert.pem", "/path/to/key.pem")
469 .build()
470 .unwrap();
471
472 assert!(tls_config.enable_tls);
473 assert_eq!(
474 tls_config.tls_cert_path,
475 Some("/path/to/cert.pem".to_string())
476 );
477 assert_eq!(
478 tls_config.tls_key_path,
479 Some("/path/to/key.pem".to_string())
480 );
481 assert!(tls_config.is_tls_configured());
482
483 let incomplete_tls = server_builder()
485 .with_server_info(server_info)
486 .with_port(443)
487 .build()
488 .unwrap();
489
490 assert!(!incomplete_tls.enable_tls);
491 assert!(!incomplete_tls.is_tls_configured());
492 }
493
494 #[test]
495 fn test_connection_limits() {
496 let server_info =
497 create_server_info(Some("limits-test".to_string()), Some("1.0.0".to_string()));
498
499 let custom_limits = server_builder()
501 .with_server_info(server_info.clone())
502 .with_max_connections(10000)
503 .with_connection_timeout(Duration::from_secs(120))
504 .build()
505 .unwrap();
506
507 assert_eq!(custom_limits.max_connections, 10000);
508 assert_eq!(custom_limits.connection_timeout, Duration::from_secs(120));
509
510 let default_limits = server_builder()
512 .with_server_info(server_info)
513 .build()
514 .unwrap();
515
516 assert_eq!(default_limits.max_connections, 1000);
517 assert_eq!(default_limits.connection_timeout, Duration::from_secs(30));
518 }
519
520 #[test]
521 fn test_custom_endpoints() {
522 let server_info = create_server_info(
523 Some("endpoints-test".to_string()),
524 Some("1.0.0".to_string()),
525 );
526
527 let config = server_builder()
528 .with_server_info(server_info)
529 .with_custom_endpoint("/api/v1/users", "GET", "list_users")
530 .with_custom_endpoint("/api/v1/users", "POST", "create_user")
531 .with_custom_endpoint("/api/v1/users/{id}", "GET", "get_user")
532 .with_custom_endpoint("/api/v1/users/{id}", "PUT", "update_user")
533 .with_custom_endpoint("/api/v1/users/{id}", "DELETE", "delete_user")
534 .build()
535 .unwrap();
536
537 assert_eq!(config.custom_endpoints.len(), 5);
538
539 let endpoints = &config.custom_endpoints;
541 assert_eq!(endpoints[0].path, "/api/v1/users");
542 assert_eq!(endpoints[0].method, "GET");
543 assert_eq!(endpoints[0].handler_name, "list_users");
544
545 assert_eq!(endpoints[4].path, "/api/v1/users/{id}");
546 assert_eq!(endpoints[4].method, "DELETE");
547 assert_eq!(endpoints[4].handler_name, "delete_user");
548 }
549
550 #[test]
551 fn test_middleware_ordering() {
552 let server_info = create_server_info(
553 Some("middleware-test".to_string()),
554 Some("1.0.0".to_string()),
555 );
556
557 let config = server_builder()
558 .with_server_info(server_info)
559 .with_middleware(AuthMiddleware::bearer("key1"))
560 .with_middleware(RateLimitMiddleware::per_second(50))
561 .with_middleware(AuthMiddleware::basic_auth("user", "pass"))
562 .with_middleware(RateLimitMiddleware::with_burst(100, 200))
563 .build()
564 .unwrap();
565
566 assert_eq!(config.middleware.len(), 4);
567
568 assert_eq!(config.middleware[0].name, "auth");
570 assert_eq!(
571 config.middleware[0].config.get("api_key"),
572 Some(&"key1".to_string())
573 );
574
575 assert_eq!(config.middleware[1].name, "rate_limit");
576 assert_eq!(
577 config.middleware[1].config.get("requests_per_second"),
578 Some(&"50".to_string())
579 );
580
581 assert_eq!(config.middleware[2].name, "auth");
582 assert_eq!(
583 config.middleware[2].config.get("type"),
584 Some(&"basic".to_string())
585 );
586
587 assert_eq!(config.middleware[3].name, "rate_limit");
588 assert_eq!(
589 config.middleware[3].config.get("burst_size"),
590 Some(&"200".to_string())
591 );
592 }
593}