1use crate::{CliError, McpConfiguration};
4use pulseengine_mcp_protocol::ServerInfo;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::time::Duration;
8use tracing::info;
9
10pub async fn run_server<C>(_config: C) -> std::result::Result<(), CliError>
12where
13 C: McpConfiguration,
14{
15 info!("Starting MCP server...");
24
25 _config.initialize_logging()?;
27
28 _config.validate()?;
30
31 info!("Server info: {:?}", _config.get_server_info());
32
33 Err(CliError::server_setup(
38 "Server implementation not yet complete",
39 ))
40}
41
42pub fn server_builder() -> ServerBuilder {
44 ServerBuilder::new()
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub enum TransportType {
50 Http {
52 port: u16,
53 host: String,
54 },
55 WebSocket {
57 port: u16,
58 host: String,
59 path: String,
60 },
61 Stdio,
63}
64
65impl Default for TransportType {
66 fn default() -> Self {
67 Self::Http {
68 port: 8080,
69 host: "localhost".to_string(),
70 }
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct CorsPolicy {
77 pub allowed_origins: Vec<String>,
78 pub allowed_methods: Vec<String>,
79 pub allowed_headers: Vec<String>,
80 pub allow_credentials: bool,
81 pub max_age: Option<Duration>,
82}
83
84impl CorsPolicy {
85 pub fn permissive() -> Self {
87 Self {
88 allowed_origins: vec!["*".to_string()],
89 allowed_methods: vec![
90 "GET".to_string(),
91 "POST".to_string(),
92 "PUT".to_string(),
93 "DELETE".to_string(),
94 "OPTIONS".to_string(),
95 ],
96 allowed_headers: vec!["*".to_string()],
97 allow_credentials: false,
98 max_age: Some(Duration::from_secs(3600)),
99 }
100 }
101
102 pub fn strict() -> Self {
104 Self {
105 allowed_origins: vec![],
106 allowed_methods: vec!["GET".to_string(), "POST".to_string()],
107 allowed_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
108 allow_credentials: true,
109 max_age: Some(Duration::from_secs(300)),
110 }
111 }
112
113 pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
115 self.allowed_origins.push(origin.into());
116 self
117 }
118
119 pub fn allow_method(mut self, method: impl Into<String>) -> Self {
121 self.allowed_methods.push(method.into());
122 self
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct CustomEndpoint {
129 pub path: String,
130 pub method: String,
131 pub handler_name: String,
132}
133
134impl CustomEndpoint {
135 pub fn new(path: impl Into<String>, method: impl Into<String>, handler_name: impl Into<String>) -> Self {
136 Self {
137 path: path.into(),
138 method: method.into(),
139 handler_name: handler_name.into(),
140 }
141 }
142}
143
144#[derive(Debug, Clone)]
146pub struct MiddlewareConfig {
147 pub name: String,
148 pub config: HashMap<String, String>,
149}
150
151impl MiddlewareConfig {
152 pub fn new(name: impl Into<String>) -> Self {
153 Self {
154 name: name.into(),
155 config: HashMap::new(),
156 }
157 }
158
159 pub fn with_config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
160 self.config.insert(key.into(), value.into());
161 self
162 }
163}
164
165pub struct ServerBuilder {
167 server_info: Option<ServerInfo>,
168 transport: Option<TransportType>,
169 cors_policy: Option<CorsPolicy>,
170 middleware: Vec<MiddlewareConfig>,
171 custom_endpoints: Vec<CustomEndpoint>,
172 metrics_endpoint: Option<String>,
173 health_endpoint: Option<String>,
174 connection_timeout: Option<Duration>,
175 max_connections: Option<usize>,
176 enable_compression: bool,
177 enable_tls: bool,
178 tls_cert_path: Option<String>,
179 tls_key_path: Option<String>,
180}
181
182impl ServerBuilder {
183 pub fn new() -> Self {
184 Self {
185 server_info: None,
186 transport: None,
187 cors_policy: None,
188 middleware: Vec::new(),
189 custom_endpoints: Vec::new(),
190 metrics_endpoint: None,
191 health_endpoint: None,
192 connection_timeout: None,
193 max_connections: None,
194 enable_compression: false,
195 enable_tls: false,
196 tls_cert_path: None,
197 tls_key_path: None,
198 }
199 }
200
201 pub fn with_server_info(mut self, info: ServerInfo) -> Self {
202 self.server_info = Some(info);
203 self
204 }
205
206 pub fn with_port(mut self, port: u16) -> Self {
207 self.transport = Some(TransportType::Http {
208 port,
209 host: "localhost".to_string(),
210 });
211 self
212 }
213
214 pub fn with_transport(mut self, transport: TransportType) -> Self {
215 self.transport = Some(transport);
216 self
217 }
218
219 pub fn with_cors_policy(mut self, cors: CorsPolicy) -> Self {
220 self.cors_policy = Some(cors);
221 self
222 }
223
224 pub fn with_middleware(mut self, middleware: MiddlewareConfig) -> Self {
225 self.middleware.push(middleware);
226 self
227 }
228
229 pub fn with_metrics_endpoint(mut self, path: impl Into<String>) -> Self {
230 self.metrics_endpoint = Some(path.into());
231 self
232 }
233
234 pub fn with_health_endpoint(mut self, path: impl Into<String>) -> Self {
235 self.health_endpoint = Some(path.into());
236 self
237 }
238
239 pub fn with_custom_endpoint(
240 mut self,
241 path: impl Into<String>,
242 method: impl Into<String>,
243 handler_name: impl Into<String>,
244 ) -> Self {
245 self.custom_endpoints.push(CustomEndpoint::new(path, method, handler_name));
246 self
247 }
248
249 pub fn with_connection_timeout(mut self, timeout: Duration) -> Self {
250 self.connection_timeout = Some(timeout);
251 self
252 }
253
254 pub fn with_max_connections(mut self, max: usize) -> Self {
255 self.max_connections = Some(max);
256 self
257 }
258
259 pub fn with_compression(mut self, enable: bool) -> Self {
260 self.enable_compression = enable;
261 self
262 }
263
264 pub fn with_tls(mut self, cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
265 self.enable_tls = true;
266 self.tls_cert_path = Some(cert_path.into());
267 self.tls_key_path = Some(key_path.into());
268 self
269 }
270
271 pub fn build(self) -> Result<BuiltServerConfig, CliError> {
272 Ok(BuiltServerConfig {
273 server_info: self
274 .server_info
275 .ok_or_else(|| CliError::configuration("Server info is required"))?,
276 transport: self.transport.unwrap_or_default(),
277 cors_policy: self.cors_policy,
278 middleware: self.middleware,
279 custom_endpoints: self.custom_endpoints,
280 metrics_endpoint: self.metrics_endpoint,
281 health_endpoint: self.health_endpoint,
282 connection_timeout: self.connection_timeout.unwrap_or(Duration::from_secs(30)),
283 max_connections: self.max_connections.unwrap_or(1000),
284 enable_compression: self.enable_compression,
285 enable_tls: self.enable_tls,
286 tls_cert_path: self.tls_cert_path,
287 tls_key_path: self.tls_key_path,
288 })
289 }
290}
291
292impl Default for ServerBuilder {
293 fn default() -> Self {
294 Self::new()
295 }
296}
297
298#[derive(Debug, Clone)]
300pub struct BuiltServerConfig {
301 pub server_info: ServerInfo,
302 pub transport: TransportType,
303 pub cors_policy: Option<CorsPolicy>,
304 pub middleware: Vec<MiddlewareConfig>,
305 pub custom_endpoints: Vec<CustomEndpoint>,
306 pub metrics_endpoint: Option<String>,
307 pub health_endpoint: Option<String>,
308 pub connection_timeout: Duration,
309 pub max_connections: usize,
310 pub enable_compression: bool,
311 pub enable_tls: bool,
312 pub tls_cert_path: Option<String>,
313 pub tls_key_path: Option<String>,
314}
315
316impl BuiltServerConfig {
317 pub fn port(&self) -> Option<u16> {
319 match &self.transport {
320 TransportType::Http { port, .. } | TransportType::WebSocket { port, .. } => Some(*port),
321 TransportType::Stdio => None,
322 }
323 }
324
325 pub fn host(&self) -> Option<&str> {
327 match &self.transport {
328 TransportType::Http { host, .. } | TransportType::WebSocket { host, .. } => Some(host),
329 TransportType::Stdio => None,
330 }
331 }
332
333 pub fn is_tls_configured(&self) -> bool {
335 self.enable_tls && self.tls_cert_path.is_some() && self.tls_key_path.is_some()
336 }
337}
338
339pub struct AuthMiddleware;
341
342impl AuthMiddleware {
343 pub fn bearer(api_key: impl Into<String>) -> MiddlewareConfig {
344 MiddlewareConfig::new("auth")
345 .with_config("api_key", api_key)
346 .with_config("type", "bearer")
347 }
348
349 pub fn basic_auth(username: impl Into<String>, password: impl Into<String>) -> MiddlewareConfig {
350 MiddlewareConfig::new("auth")
351 .with_config("username", username)
352 .with_config("password", password)
353 .with_config("type", "basic")
354 }
355}
356
357pub struct RateLimitMiddleware;
359
360impl RateLimitMiddleware {
361 pub fn per_second(requests_per_second: u32) -> MiddlewareConfig {
362 MiddlewareConfig::new("rate_limit")
363 .with_config("requests_per_second", requests_per_second.to_string())
364 }
365
366 pub fn with_burst(requests_per_second: u32, burst_size: u32) -> MiddlewareConfig {
367 MiddlewareConfig::new("rate_limit")
368 .with_config("requests_per_second", requests_per_second.to_string())
369 .with_config("burst_size", burst_size.to_string())
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376 use crate::config::create_server_info;
377
378 #[test]
379 fn test_server_builder_basic() {
380 let server_info = create_server_info(Some("test".to_string()), Some("1.0.0".to_string()));
381
382 let config = server_builder()
383 .with_server_info(server_info)
384 .with_port(3000)
385 .build()
386 .unwrap();
387
388 assert_eq!(config.port(), Some(3000));
389 assert_eq!(config.server_info.server_info.name, "test");
390 }
391
392 #[test]
393 fn test_server_builder_advanced() {
394 let server_info = create_server_info(Some("advanced".to_string()), Some("1.0.0".to_string()));
395
396 let config = server_builder()
397 .with_server_info(server_info)
398 .with_transport(TransportType::Http {
399 port: 8080,
400 host: "0.0.0.0".to_string(),
401 })
402 .with_cors_policy(CorsPolicy::permissive())
403 .with_middleware(AuthMiddleware::bearer("secret-key"))
404 .with_middleware(RateLimitMiddleware::per_second(100))
405 .with_metrics_endpoint("/metrics")
406 .with_health_endpoint("/health")
407 .with_custom_endpoint("/api/v1/custom", "POST", "custom_handler")
408 .with_compression(true)
409 .build()
410 .unwrap();
411
412 assert_eq!(config.port(), Some(8080));
413 assert_eq!(config.host(), Some("0.0.0.0"));
414 assert!(config.cors_policy.is_some());
415 assert_eq!(config.middleware.len(), 2);
416 assert_eq!(config.custom_endpoints.len(), 1);
417 assert_eq!(config.metrics_endpoint, Some("/metrics".to_string()));
418 assert_eq!(config.health_endpoint, Some("/health".to_string()));
419 assert!(config.enable_compression);
420 }
421
422 #[test]
423 fn test_cors_policy() {
424 let cors = CorsPolicy::permissive()
425 .allow_origin("https://example.com")
426 .allow_method("PATCH");
427
428 assert!(cors.allowed_origins.contains(&"*".to_string()));
429 assert!(cors.allowed_origins.contains(&"https://example.com".to_string()));
430 assert!(cors.allowed_methods.contains(&"PATCH".to_string()));
431 }
432
433 #[test]
434 fn test_transport_types() {
435 let http_transport = TransportType::Http {
436 port: 8080,
437 host: "localhost".to_string(),
438 };
439
440 let ws_transport = TransportType::WebSocket {
441 port: 8081,
442 host: "localhost".to_string(),
443 path: "/ws".to_string(),
444 };
445
446 let stdio_transport = TransportType::Stdio;
447
448 assert!(matches!(http_transport, TransportType::Http { .. }));
449 assert!(matches!(ws_transport, TransportType::WebSocket { .. }));
450 assert!(matches!(stdio_transport, TransportType::Stdio));
451 }
452
453 #[test]
454 fn test_tls_configuration() {
455 let server_info = create_server_info(Some("tls-test".to_string()), Some("1.0.0".to_string()));
456
457 let tls_config = server_builder()
459 .with_server_info(server_info.clone())
460 .with_port(443)
461 .with_tls("/path/to/cert.pem", "/path/to/key.pem")
462 .build()
463 .unwrap();
464
465 assert!(tls_config.enable_tls);
466 assert_eq!(tls_config.tls_cert_path, Some("/path/to/cert.pem".to_string()));
467 assert_eq!(tls_config.tls_key_path, Some("/path/to/key.pem".to_string()));
468 assert!(tls_config.is_tls_configured());
469
470 let incomplete_tls = server_builder()
472 .with_server_info(server_info)
473 .with_port(443)
474 .build()
475 .unwrap();
476
477 assert!(!incomplete_tls.enable_tls);
478 assert!(!incomplete_tls.is_tls_configured());
479 }
480
481 #[test]
482 fn test_connection_limits() {
483 let server_info = create_server_info(Some("limits-test".to_string()), Some("1.0.0".to_string()));
484
485 let custom_limits = server_builder()
487 .with_server_info(server_info.clone())
488 .with_max_connections(10000)
489 .with_connection_timeout(Duration::from_secs(120))
490 .build()
491 .unwrap();
492
493 assert_eq!(custom_limits.max_connections, 10000);
494 assert_eq!(custom_limits.connection_timeout, Duration::from_secs(120));
495
496 let default_limits = server_builder()
498 .with_server_info(server_info)
499 .build()
500 .unwrap();
501
502 assert_eq!(default_limits.max_connections, 1000);
503 assert_eq!(default_limits.connection_timeout, Duration::from_secs(30));
504 }
505
506 #[test]
507 fn test_custom_endpoints() {
508 let server_info = create_server_info(Some("endpoints-test".to_string()), Some("1.0.0".to_string()));
509
510 let config = server_builder()
511 .with_server_info(server_info)
512 .with_custom_endpoint("/api/v1/users", "GET", "list_users")
513 .with_custom_endpoint("/api/v1/users", "POST", "create_user")
514 .with_custom_endpoint("/api/v1/users/{id}", "GET", "get_user")
515 .with_custom_endpoint("/api/v1/users/{id}", "PUT", "update_user")
516 .with_custom_endpoint("/api/v1/users/{id}", "DELETE", "delete_user")
517 .build()
518 .unwrap();
519
520 assert_eq!(config.custom_endpoints.len(), 5);
521
522 let endpoints = &config.custom_endpoints;
524 assert_eq!(endpoints[0].path, "/api/v1/users");
525 assert_eq!(endpoints[0].method, "GET");
526 assert_eq!(endpoints[0].handler_name, "list_users");
527
528 assert_eq!(endpoints[4].path, "/api/v1/users/{id}");
529 assert_eq!(endpoints[4].method, "DELETE");
530 assert_eq!(endpoints[4].handler_name, "delete_user");
531 }
532
533 #[test]
534 fn test_middleware_ordering() {
535 let server_info = create_server_info(Some("middleware-test".to_string()), Some("1.0.0".to_string()));
536
537 let config = server_builder()
538 .with_server_info(server_info)
539 .with_middleware(AuthMiddleware::bearer("key1"))
540 .with_middleware(RateLimitMiddleware::per_second(50))
541 .with_middleware(AuthMiddleware::basic_auth("user", "pass"))
542 .with_middleware(RateLimitMiddleware::with_burst(100, 200))
543 .build()
544 .unwrap();
545
546 assert_eq!(config.middleware.len(), 4);
547
548 assert_eq!(config.middleware[0].name, "auth");
550 assert_eq!(config.middleware[0].config.get("api_key"), Some(&"key1".to_string()));
551
552 assert_eq!(config.middleware[1].name, "rate_limit");
553 assert_eq!(config.middleware[1].config.get("requests_per_second"), Some(&"50".to_string()));
554
555 assert_eq!(config.middleware[2].name, "auth");
556 assert_eq!(config.middleware[2].config.get("type"), Some(&"basic".to_string()));
557
558 assert_eq!(config.middleware[3].name, "rate_limit");
559 assert_eq!(config.middleware[3].config.get("burst_size"), Some(&"200".to_string()));
560 }
561}