1use std::collections::HashSet;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::time::{Duration, Instant};
13
14use parking_lot::Mutex;
15use serde::{Deserialize, Serialize};
16
17pub use turbomcp_core::SUPPORTED_VERSIONS as SUPPORTED_PROTOCOL_VERSIONS;
19pub use turbomcp_types::ProtocolVersion;
20
21pub const DEFAULT_MAX_CONNECTIONS: usize = 1000;
23
24pub const DEFAULT_RATE_LIMIT: u32 = 100;
26
27pub const DEFAULT_RATE_LIMIT_WINDOW: Duration = Duration::from_secs(1);
29
30pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
32
33#[derive(Debug, Clone)]
35pub struct OriginValidationConfig {
36 pub allowed_origins: HashSet<String>,
38 pub allow_localhost: bool,
40 pub allow_any: bool,
42 pub trusted_proxies: Vec<String>,
48}
49
50impl Default for OriginValidationConfig {
51 fn default() -> Self {
52 Self {
53 allowed_origins: HashSet::new(),
54 allow_localhost: true,
55 allow_any: false,
56 trusted_proxies: Vec::new(),
57 }
58 }
59}
60
61impl OriginValidationConfig {
62 #[must_use]
64 pub fn new() -> Self {
65 Self::default()
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct ServerConfig {
72 pub protocol: ProtocolConfig,
74 pub rate_limit: Option<RateLimitConfig>,
76 pub connection_limits: ConnectionLimits,
78 pub required_capabilities: RequiredCapabilities,
80 pub max_message_size: usize,
82 pub origin_validation: OriginValidationConfig,
84}
85
86impl Default for ServerConfig {
87 fn default() -> Self {
88 Self {
89 protocol: ProtocolConfig::default(),
90 rate_limit: None,
91 connection_limits: ConnectionLimits::default(),
92 required_capabilities: RequiredCapabilities::default(),
93 max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
94 origin_validation: OriginValidationConfig::default(),
95 }
96 }
97}
98
99impl ServerConfig {
100 #[must_use]
102 pub fn new() -> Self {
103 Self::default()
104 }
105
106 #[must_use]
108 pub fn builder() -> ServerConfigBuilder {
109 ServerConfigBuilder::default()
110 }
111}
112
113#[derive(Debug, Clone, Default)]
115pub struct ServerConfigBuilder {
116 protocol: Option<ProtocolConfig>,
117 rate_limit: Option<RateLimitConfig>,
118 connection_limits: Option<ConnectionLimits>,
119 required_capabilities: Option<RequiredCapabilities>,
120 max_message_size: Option<usize>,
121 origin_validation: Option<OriginValidationConfig>,
122}
123
124impl ServerConfigBuilder {
125 #[must_use]
127 pub fn protocol(mut self, config: ProtocolConfig) -> Self {
128 self.protocol = Some(config);
129 self
130 }
131
132 #[must_use]
134 pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
135 self.rate_limit = Some(config);
136 self
137 }
138
139 #[must_use]
141 pub fn connection_limits(mut self, limits: ConnectionLimits) -> Self {
142 self.connection_limits = Some(limits);
143 self
144 }
145
146 #[must_use]
148 pub fn required_capabilities(mut self, caps: RequiredCapabilities) -> Self {
149 self.required_capabilities = Some(caps);
150 self
151 }
152
153 #[must_use]
158 pub fn max_message_size(mut self, size: usize) -> Self {
159 self.max_message_size = Some(size);
160 self
161 }
162
163 #[must_use]
165 pub fn origin_validation(mut self, config: OriginValidationConfig) -> Self {
166 self.origin_validation = Some(config);
167 self
168 }
169
170 #[must_use]
172 pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
173 self.origin_validation
174 .get_or_insert_with(OriginValidationConfig::default)
175 .allowed_origins
176 .insert(origin.into());
177 self
178 }
179
180 #[must_use]
182 pub fn allow_origins<I, S>(mut self, origins: I) -> Self
183 where
184 I: IntoIterator<Item = S>,
185 S: Into<String>,
186 {
187 let config = self
188 .origin_validation
189 .get_or_insert_with(OriginValidationConfig::default);
190 config
191 .allowed_origins
192 .extend(origins.into_iter().map(Into::into));
193 self
194 }
195
196 #[must_use]
198 pub fn allow_localhost_origins(mut self, allow: bool) -> Self {
199 self.origin_validation
200 .get_or_insert_with(OriginValidationConfig::default)
201 .allow_localhost = allow;
202 self
203 }
204
205 #[must_use]
207 pub fn allow_any_origin(mut self, allow: bool) -> Self {
208 self.origin_validation
209 .get_or_insert_with(OriginValidationConfig::default)
210 .allow_any = allow;
211 self
212 }
213
214 #[must_use]
219 pub fn build(self) -> ServerConfig {
220 ServerConfig {
221 protocol: self.protocol.unwrap_or_default(),
222 rate_limit: self.rate_limit,
223 connection_limits: self.connection_limits.unwrap_or_default(),
224 required_capabilities: self.required_capabilities.unwrap_or_default(),
225 max_message_size: self.max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE),
226 origin_validation: self.origin_validation.unwrap_or_default(),
227 }
228 }
229
230 pub fn try_build(self) -> Result<ServerConfig, ConfigValidationError> {
256 let max_message_size = self.max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE);
257
258 if max_message_size < 1024 {
260 return Err(ConfigValidationError::InvalidMessageSize {
261 size: max_message_size,
262 min: 1024,
263 });
264 }
265
266 if let Some(ref rate_limit) = self.rate_limit {
268 if rate_limit.max_requests == 0 {
269 return Err(ConfigValidationError::InvalidRateLimit {
270 reason: "max_requests cannot be 0".to_string(),
271 });
272 }
273 if rate_limit.window.is_zero() {
274 return Err(ConfigValidationError::InvalidRateLimit {
275 reason: "rate limit window cannot be zero".to_string(),
276 });
277 }
278 }
279
280 let connection_limits = self.connection_limits.unwrap_or_default();
282 if connection_limits.max_tcp_connections == 0
283 && connection_limits.max_websocket_connections == 0
284 && connection_limits.max_http_concurrent == 0
285 && connection_limits.max_unix_connections == 0
286 {
287 return Err(ConfigValidationError::InvalidConnectionLimits {
288 reason: "at least one connection limit must be non-zero".to_string(),
289 });
290 }
291
292 Ok(ServerConfig {
293 protocol: self.protocol.unwrap_or_default(),
294 rate_limit: self.rate_limit,
295 connection_limits,
296 required_capabilities: self.required_capabilities.unwrap_or_default(),
297 max_message_size,
298 origin_validation: self.origin_validation.unwrap_or_default(),
299 })
300 }
301}
302
303#[derive(Debug, Clone, thiserror::Error)]
305pub enum ConfigValidationError {
306 #[error("Invalid max_message_size: {size} bytes is below minimum of {min} bytes")]
308 InvalidMessageSize {
309 size: usize,
311 min: usize,
313 },
314
315 #[error("Invalid rate limit: {reason}")]
317 InvalidRateLimit {
318 reason: String,
320 },
321
322 #[error("Invalid connection limits: {reason}")]
324 InvalidConnectionLimits {
325 reason: String,
327 },
328}
329
330#[derive(Debug, Clone)]
332pub struct ProtocolConfig {
333 pub preferred_version: ProtocolVersion,
335 pub supported_versions: Vec<ProtocolVersion>,
337 pub allow_fallback: bool,
339}
340
341impl Default for ProtocolConfig {
342 fn default() -> Self {
351 Self {
352 preferred_version: ProtocolVersion::LATEST.clone(),
353 supported_versions: ProtocolVersion::STABLE.to_vec(),
354 allow_fallback: false,
355 }
356 }
357}
358
359impl ProtocolConfig {
360 #[must_use]
362 pub fn strict(version: impl Into<ProtocolVersion>) -> Self {
363 let v = version.into();
364 Self {
365 preferred_version: v.clone(),
366 supported_versions: vec![v],
367 allow_fallback: false,
368 }
369 }
370
371 #[must_use]
376 pub fn multi_version() -> Self {
377 Self {
378 preferred_version: ProtocolVersion::LATEST.clone(),
379 supported_versions: ProtocolVersion::STABLE.to_vec(),
380 allow_fallback: false,
381 }
382 }
383
384 #[must_use]
386 pub fn is_supported(&self, version: &ProtocolVersion) -> bool {
387 self.supported_versions.contains(version)
388 }
389
390 #[must_use]
394 pub fn negotiate(&self, client_version: Option<&str>) -> Option<ProtocolVersion> {
395 match client_version {
396 Some(version_str) => {
397 let version = ProtocolVersion::from(version_str);
398 if self.is_supported(&version) {
399 Some(version)
400 } else if self.allow_fallback {
401 Some(self.preferred_version.clone())
402 } else {
403 None
404 }
405 }
406 None => Some(self.preferred_version.clone()),
407 }
408 }
409}
410
411#[derive(Debug, Clone)]
413pub struct RateLimitConfig {
414 pub max_requests: u32,
416 pub window: Duration,
418 pub per_client: bool,
420}
421
422impl Default for RateLimitConfig {
423 fn default() -> Self {
424 Self {
425 max_requests: DEFAULT_RATE_LIMIT,
426 window: DEFAULT_RATE_LIMIT_WINDOW,
427 per_client: true,
428 }
429 }
430}
431
432impl RateLimitConfig {
433 #[must_use]
435 pub fn new(max_requests: u32, window: Duration) -> Self {
436 Self {
437 max_requests,
438 window,
439 per_client: true,
440 }
441 }
442
443 #[must_use]
445 pub fn per_client(mut self, enabled: bool) -> Self {
446 self.per_client = enabled;
447 self
448 }
449}
450
451#[derive(Debug, Clone)]
453pub struct ConnectionLimits {
454 pub max_tcp_connections: usize,
456 pub max_websocket_connections: usize,
458 pub max_http_concurrent: usize,
460 pub max_unix_connections: usize,
462}
463
464impl Default for ConnectionLimits {
465 fn default() -> Self {
466 Self {
467 max_tcp_connections: DEFAULT_MAX_CONNECTIONS,
468 max_websocket_connections: DEFAULT_MAX_CONNECTIONS,
469 max_http_concurrent: DEFAULT_MAX_CONNECTIONS,
470 max_unix_connections: DEFAULT_MAX_CONNECTIONS,
471 }
472 }
473}
474
475impl ConnectionLimits {
476 #[must_use]
478 pub fn new(max_connections: usize) -> Self {
479 Self {
480 max_tcp_connections: max_connections,
481 max_websocket_connections: max_connections,
482 max_http_concurrent: max_connections,
483 max_unix_connections: max_connections,
484 }
485 }
486}
487
488#[derive(Debug, Clone, Default, Serialize, Deserialize)]
492pub struct RequiredCapabilities {
493 #[serde(default)]
495 pub roots: bool,
496 #[serde(default)]
498 pub sampling: bool,
499 #[serde(default)]
501 pub extensions: HashSet<String>,
502 #[serde(default)]
504 pub experimental: HashSet<String>,
505}
506
507impl RequiredCapabilities {
508 #[must_use]
510 pub fn none() -> Self {
511 Self::default()
512 }
513
514 #[must_use]
516 pub fn with_roots(mut self) -> Self {
517 self.roots = true;
518 self
519 }
520
521 #[must_use]
523 pub fn with_sampling(mut self) -> Self {
524 self.sampling = true;
525 self
526 }
527
528 #[must_use]
530 pub fn with_extension(mut self, name: impl Into<String>) -> Self {
531 self.extensions.insert(name.into());
532 self
533 }
534
535 #[must_use]
537 pub fn with_experimental(mut self, name: impl Into<String>) -> Self {
538 self.experimental.insert(name.into());
539 self
540 }
541
542 #[must_use]
544 pub fn validate(&self, client_caps: &ClientCapabilities) -> CapabilityValidation {
545 let mut missing = Vec::new();
546
547 if self.roots && !client_caps.roots {
548 missing.push("roots".to_string());
549 }
550
551 if self.sampling && !client_caps.sampling {
552 missing.push("sampling".to_string());
553 }
554
555 for extension in &self.extensions {
556 if !client_caps.extensions.contains(extension) {
557 missing.push(format!("extensions/{}", extension));
558 }
559 }
560
561 for exp in &self.experimental {
562 if !client_caps.experimental.contains(exp) {
563 missing.push(format!("experimental/{}", exp));
564 }
565 }
566
567 if missing.is_empty() {
568 CapabilityValidation::Valid
569 } else {
570 CapabilityValidation::Missing(missing)
571 }
572 }
573}
574
575#[derive(Debug, Clone, Default, Serialize, Deserialize)]
577pub struct ClientCapabilities {
578 #[serde(default)]
580 pub roots: bool,
581 #[serde(default)]
583 pub sampling: bool,
584 #[serde(default)]
586 pub extensions: HashSet<String>,
587 #[serde(default)]
589 pub experimental: HashSet<String>,
590}
591
592impl ClientCapabilities {
593 #[must_use]
595 pub fn from_params(params: &serde_json::Value) -> Self {
596 let caps = params.get("capabilities").cloned().unwrap_or_default();
597
598 Self {
599 roots: caps.get("roots").map(|v| !v.is_null()).unwrap_or(false),
600 sampling: caps.get("sampling").map(|v| !v.is_null()).unwrap_or(false),
601 extensions: caps
602 .get("extensions")
603 .and_then(|v| v.as_object())
604 .map(|obj| obj.keys().cloned().collect())
605 .unwrap_or_default(),
606 experimental: caps
607 .get("experimental")
608 .and_then(|v| v.as_object())
609 .map(|obj| obj.keys().cloned().collect())
610 .unwrap_or_default(),
611 }
612 }
613}
614
615#[derive(Debug, Clone)]
617pub enum CapabilityValidation {
618 Valid,
620 Missing(Vec<String>),
622}
623
624impl CapabilityValidation {
625 #[must_use]
627 pub fn is_valid(&self) -> bool {
628 matches!(self, Self::Valid)
629 }
630
631 #[must_use]
633 pub fn missing(&self) -> Option<&[String]> {
634 match self {
635 Self::Valid => None,
636 Self::Missing(caps) => Some(caps),
637 }
638 }
639}
640
641#[derive(Debug)]
643pub struct RateLimiter {
644 config: RateLimitConfig,
645 global_bucket: Mutex<TokenBucket>,
647 client_buckets: Mutex<std::collections::HashMap<String, TokenBucket>>,
649 last_cleanup: Mutex<Instant>,
651}
652
653impl RateLimiter {
654 #[must_use]
656 pub fn new(config: RateLimitConfig) -> Self {
657 Self {
658 global_bucket: Mutex::new(TokenBucket::new(config.max_requests, config.window)),
659 client_buckets: Mutex::new(std::collections::HashMap::new()),
660 last_cleanup: Mutex::new(Instant::now()),
661 config,
662 }
663 }
664
665 pub fn check(&self, client_id: Option<&str>) -> bool {
669 let needs_cleanup = {
671 let last = self.last_cleanup.lock();
672 last.elapsed() > Duration::from_secs(60)
673 };
674 if needs_cleanup {
675 self.cleanup(Duration::from_secs(300));
676 *self.last_cleanup.lock() = Instant::now();
677 }
678
679 if self.config.per_client {
680 if let Some(id) = client_id {
681 let mut buckets = self.client_buckets.lock();
682 let bucket = buckets.entry(id.to_string()).or_insert_with(|| {
683 TokenBucket::new(self.config.max_requests, self.config.window)
684 });
685 bucket.try_acquire()
686 } else {
687 self.global_bucket.lock().try_acquire()
689 }
690 } else {
691 self.global_bucket.lock().try_acquire()
692 }
693 }
694
695 pub fn cleanup(&self, max_age: Duration) {
697 let mut buckets = self.client_buckets.lock();
698 let now = Instant::now();
699 buckets.retain(|_, bucket| now.duration_since(bucket.last_access) < max_age);
700 }
701
702 #[must_use]
704 pub fn client_bucket_count(&self) -> usize {
705 self.client_buckets.lock().len()
706 }
707}
708
709#[derive(Debug)]
711struct TokenBucket {
712 tokens: f64,
713 max_tokens: f64,
714 refill_rate: f64, last_refill: Instant,
716 last_access: Instant,
717}
718
719impl TokenBucket {
720 fn new(max_requests: u32, window: Duration) -> Self {
721 let max_tokens = max_requests as f64;
722 let refill_rate = max_tokens / window.as_secs_f64();
723 Self {
724 tokens: max_tokens,
725 max_tokens,
726 refill_rate,
727 last_refill: Instant::now(),
728 last_access: Instant::now(),
729 }
730 }
731
732 fn try_acquire(&mut self) -> bool {
733 let now = Instant::now();
734 let elapsed = now.duration_since(self.last_refill);
735
736 if elapsed >= Duration::from_millis(10) {
738 self.tokens =
739 (self.tokens + elapsed.as_secs_f64() * self.refill_rate).min(self.max_tokens);
740 self.last_refill = now;
741 }
742
743 self.last_access = now;
744
745 if self.tokens >= 1.0 {
746 self.tokens -= 1.0;
747 true
748 } else {
749 false
750 }
751 }
752}
753
754#[derive(Debug)]
759pub struct ConnectionCounter {
760 current: AtomicUsize,
761 max: usize,
762}
763
764impl ConnectionCounter {
765 #[must_use]
767 pub fn new(max: usize) -> Self {
768 Self {
769 current: AtomicUsize::new(0),
770 max,
771 }
772 }
773
774 pub fn try_acquire_arc(self: &Arc<Self>) -> Option<ConnectionGuard> {
784 loop {
785 let current = self.current.load(Ordering::Relaxed);
786 if current >= self.max {
787 return None;
788 }
789 if self
790 .current
791 .compare_exchange(current, current + 1, Ordering::SeqCst, Ordering::Relaxed)
792 .is_ok()
793 {
794 return Some(ConnectionGuard {
795 counter: Arc::clone(self),
796 });
797 }
798 std::hint::spin_loop();
800 }
801 }
802
803 #[must_use]
805 pub fn current(&self) -> usize {
806 self.current.load(Ordering::Relaxed)
807 }
808
809 #[must_use]
811 pub fn max(&self) -> usize {
812 self.max
813 }
814
815 fn release(&self) {
816 self.current.fetch_sub(1, Ordering::SeqCst);
817 }
818}
819
820#[derive(Debug)]
824pub struct ConnectionGuard {
825 counter: Arc<ConnectionCounter>,
826}
827
828impl Drop for ConnectionGuard {
829 fn drop(&mut self) {
830 self.counter.release();
831 }
832}
833
834#[cfg(test)]
835mod tests {
836 use super::*;
837
838 #[test]
839 fn test_protocol_negotiation_exact_match() {
840 let config = ProtocolConfig::default();
841 assert_eq!(
842 config.negotiate(Some("2025-11-25")),
843 Some(ProtocolVersion::V2025_11_25)
844 );
845 }
846
847 #[test]
848 fn test_protocol_negotiation_default_accepts_stable_versions() {
849 let config = ProtocolConfig::default();
853 assert_eq!(
854 config.negotiate(Some("2025-06-18")),
855 Some(ProtocolVersion::V2025_06_18)
856 );
857 }
858
859 #[test]
860 fn test_protocol_negotiation_strict_rejects_older_version() {
861 let config = ProtocolConfig::strict(ProtocolVersion::LATEST.clone());
862 assert_eq!(config.negotiate(Some("2025-06-18")), None);
863 }
864
865 #[test]
866 fn test_protocol_negotiation_multi_version_accepts_older() {
867 let config = ProtocolConfig::multi_version();
868 assert_eq!(
869 config.negotiate(Some("2025-06-18")),
870 Some(ProtocolVersion::V2025_06_18)
871 );
872 assert_eq!(
873 config.negotiate(Some("2025-11-25")),
874 Some(ProtocolVersion::V2025_11_25)
875 );
876 }
877
878 #[test]
879 fn test_protocol_negotiation_none_returns_preferred() {
880 let config = ProtocolConfig::default();
881 assert_eq!(config.negotiate(None), Some(ProtocolVersion::V2025_11_25));
882 }
883
884 #[test]
885 fn test_protocol_negotiation_unknown_version() {
886 let config = ProtocolConfig::default();
887 assert_eq!(config.negotiate(Some("unknown-version")), None);
888 }
889
890 #[test]
891 fn test_protocol_negotiation_strict() {
892 let config = ProtocolConfig::strict("2025-11-25");
893 assert_eq!(config.negotiate(Some("2025-06-18")), None);
894 }
895
896 #[test]
897 fn test_capability_validation() {
898 let required = RequiredCapabilities::none().with_roots();
899 let client = ClientCapabilities {
900 roots: true,
901 ..Default::default()
902 };
903 assert!(required.validate(&client).is_valid());
904
905 let client_missing = ClientCapabilities::default();
906 assert!(!required.validate(&client_missing).is_valid());
907 }
908
909 #[test]
910 fn test_extension_capability_validation() {
911 let required = RequiredCapabilities::none().with_extension("trace");
912 let client = ClientCapabilities {
913 extensions: ["trace".to_string()].into_iter().collect(),
914 ..Default::default()
915 };
916 assert!(required.validate(&client).is_valid());
917
918 let missing = ClientCapabilities::default();
919 let validation = required.validate(&missing);
920 assert!(!validation.is_valid());
921 assert_eq!(
922 validation.missing(),
923 Some(&["extensions/trace".to_string()][..])
924 );
925 }
926
927 #[test]
928 fn test_client_capabilities_parse_extensions() {
929 let params = serde_json::json!({
930 "capabilities": {
931 "extensions": {
932 "trace": {"version": "1"},
933 "handoff": {}
934 }
935 }
936 });
937
938 let caps = ClientCapabilities::from_params(¶ms);
939 assert!(caps.extensions.contains("trace"));
940 assert!(caps.extensions.contains("handoff"));
941 }
942
943 #[test]
944 fn test_rate_limiter() {
945 let config = RateLimitConfig::new(2, Duration::from_secs(1));
946 let limiter = RateLimiter::new(config);
947
948 assert!(limiter.check(None));
949 assert!(limiter.check(None));
950 assert!(!limiter.check(None)); }
952
953 #[test]
954 fn test_connection_counter() {
955 let counter = Arc::new(ConnectionCounter::new(2));
956
957 let guard1 = counter.try_acquire_arc();
958 assert!(guard1.is_some());
959 assert_eq!(counter.current(), 1);
960
961 let guard2 = counter.try_acquire_arc();
962 assert!(guard2.is_some());
963 assert_eq!(counter.current(), 2);
964
965 let guard3 = counter.try_acquire_arc();
966 assert!(guard3.is_none()); drop(guard1);
969 assert_eq!(counter.current(), 1);
970
971 let guard4 = counter.try_acquire_arc();
972 assert!(guard4.is_some());
973 }
974
975 #[test]
980 fn test_builder_default_succeeds() {
981 let config = ServerConfig::builder().build();
983 assert_eq!(config.max_message_size, DEFAULT_MAX_MESSAGE_SIZE);
984 assert!(config.origin_validation.allow_localhost);
985 assert!(config.origin_validation.allowed_origins.is_empty());
986 }
987
988 #[test]
989 fn test_builder_origin_validation_overrides() {
990 let config = ServerConfig::builder()
991 .allow_origin("https://app.example.com")
992 .allow_localhost_origins(false)
993 .build();
994
995 assert!(!config.origin_validation.allow_localhost);
996 assert!(
997 config
998 .origin_validation
999 .allowed_origins
1000 .contains("https://app.example.com")
1001 );
1002 }
1003
1004 #[test]
1005 fn test_builder_try_build_valid() {
1006 let result = ServerConfig::builder()
1007 .max_message_size(1024 * 1024)
1008 .try_build();
1009 assert!(result.is_ok());
1010 }
1011
1012 #[test]
1013 fn test_builder_try_build_invalid_message_size() {
1014 let result = ServerConfig::builder()
1015 .max_message_size(100) .try_build();
1017 assert!(result.is_err());
1018 assert!(matches!(
1019 result.unwrap_err(),
1020 ConfigValidationError::InvalidMessageSize { .. }
1021 ));
1022 }
1023
1024 #[test]
1025 fn test_builder_try_build_invalid_rate_limit() {
1026 let result = ServerConfig::builder()
1027 .rate_limit(RateLimitConfig {
1028 max_requests: 0, window: Duration::from_secs(1),
1030 per_client: true,
1031 })
1032 .try_build();
1033 assert!(result.is_err());
1034 assert!(matches!(
1035 result.unwrap_err(),
1036 ConfigValidationError::InvalidRateLimit { .. }
1037 ));
1038 }
1039
1040 #[test]
1041 fn test_builder_try_build_zero_window() {
1042 let result = ServerConfig::builder()
1043 .rate_limit(RateLimitConfig {
1044 max_requests: 100,
1045 window: Duration::ZERO, per_client: true,
1047 })
1048 .try_build();
1049 assert!(result.is_err());
1050 assert!(matches!(
1051 result.unwrap_err(),
1052 ConfigValidationError::InvalidRateLimit { .. }
1053 ));
1054 }
1055
1056 #[test]
1057 fn test_builder_try_build_invalid_connection_limits() {
1058 let result = ServerConfig::builder()
1059 .connection_limits(ConnectionLimits {
1060 max_tcp_connections: 0,
1061 max_websocket_connections: 0,
1062 max_http_concurrent: 0,
1063 max_unix_connections: 0,
1064 })
1065 .try_build();
1066 assert!(result.is_err());
1067 assert!(matches!(
1068 result.unwrap_err(),
1069 ConfigValidationError::InvalidConnectionLimits { .. }
1070 ));
1071 }
1072}
1073
1074#[cfg(test)]
1075mod proptest_tests {
1076 use super::*;
1077 use proptest::prelude::*;
1078
1079 proptest! {
1080 #[test]
1081 fn config_builder_never_panics(
1082 max_msg_size in 0usize..10_000_000,
1083 ) {
1084 let _ = ServerConfig::builder()
1086 .max_message_size(max_msg_size)
1087 .try_build();
1088 }
1089
1090 #[test]
1091 fn connection_counter_bounded(max in 1usize..10000) {
1092 let counter = Arc::new(ConnectionCounter::new(max));
1093 let mut guards = Vec::new();
1094 for _ in 0..max + 10 {
1096 if let Some(guard) = counter.try_acquire_arc() {
1097 guards.push(guard);
1098 }
1099 }
1100 assert_eq!(guards.len(), max);
1101 assert_eq!(counter.current(), max);
1102 }
1103 }
1104}