1use crate::{CliError, McpConfiguration};
4use pulseengine_mcp_protocol::ServerInfo;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::time::Duration;
8use tracing::info;
9
10pub async fn run_server<C>(_config: C) -> std::result::Result<(), CliError>
12where
13 C: McpConfiguration,
14{
15 info!("Starting MCP server...");
24
25 _config.initialize_logging()?;
27
28 _config.validate()?;
30
31 info!("Server info: {:?}", _config.get_server_info());
32
33 Err(CliError::server_setup(
38 "Server implementation not yet complete",
39 ))
40}
41
42pub fn server_builder() -> ServerBuilder {
44 ServerBuilder::new()
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub enum TransportType {
50 Http { port: u16, host: String },
52 WebSocket {
54 port: u16,
55 host: String,
56 path: String,
57 },
58 Stdio,
60}
61
62impl Default for TransportType {
63 fn default() -> Self {
64 Self::Http {
65 port: 8080,
66 host: "localhost".to_string(),
67 }
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct CorsPolicy {
74 pub allowed_origins: Vec<String>,
75 pub allowed_methods: Vec<String>,
76 pub allowed_headers: Vec<String>,
77 pub allow_credentials: bool,
78 pub max_age: Option<Duration>,
79}
80
81impl CorsPolicy {
82 pub fn permissive() -> Self {
84 Self {
85 allowed_origins: vec!["*".to_string()],
86 allowed_methods: vec![
87 "GET".to_string(),
88 "POST".to_string(),
89 "PUT".to_string(),
90 "DELETE".to_string(),
91 "OPTIONS".to_string(),
92 ],
93 allowed_headers: vec!["*".to_string()],
94 allow_credentials: false,
95 max_age: Some(Duration::from_secs(3600)),
96 }
97 }
98
99 pub fn strict() -> Self {
101 Self {
102 allowed_origins: vec![],
103 allowed_methods: vec!["GET".to_string(), "POST".to_string()],
104 allowed_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
105 allow_credentials: true,
106 max_age: Some(Duration::from_secs(300)),
107 }
108 }
109
110 pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
112 self.allowed_origins.push(origin.into());
113 self
114 }
115
116 pub fn allow_method(mut self, method: impl Into<String>) -> Self {
118 self.allowed_methods.push(method.into());
119 self
120 }
121}
122
123#[derive(Debug, Clone)]
125pub struct CustomEndpoint {
126 pub path: String,
127 pub method: String,
128 pub handler_name: String,
129}
130
131impl CustomEndpoint {
132 pub fn new(
133 path: impl Into<String>,
134 method: impl Into<String>,
135 handler_name: impl Into<String>,
136 ) -> Self {
137 Self {
138 path: path.into(),
139 method: method.into(),
140 handler_name: handler_name.into(),
141 }
142 }
143}
144
145#[derive(Debug, Clone)]
147pub struct MiddlewareConfig {
148 pub name: String,
149 pub config: HashMap<String, String>,
150}
151
152impl MiddlewareConfig {
153 pub fn new(name: impl Into<String>) -> Self {
154 Self {
155 name: name.into(),
156 config: HashMap::new(),
157 }
158 }
159
160 pub fn with_config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
161 self.config.insert(key.into(), value.into());
162 self
163 }
164}
165
166pub struct ServerBuilder {
168 server_info: Option<ServerInfo>,
169 transport: Option<TransportType>,
170 cors_policy: Option<CorsPolicy>,
171 middleware: Vec<MiddlewareConfig>,
172 custom_endpoints: Vec<CustomEndpoint>,
173 metrics_endpoint: Option<String>,
174 health_endpoint: Option<String>,
175 connection_timeout: Option<Duration>,
176 max_connections: Option<usize>,
177 enable_compression: bool,
178 enable_tls: bool,
179 tls_cert_path: Option<String>,
180 tls_key_path: Option<String>,
181}
182
183impl ServerBuilder {
184 pub fn new() -> Self {
185 Self {
186 server_info: None,
187 transport: None,
188 cors_policy: None,
189 middleware: Vec::new(),
190 custom_endpoints: Vec::new(),
191 metrics_endpoint: None,
192 health_endpoint: None,
193 connection_timeout: None,
194 max_connections: None,
195 enable_compression: false,
196 enable_tls: false,
197 tls_cert_path: None,
198 tls_key_path: None,
199 }
200 }
201
202 pub fn with_server_info(mut self, info: ServerInfo) -> Self {
203 self.server_info = Some(info);
204 self
205 }
206
207 pub fn with_port(mut self, port: u16) -> Self {
208 self.transport = Some(TransportType::Http {
209 port,
210 host: "localhost".to_string(),
211 });
212 self
213 }
214
215 pub fn with_transport(mut self, transport: TransportType) -> Self {
216 self.transport = Some(transport);
217 self
218 }
219
220 pub fn with_cors_policy(mut self, cors: CorsPolicy) -> Self {
221 self.cors_policy = Some(cors);
222 self
223 }
224
225 pub fn with_middleware(mut self, middleware: MiddlewareConfig) -> Self {
226 self.middleware.push(middleware);
227 self
228 }
229
230 pub fn with_metrics_endpoint(mut self, path: impl Into<String>) -> Self {
231 self.metrics_endpoint = Some(path.into());
232 self
233 }
234
235 pub fn with_health_endpoint(mut self, path: impl Into<String>) -> Self {
236 self.health_endpoint = Some(path.into());
237 self
238 }
239
240 pub fn with_custom_endpoint(
241 mut self,
242 path: impl Into<String>,
243 method: impl Into<String>,
244 handler_name: impl Into<String>,
245 ) -> Self {
246 self.custom_endpoints
247 .push(CustomEndpoint::new(path, method, handler_name));
248 self
249 }
250
251 pub fn with_connection_timeout(mut self, timeout: Duration) -> Self {
252 self.connection_timeout = Some(timeout);
253 self
254 }
255
256 pub fn with_max_connections(mut self, max: usize) -> Self {
257 self.max_connections = Some(max);
258 self
259 }
260
261 pub fn with_compression(mut self, enable: bool) -> Self {
262 self.enable_compression = enable;
263 self
264 }
265
266 pub fn with_tls(mut self, cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
267 self.enable_tls = true;
268 self.tls_cert_path = Some(cert_path.into());
269 self.tls_key_path = Some(key_path.into());
270 self
271 }
272
273 pub fn build(self) -> Result<BuiltServerConfig, CliError> {
274 Ok(BuiltServerConfig {
275 server_info: self
276 .server_info
277 .ok_or_else(|| CliError::configuration("Server info is required"))?,
278 transport: self.transport.unwrap_or_default(),
279 cors_policy: self.cors_policy,
280 middleware: self.middleware,
281 custom_endpoints: self.custom_endpoints,
282 metrics_endpoint: self.metrics_endpoint,
283 health_endpoint: self.health_endpoint,
284 connection_timeout: self.connection_timeout.unwrap_or(Duration::from_secs(30)),
285 max_connections: self.max_connections.unwrap_or(1000),
286 enable_compression: self.enable_compression,
287 enable_tls: self.enable_tls,
288 tls_cert_path: self.tls_cert_path,
289 tls_key_path: self.tls_key_path,
290 })
291 }
292}
293
294impl Default for ServerBuilder {
295 fn default() -> Self {
296 Self::new()
297 }
298}
299
300#[derive(Debug, Clone)]
302pub struct BuiltServerConfig {
303 pub server_info: ServerInfo,
304 pub transport: TransportType,
305 pub cors_policy: Option<CorsPolicy>,
306 pub middleware: Vec<MiddlewareConfig>,
307 pub custom_endpoints: Vec<CustomEndpoint>,
308 pub metrics_endpoint: Option<String>,
309 pub health_endpoint: Option<String>,
310 pub connection_timeout: Duration,
311 pub max_connections: usize,
312 pub enable_compression: bool,
313 pub enable_tls: bool,
314 pub tls_cert_path: Option<String>,
315 pub tls_key_path: Option<String>,
316}
317
318impl BuiltServerConfig {
319 pub fn port(&self) -> Option<u16> {
321 match &self.transport {
322 TransportType::Http { port, .. } | TransportType::WebSocket { port, .. } => Some(*port),
323 TransportType::Stdio => None,
324 }
325 }
326
327 pub fn host(&self) -> Option<&str> {
329 match &self.transport {
330 TransportType::Http { host, .. } | TransportType::WebSocket { host, .. } => Some(host),
331 TransportType::Stdio => None,
332 }
333 }
334
335 pub fn is_tls_configured(&self) -> bool {
337 self.enable_tls && self.tls_cert_path.is_some() && self.tls_key_path.is_some()
338 }
339}
340
341pub struct AuthMiddleware;
343
344impl AuthMiddleware {
345 pub fn bearer(api_key: impl Into<String>) -> MiddlewareConfig {
346 MiddlewareConfig::new("auth")
347 .with_config("api_key", api_key)
348 .with_config("type", "bearer")
349 }
350
351 pub fn basic_auth(
352 username: impl Into<String>,
353 password: impl Into<String>,
354 ) -> MiddlewareConfig {
355 MiddlewareConfig::new("auth")
356 .with_config("username", username)
357 .with_config("password", password)
358 .with_config("type", "basic")
359 }
360}
361
362pub struct RateLimitMiddleware;
364
365impl RateLimitMiddleware {
366 pub fn per_second(requests_per_second: u32) -> MiddlewareConfig {
367 MiddlewareConfig::new("rate_limit")
368 .with_config("requests_per_second", requests_per_second.to_string())
369 }
370
371 pub fn with_burst(requests_per_second: u32, burst_size: u32) -> MiddlewareConfig {
372 MiddlewareConfig::new("rate_limit")
373 .with_config("requests_per_second", requests_per_second.to_string())
374 .with_config("burst_size", burst_size.to_string())
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use crate::config::create_server_info;
382
383 #[test]
384 fn test_server_builder_basic() {
385 let server_info = create_server_info(Some("test".to_string()), Some("1.0.0".to_string()));
386
387 let config = server_builder()
388 .with_server_info(server_info)
389 .with_port(3000)
390 .build()
391 .unwrap();
392
393 assert_eq!(config.port(), Some(3000));
394 assert_eq!(config.server_info.server_info.name, "test");
395 }
396
397 #[test]
398 fn test_server_builder_advanced() {
399 let server_info =
400 create_server_info(Some("advanced".to_string()), Some("1.0.0".to_string()));
401
402 let config = server_builder()
403 .with_server_info(server_info)
404 .with_transport(TransportType::Http {
405 port: 8080,
406 host: "0.0.0.0".to_string(),
407 })
408 .with_cors_policy(CorsPolicy::permissive())
409 .with_middleware(AuthMiddleware::bearer("secret-key"))
410 .with_middleware(RateLimitMiddleware::per_second(100))
411 .with_metrics_endpoint("/metrics")
412 .with_health_endpoint("/health")
413 .with_custom_endpoint("/api/v1/custom", "POST", "custom_handler")
414 .with_compression(true)
415 .build()
416 .unwrap();
417
418 assert_eq!(config.port(), Some(8080));
419 assert_eq!(config.host(), Some("0.0.0.0"));
420 assert!(config.cors_policy.is_some());
421 assert_eq!(config.middleware.len(), 2);
422 assert_eq!(config.custom_endpoints.len(), 1);
423 assert_eq!(config.metrics_endpoint, Some("/metrics".to_string()));
424 assert_eq!(config.health_endpoint, Some("/health".to_string()));
425 assert!(config.enable_compression);
426 }
427
428 #[test]
429 fn test_cors_policy() {
430 let cors = CorsPolicy::permissive()
431 .allow_origin("https://example.com")
432 .allow_method("PATCH");
433
434 assert!(cors.allowed_origins.contains(&"*".to_string()));
435 assert!(
436 cors.allowed_origins
437 .contains(&"https://example.com".to_string())
438 );
439 assert!(cors.allowed_methods.contains(&"PATCH".to_string()));
440 }
441
442 #[test]
443 fn test_transport_types() {
444 let http_transport = TransportType::Http {
445 port: 8080,
446 host: "localhost".to_string(),
447 };
448
449 let ws_transport = TransportType::WebSocket {
450 port: 8081,
451 host: "localhost".to_string(),
452 path: "/ws".to_string(),
453 };
454
455 let stdio_transport = TransportType::Stdio;
456
457 assert!(matches!(http_transport, TransportType::Http { .. }));
458 assert!(matches!(ws_transport, TransportType::WebSocket { .. }));
459 assert!(matches!(stdio_transport, TransportType::Stdio));
460 }
461
462 #[test]
463 fn test_tls_configuration() {
464 let server_info =
465 create_server_info(Some("tls-test".to_string()), Some("1.0.0".to_string()));
466
467 let tls_config = server_builder()
469 .with_server_info(server_info.clone())
470 .with_port(443)
471 .with_tls("/path/to/cert.pem", "/path/to/key.pem")
472 .build()
473 .unwrap();
474
475 assert!(tls_config.enable_tls);
476 assert_eq!(
477 tls_config.tls_cert_path,
478 Some("/path/to/cert.pem".to_string())
479 );
480 assert_eq!(
481 tls_config.tls_key_path,
482 Some("/path/to/key.pem".to_string())
483 );
484 assert!(tls_config.is_tls_configured());
485
486 let incomplete_tls = server_builder()
488 .with_server_info(server_info)
489 .with_port(443)
490 .build()
491 .unwrap();
492
493 assert!(!incomplete_tls.enable_tls);
494 assert!(!incomplete_tls.is_tls_configured());
495 }
496
497 #[test]
498 fn test_connection_limits() {
499 let server_info =
500 create_server_info(Some("limits-test".to_string()), Some("1.0.0".to_string()));
501
502 let custom_limits = server_builder()
504 .with_server_info(server_info.clone())
505 .with_max_connections(10000)
506 .with_connection_timeout(Duration::from_secs(120))
507 .build()
508 .unwrap();
509
510 assert_eq!(custom_limits.max_connections, 10000);
511 assert_eq!(custom_limits.connection_timeout, Duration::from_secs(120));
512
513 let default_limits = server_builder()
515 .with_server_info(server_info)
516 .build()
517 .unwrap();
518
519 assert_eq!(default_limits.max_connections, 1000);
520 assert_eq!(default_limits.connection_timeout, Duration::from_secs(30));
521 }
522
523 #[test]
524 fn test_custom_endpoints() {
525 let server_info = create_server_info(
526 Some("endpoints-test".to_string()),
527 Some("1.0.0".to_string()),
528 );
529
530 let config = server_builder()
531 .with_server_info(server_info)
532 .with_custom_endpoint("/api/v1/users", "GET", "list_users")
533 .with_custom_endpoint("/api/v1/users", "POST", "create_user")
534 .with_custom_endpoint("/api/v1/users/{id}", "GET", "get_user")
535 .with_custom_endpoint("/api/v1/users/{id}", "PUT", "update_user")
536 .with_custom_endpoint("/api/v1/users/{id}", "DELETE", "delete_user")
537 .build()
538 .unwrap();
539
540 assert_eq!(config.custom_endpoints.len(), 5);
541
542 let endpoints = &config.custom_endpoints;
544 assert_eq!(endpoints[0].path, "/api/v1/users");
545 assert_eq!(endpoints[0].method, "GET");
546 assert_eq!(endpoints[0].handler_name, "list_users");
547
548 assert_eq!(endpoints[4].path, "/api/v1/users/{id}");
549 assert_eq!(endpoints[4].method, "DELETE");
550 assert_eq!(endpoints[4].handler_name, "delete_user");
551 }
552
553 #[test]
554 fn test_middleware_ordering() {
555 let server_info = create_server_info(
556 Some("middleware-test".to_string()),
557 Some("1.0.0".to_string()),
558 );
559
560 let config = server_builder()
561 .with_server_info(server_info)
562 .with_middleware(AuthMiddleware::bearer("key1"))
563 .with_middleware(RateLimitMiddleware::per_second(50))
564 .with_middleware(AuthMiddleware::basic_auth("user", "pass"))
565 .with_middleware(RateLimitMiddleware::with_burst(100, 200))
566 .build()
567 .unwrap();
568
569 assert_eq!(config.middleware.len(), 4);
570
571 assert_eq!(config.middleware[0].name, "auth");
573 assert_eq!(
574 config.middleware[0].config.get("api_key"),
575 Some(&"key1".to_string())
576 );
577
578 assert_eq!(config.middleware[1].name, "rate_limit");
579 assert_eq!(
580 config.middleware[1].config.get("requests_per_second"),
581 Some(&"50".to_string())
582 );
583
584 assert_eq!(config.middleware[2].name, "auth");
585 assert_eq!(
586 config.middleware[2].config.get("type"),
587 Some(&"basic".to_string())
588 );
589
590 assert_eq!(config.middleware[3].name, "rate_limit");
591 assert_eq!(
592 config.middleware[3].config.get("burst_size"),
593 Some(&"200".to_string())
594 );
595 }
596}