use std::collections::HashSet;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
pub use turbomcp_core::SUPPORTED_VERSIONS as SUPPORTED_PROTOCOL_VERSIONS;
pub use turbomcp_core::types::core::ProtocolVersion;
pub const DEFAULT_MAX_CONNECTIONS: usize = 1000;
pub const DEFAULT_RATE_LIMIT: u32 = 100;
pub const DEFAULT_RATE_LIMIT_WINDOW: Duration = Duration::from_secs(1);
pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub protocol: ProtocolConfig,
pub rate_limit: Option<RateLimitConfig>,
pub connection_limits: ConnectionLimits,
pub required_capabilities: RequiredCapabilities,
pub max_message_size: usize,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
protocol: ProtocolConfig::default(),
rate_limit: None,
connection_limits: ConnectionLimits::default(),
required_capabilities: RequiredCapabilities::default(),
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
}
}
}
impl ServerConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn builder() -> ServerConfigBuilder {
ServerConfigBuilder::default()
}
}
#[derive(Debug, Clone, Default)]
pub struct ServerConfigBuilder {
protocol: Option<ProtocolConfig>,
rate_limit: Option<RateLimitConfig>,
connection_limits: Option<ConnectionLimits>,
required_capabilities: Option<RequiredCapabilities>,
max_message_size: Option<usize>,
}
impl ServerConfigBuilder {
#[must_use]
pub fn protocol(mut self, config: ProtocolConfig) -> Self {
self.protocol = Some(config);
self
}
#[must_use]
pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
self.rate_limit = Some(config);
self
}
#[must_use]
pub fn connection_limits(mut self, limits: ConnectionLimits) -> Self {
self.connection_limits = Some(limits);
self
}
#[must_use]
pub fn required_capabilities(mut self, caps: RequiredCapabilities) -> Self {
self.required_capabilities = Some(caps);
self
}
#[must_use]
pub fn max_message_size(mut self, size: usize) -> Self {
self.max_message_size = Some(size);
self
}
#[must_use]
pub fn build(self) -> ServerConfig {
ServerConfig {
protocol: self.protocol.unwrap_or_default(),
rate_limit: self.rate_limit,
connection_limits: self.connection_limits.unwrap_or_default(),
required_capabilities: self.required_capabilities.unwrap_or_default(),
max_message_size: self.max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE),
}
}
pub fn try_build(self) -> Result<ServerConfig, ConfigValidationError> {
let max_message_size = self.max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE);
if max_message_size < 1024 {
return Err(ConfigValidationError::InvalidMessageSize {
size: max_message_size,
min: 1024,
});
}
if let Some(ref rate_limit) = self.rate_limit {
if rate_limit.max_requests == 0 {
return Err(ConfigValidationError::InvalidRateLimit {
reason: "max_requests cannot be 0".to_string(),
});
}
if rate_limit.window.is_zero() {
return Err(ConfigValidationError::InvalidRateLimit {
reason: "rate limit window cannot be zero".to_string(),
});
}
}
let connection_limits = self.connection_limits.unwrap_or_default();
if connection_limits.max_tcp_connections == 0
&& connection_limits.max_websocket_connections == 0
&& connection_limits.max_http_concurrent == 0
&& connection_limits.max_unix_connections == 0
{
return Err(ConfigValidationError::InvalidConnectionLimits {
reason: "at least one connection limit must be non-zero".to_string(),
});
}
Ok(ServerConfig {
protocol: self.protocol.unwrap_or_default(),
rate_limit: self.rate_limit,
connection_limits,
required_capabilities: self.required_capabilities.unwrap_or_default(),
max_message_size,
})
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum ConfigValidationError {
#[error("Invalid max_message_size: {size} bytes is below minimum of {min} bytes")]
InvalidMessageSize {
size: usize,
min: usize,
},
#[error("Invalid rate limit: {reason}")]
InvalidRateLimit {
reason: String,
},
#[error("Invalid connection limits: {reason}")]
InvalidConnectionLimits {
reason: String,
},
}
#[derive(Debug, Clone)]
pub struct ProtocolConfig {
pub preferred_version: ProtocolVersion,
pub supported_versions: Vec<ProtocolVersion>,
pub allow_fallback: bool,
}
impl Default for ProtocolConfig {
fn default() -> Self {
Self {
preferred_version: ProtocolVersion::LATEST.clone(),
supported_versions: vec![ProtocolVersion::LATEST.clone()],
allow_fallback: false,
}
}
}
impl ProtocolConfig {
#[must_use]
pub fn strict(version: impl Into<ProtocolVersion>) -> Self {
let v = version.into();
Self {
preferred_version: v.clone(),
supported_versions: vec![v],
allow_fallback: false,
}
}
#[must_use]
pub fn multi_version() -> Self {
Self {
preferred_version: ProtocolVersion::LATEST.clone(),
supported_versions: ProtocolVersion::STABLE.to_vec(),
allow_fallback: false,
}
}
#[must_use]
pub fn is_supported(&self, version: &ProtocolVersion) -> bool {
self.supported_versions.contains(version)
}
#[must_use]
pub fn negotiate(&self, client_version: Option<&str>) -> Option<ProtocolVersion> {
match client_version {
Some(version_str) => {
let version = ProtocolVersion::from(version_str);
if self.is_supported(&version) {
Some(version)
} else if self.allow_fallback {
Some(self.preferred_version.clone())
} else {
None
}
}
None => Some(self.preferred_version.clone()),
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_requests: u32,
pub window: Duration,
pub per_client: bool,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests: DEFAULT_RATE_LIMIT,
window: DEFAULT_RATE_LIMIT_WINDOW,
per_client: true,
}
}
}
impl RateLimitConfig {
#[must_use]
pub fn new(max_requests: u32, window: Duration) -> Self {
Self {
max_requests,
window,
per_client: true,
}
}
#[must_use]
pub fn per_client(mut self, enabled: bool) -> Self {
self.per_client = enabled;
self
}
}
#[derive(Debug, Clone)]
pub struct ConnectionLimits {
pub max_tcp_connections: usize,
pub max_websocket_connections: usize,
pub max_http_concurrent: usize,
pub max_unix_connections: usize,
}
impl Default for ConnectionLimits {
fn default() -> Self {
Self {
max_tcp_connections: DEFAULT_MAX_CONNECTIONS,
max_websocket_connections: DEFAULT_MAX_CONNECTIONS,
max_http_concurrent: DEFAULT_MAX_CONNECTIONS,
max_unix_connections: DEFAULT_MAX_CONNECTIONS,
}
}
}
impl ConnectionLimits {
#[must_use]
pub fn new(max_connections: usize) -> Self {
Self {
max_tcp_connections: max_connections,
max_websocket_connections: max_connections,
max_http_concurrent: max_connections,
max_unix_connections: max_connections,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RequiredCapabilities {
#[serde(default)]
pub roots: bool,
#[serde(default)]
pub sampling: bool,
#[serde(default)]
pub experimental: HashSet<String>,
}
impl RequiredCapabilities {
#[must_use]
pub fn none() -> Self {
Self::default()
}
#[must_use]
pub fn with_roots(mut self) -> Self {
self.roots = true;
self
}
#[must_use]
pub fn with_sampling(mut self) -> Self {
self.sampling = true;
self
}
#[must_use]
pub fn with_experimental(mut self, name: impl Into<String>) -> Self {
self.experimental.insert(name.into());
self
}
#[must_use]
pub fn validate(&self, client_caps: &ClientCapabilities) -> CapabilityValidation {
let mut missing = Vec::new();
if self.roots && !client_caps.roots {
missing.push("roots".to_string());
}
if self.sampling && !client_caps.sampling {
missing.push("sampling".to_string());
}
for exp in &self.experimental {
if !client_caps.experimental.contains(exp) {
missing.push(format!("experimental/{}", exp));
}
}
if missing.is_empty() {
CapabilityValidation::Valid
} else {
CapabilityValidation::Missing(missing)
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ClientCapabilities {
#[serde(default)]
pub roots: bool,
#[serde(default)]
pub sampling: bool,
#[serde(default)]
pub experimental: HashSet<String>,
}
impl ClientCapabilities {
#[must_use]
pub fn from_params(params: &serde_json::Value) -> Self {
let caps = params.get("capabilities").cloned().unwrap_or_default();
Self {
roots: caps.get("roots").map(|v| !v.is_null()).unwrap_or(false),
sampling: caps.get("sampling").map(|v| !v.is_null()).unwrap_or(false),
experimental: caps
.get("experimental")
.and_then(|v| v.as_object())
.map(|obj| obj.keys().cloned().collect())
.unwrap_or_default(),
}
}
}
#[derive(Debug, Clone)]
pub enum CapabilityValidation {
Valid,
Missing(Vec<String>),
}
impl CapabilityValidation {
#[must_use]
pub fn is_valid(&self) -> bool {
matches!(self, Self::Valid)
}
#[must_use]
pub fn missing(&self) -> Option<&[String]> {
match self {
Self::Valid => None,
Self::Missing(caps) => Some(caps),
}
}
}
#[derive(Debug)]
pub struct RateLimiter {
config: RateLimitConfig,
global_bucket: Mutex<TokenBucket>,
client_buckets: Mutex<std::collections::HashMap<String, TokenBucket>>,
last_cleanup: Mutex<Instant>,
}
impl RateLimiter {
#[must_use]
pub fn new(config: RateLimitConfig) -> Self {
Self {
global_bucket: Mutex::new(TokenBucket::new(config.max_requests, config.window)),
client_buckets: Mutex::new(std::collections::HashMap::new()),
last_cleanup: Mutex::new(Instant::now()),
config,
}
}
pub fn check(&self, client_id: Option<&str>) -> bool {
let needs_cleanup = {
let last = self.last_cleanup.lock();
last.elapsed() > Duration::from_secs(60)
};
if needs_cleanup {
self.cleanup(Duration::from_secs(300));
*self.last_cleanup.lock() = Instant::now();
}
if self.config.per_client {
if let Some(id) = client_id {
let mut buckets = self.client_buckets.lock();
let bucket = buckets.entry(id.to_string()).or_insert_with(|| {
TokenBucket::new(self.config.max_requests, self.config.window)
});
bucket.try_acquire()
} else {
self.global_bucket.lock().try_acquire()
}
} else {
self.global_bucket.lock().try_acquire()
}
}
pub fn cleanup(&self, max_age: Duration) {
let mut buckets = self.client_buckets.lock();
let now = Instant::now();
buckets.retain(|_, bucket| now.duration_since(bucket.last_access) < max_age);
}
#[must_use]
pub fn client_bucket_count(&self) -> usize {
self.client_buckets.lock().len()
}
}
#[derive(Debug)]
struct TokenBucket {
tokens: f64,
max_tokens: f64,
refill_rate: f64, last_refill: Instant,
last_access: Instant,
}
impl TokenBucket {
fn new(max_requests: u32, window: Duration) -> Self {
let max_tokens = max_requests as f64;
let refill_rate = max_tokens / window.as_secs_f64();
Self {
tokens: max_tokens,
max_tokens,
refill_rate,
last_refill: Instant::now(),
last_access: Instant::now(),
}
}
fn try_acquire(&mut self) -> bool {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill);
if elapsed >= Duration::from_millis(10) {
self.tokens =
(self.tokens + elapsed.as_secs_f64() * self.refill_rate).min(self.max_tokens);
self.last_refill = now;
}
self.last_access = now;
if self.tokens >= 1.0 {
self.tokens -= 1.0;
true
} else {
false
}
}
}
#[derive(Debug)]
pub struct ConnectionCounter {
current: AtomicUsize,
max: usize,
}
impl ConnectionCounter {
#[must_use]
pub fn new(max: usize) -> Self {
Self {
current: AtomicUsize::new(0),
max,
}
}
pub fn try_acquire_arc(self: &Arc<Self>) -> Option<ConnectionGuard> {
for _ in 0..1000 {
let current = self.current.load(Ordering::Relaxed);
if current >= self.max {
return None;
}
if self
.current
.compare_exchange(current, current + 1, Ordering::SeqCst, Ordering::Relaxed)
.is_ok()
{
return Some(ConnectionGuard {
counter: Arc::clone(self),
});
}
std::hint::spin_loop();
}
tracing::error!(
"ConnectionCounter CAS loop exceeded 1000 iterations - possible contention bug"
);
None
}
#[must_use]
pub fn current(&self) -> usize {
self.current.load(Ordering::Relaxed)
}
#[must_use]
pub fn max(&self) -> usize {
self.max
}
fn release(&self) {
self.current.fetch_sub(1, Ordering::SeqCst);
}
}
#[derive(Debug)]
pub struct ConnectionGuard {
counter: Arc<ConnectionCounter>,
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
self.counter.release();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_protocol_negotiation_exact_match() {
let config = ProtocolConfig::default();
assert_eq!(
config.negotiate(Some("2025-11-25")),
Some(ProtocolVersion::V2025_11_25)
);
}
#[test]
fn test_protocol_negotiation_default_rejects_older_version() {
let config = ProtocolConfig::default();
assert_eq!(config.negotiate(Some("2025-06-18")), None);
}
#[test]
fn test_protocol_negotiation_multi_version_accepts_older() {
let config = ProtocolConfig::multi_version();
assert_eq!(
config.negotiate(Some("2025-06-18")),
Some(ProtocolVersion::V2025_06_18)
);
assert_eq!(
config.negotiate(Some("2025-11-25")),
Some(ProtocolVersion::V2025_11_25)
);
}
#[test]
fn test_protocol_negotiation_none_returns_preferred() {
let config = ProtocolConfig::default();
assert_eq!(config.negotiate(None), Some(ProtocolVersion::V2025_11_25));
}
#[test]
fn test_protocol_negotiation_unknown_version() {
let config = ProtocolConfig::default();
assert_eq!(config.negotiate(Some("unknown-version")), None);
}
#[test]
fn test_protocol_negotiation_strict() {
let config = ProtocolConfig::strict("2025-11-25");
assert_eq!(config.negotiate(Some("2025-06-18")), None);
}
#[test]
fn test_capability_validation() {
let required = RequiredCapabilities::none().with_roots();
let client = ClientCapabilities {
roots: true,
..Default::default()
};
assert!(required.validate(&client).is_valid());
let client_missing = ClientCapabilities::default();
assert!(!required.validate(&client_missing).is_valid());
}
#[test]
fn test_rate_limiter() {
let config = RateLimitConfig::new(2, Duration::from_secs(1));
let limiter = RateLimiter::new(config);
assert!(limiter.check(None));
assert!(limiter.check(None));
assert!(!limiter.check(None)); }
#[test]
fn test_connection_counter() {
let counter = Arc::new(ConnectionCounter::new(2));
let guard1 = counter.try_acquire_arc();
assert!(guard1.is_some());
assert_eq!(counter.current(), 1);
let guard2 = counter.try_acquire_arc();
assert!(guard2.is_some());
assert_eq!(counter.current(), 2);
let guard3 = counter.try_acquire_arc();
assert!(guard3.is_none());
drop(guard1);
assert_eq!(counter.current(), 1);
let guard4 = counter.try_acquire_arc();
assert!(guard4.is_some());
}
#[test]
fn test_builder_default_succeeds() {
let config = ServerConfig::builder().build();
assert_eq!(config.max_message_size, DEFAULT_MAX_MESSAGE_SIZE);
}
#[test]
fn test_builder_try_build_valid() {
let result = ServerConfig::builder()
.max_message_size(1024 * 1024)
.try_build();
assert!(result.is_ok());
}
#[test]
fn test_builder_try_build_invalid_message_size() {
let result = ServerConfig::builder()
.max_message_size(100) .try_build();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ConfigValidationError::InvalidMessageSize { .. }
));
}
#[test]
fn test_builder_try_build_invalid_rate_limit() {
let result = ServerConfig::builder()
.rate_limit(RateLimitConfig {
max_requests: 0, window: Duration::from_secs(1),
per_client: true,
})
.try_build();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ConfigValidationError::InvalidRateLimit { .. }
));
}
#[test]
fn test_builder_try_build_zero_window() {
let result = ServerConfig::builder()
.rate_limit(RateLimitConfig {
max_requests: 100,
window: Duration::ZERO, per_client: true,
})
.try_build();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ConfigValidationError::InvalidRateLimit { .. }
));
}
#[test]
fn test_builder_try_build_invalid_connection_limits() {
let result = ServerConfig::builder()
.connection_limits(ConnectionLimits {
max_tcp_connections: 0,
max_websocket_connections: 0,
max_http_concurrent: 0,
max_unix_connections: 0,
})
.try_build();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ConfigValidationError::InvalidConnectionLimits { .. }
));
}
}
#[cfg(test)]
mod proptest_tests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn config_builder_never_panics(
max_msg_size in 0usize..10_000_000,
) {
let _ = ServerConfig::builder()
.max_message_size(max_msg_size)
.try_build();
}
#[test]
fn connection_counter_bounded(max in 1usize..10000) {
let counter = Arc::new(ConnectionCounter::new(max));
let mut guards = Vec::new();
for _ in 0..max + 10 {
if let Some(guard) = counter.try_acquire_arc() {
guards.push(guard);
}
}
assert_eq!(guards.len(), max);
assert_eq!(counter.current(), max);
}
}
}