use crate::runtime::RwLock;
use crate::server::http_middleware::{
ServerHttpLoggingMiddleware, ServerHttpMiddleware, ServerHttpMiddlewareChain,
};
use crate::shared::middleware::{EnhancedMiddlewareChain, MetricsMiddleware, RateLimitMiddleware};
use std::sync::Arc;
#[allow(missing_debug_implementations)]
pub struct ServerPreset {
protocol_chain: Arc<RwLock<EnhancedMiddlewareChain>>,
http_chain: Option<Arc<ServerHttpMiddlewareChain>>,
service_name: String,
}
impl ServerPreset {
pub fn new(service_name: impl Into<String>) -> Self {
let service_name = service_name.into();
let mut protocol_chain = EnhancedMiddlewareChain::new();
protocol_chain.add(Arc::new(MetricsMiddleware::new(service_name.clone())));
let mut http_chain = ServerHttpMiddlewareChain::new();
let logging = ServerHttpLoggingMiddleware::new()
.with_level(tracing::Level::INFO)
.with_redact_query(true);
http_chain.add(Arc::new(logging));
Self {
protocol_chain: Arc::new(RwLock::new(protocol_chain)),
http_chain: Some(Arc::new(http_chain)),
service_name,
}
}
pub fn with_http_middleware_item(
mut self,
middleware: impl ServerHttpMiddleware + 'static,
) -> Self {
if let Some(chain) = Arc::get_mut(self.http_chain.as_mut().unwrap()) {
chain.add(Arc::new(middleware));
} else {
let mut new_chain = ServerHttpMiddlewareChain::new();
new_chain.add(Arc::new(middleware));
self.http_chain = Some(Arc::new(new_chain));
}
self
}
pub fn with_rate_limit(self, rate_limiter: RateLimitMiddleware) -> Self {
if let Ok(mut chain) = self.protocol_chain.try_write() {
chain.add(Arc::new(rate_limiter));
}
self
}
pub fn protocol_middleware(&self) -> Arc<RwLock<EnhancedMiddlewareChain>> {
self.protocol_chain.clone()
}
pub fn http_middleware(&self) -> Option<Arc<ServerHttpMiddlewareChain>> {
self.http_chain.clone()
}
pub fn service_name(&self) -> &str {
&self.service_name
}
}
impl Default for ServerPreset {
fn default() -> Self {
Self::new("mcp-server")
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_preset_creation() {
let preset = ServerPreset::new("test-service");
assert_eq!(preset.service_name(), "test-service");
assert!(preset.protocol_middleware().try_read().is_ok());
assert!(preset.http_middleware().is_some());
}
#[test]
fn test_preset_default() {
let preset = ServerPreset::default();
assert_eq!(preset.service_name(), "mcp-server");
}
#[test]
fn test_preset_with_rate_limit() {
let rate_limiter = RateLimitMiddleware::new(100, 100, Duration::from_secs(60));
let preset = ServerPreset::new("test-service").with_rate_limit(rate_limiter);
let chain = preset.protocol_middleware();
assert!(chain.try_read().is_ok());
}
#[test]
fn test_preset_with_http_middleware() {
let logging = ServerHttpLoggingMiddleware::new();
let preset = ServerPreset::new("test-service").with_http_middleware_item(logging);
assert!(preset.http_middleware().is_some());
}
#[test]
fn test_preset_chaining() {
let rate_limiter = RateLimitMiddleware::new(100, 100, Duration::from_secs(60));
let logging = ServerHttpLoggingMiddleware::new();
let preset = ServerPreset::new("test-service")
.with_http_middleware_item(logging)
.with_rate_limit(rate_limiter);
assert!(preset.http_middleware().is_some());
assert!(preset.protocol_middleware().try_read().is_ok());
}
}