1use crate::resource_limits::{
24 ConcurrencyConfig, ConcurrencyLimits, LimitEnforcer, MemoryAwareCache, PerClientRateLimiter,
25 RateLimitConfig, RequestSizeLimits, ResponseSizeLimits, TimeoutLimits,
26};
27use serde::{Deserialize, Serialize};
28use std::path::Path;
29use std::time::Duration;
30use tracing::info;
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ResourceLimitConfig {
35 #[serde(default = "default_enabled")]
37 pub enabled: bool,
38
39 #[serde(default)]
41 pub request: RequestSizeConfig,
42
43 #[serde(default)]
45 pub response: ResponseSizeConfig,
46
47 #[serde(default)]
49 pub rate_limiting: RateLimitingConfig,
50
51 #[serde(default)]
53 pub memory: MemoryConfig,
54
55 #[serde(default)]
57 pub concurrency: ConcurrencyConfigTOML,
58
59 #[serde(default)]
61 pub timeouts: TimeoutConfig,
62}
63
64impl Default for ResourceLimitConfig {
65 fn default() -> Self {
66 Self {
67 enabled: default_enabled(),
68 request: RequestSizeConfig::default(),
69 response: ResponseSizeConfig::default(),
70 rate_limiting: RateLimitingConfig::default(),
71 memory: MemoryConfig::default(),
72 concurrency: ConcurrencyConfigTOML::default(),
73 timeouts: TimeoutConfig::default(),
74 }
75 }
76}
77
78fn default_enabled() -> bool {
79 true
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct RequestSizeConfig {
85 #[serde(default = "default_max_request_size")]
87 pub max_total_size_bytes: usize,
88
89 #[serde(default = "default_max_param_size")]
91 pub max_param_size_bytes: usize,
92
93 #[serde(default = "default_max_array_elements")]
95 pub max_array_elements: usize,
96
97 #[serde(default = "default_max_object_depth")]
99 pub max_object_depth: usize,
100}
101
102impl Default for RequestSizeConfig {
103 fn default() -> Self {
104 Self {
105 max_total_size_bytes: default_max_request_size(),
106 max_param_size_bytes: default_max_param_size(),
107 max_array_elements: default_max_array_elements(),
108 max_object_depth: default_max_object_depth(),
109 }
110 }
111}
112
113fn default_max_request_size() -> usize {
114 10_485_760 }
116
117fn default_max_param_size() -> usize {
118 5_242_880 }
120
121fn default_max_array_elements() -> usize {
122 10_000
123}
124
125fn default_max_object_depth() -> usize {
126 32
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct ResponseSizeConfig {
132 #[serde(default = "default_max_response_size")]
134 pub max_total_size_bytes: usize,
135
136 #[serde(default = "default_max_result_items")]
138 pub max_result_items: usize,
139
140 #[serde(default = "default_enable_streaming")]
142 pub enable_streaming: bool,
143}
144
145impl Default for ResponseSizeConfig {
146 fn default() -> Self {
147 Self {
148 max_total_size_bytes: default_max_response_size(),
149 max_result_items: default_max_result_items(),
150 enable_streaming: default_enable_streaming(),
151 }
152 }
153}
154
155fn default_max_response_size() -> usize {
156 50_000_000 }
158
159fn default_max_result_items() -> usize {
160 100_000
161}
162
163fn default_enable_streaming() -> bool {
164 true
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct RateLimitingConfig {
170 #[serde(default = "default_rate_limit_mode")]
172 pub mode: String,
173
174 #[serde(default = "default_burst")]
176 pub default_burst: usize,
177
178 #[serde(default = "default_per_second")]
180 pub default_per_second: usize,
181
182 #[serde(default = "default_cleanup_interval")]
184 pub cleanup_interval_seconds: usize,
185
186 #[serde(default)]
188 pub overrides: Vec<RateLimitOverride>,
189}
190
191impl Default for RateLimitingConfig {
192 fn default() -> Self {
193 Self {
194 mode: default_rate_limit_mode(),
195 default_burst: default_burst(),
196 default_per_second: default_per_second(),
197 cleanup_interval_seconds: default_cleanup_interval(),
198 overrides: Vec::new(),
199 }
200 }
201}
202
203fn default_rate_limit_mode() -> String {
204 "per_client".to_string()
205}
206
207fn default_burst() -> usize {
208 200
209}
210
211fn default_per_second() -> usize {
212 100
213}
214
215fn default_cleanup_interval() -> usize {
216 300 }
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct RateLimitOverride {
222 pub client_pattern: String,
224
225 pub burst: usize,
227
228 pub per_second: usize,
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
234pub struct MemoryConfig {
235 #[serde(default = "default_max_cache_memory")]
237 pub max_cache_memory_bytes: usize,
238
239 #[serde(default = "default_max_operation_memory")]
241 pub max_operation_memory_bytes: usize,
242
243 #[serde(default = "default_enable_memory_tracking")]
245 pub enable_memory_tracking: bool,
246}
247
248impl Default for MemoryConfig {
249 fn default() -> Self {
250 Self {
251 max_cache_memory_bytes: default_max_cache_memory(),
252 max_operation_memory_bytes: default_max_operation_memory(),
253 enable_memory_tracking: default_enable_memory_tracking(),
254 }
255 }
256}
257
258fn default_max_cache_memory() -> usize {
259 104_857_600 }
261
262fn default_max_operation_memory() -> usize {
263 52_428_800 }
265
266fn default_enable_memory_tracking() -> bool {
267 true
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct ConcurrencyConfigTOML {
273 #[serde(default = "default_max_concurrent_requests")]
275 pub max_concurrent_requests: usize,
276
277 #[serde(default = "default_max_concurrent_per_client")]
279 pub max_concurrent_per_client: usize,
280
281 #[serde(default = "default_max_concurrent_per_tool")]
283 pub max_concurrent_per_tool: usize,
284
285 #[serde(default = "default_queue_timeout")]
287 pub queue_timeout_ms: usize,
288}
289
290impl Default for ConcurrencyConfigTOML {
291 fn default() -> Self {
292 Self {
293 max_concurrent_requests: default_max_concurrent_requests(),
294 max_concurrent_per_client: default_max_concurrent_per_client(),
295 max_concurrent_per_tool: default_max_concurrent_per_tool(),
296 queue_timeout_ms: default_queue_timeout(),
297 }
298 }
299}
300
301fn default_max_concurrent_requests() -> usize {
302 100
303}
304
305fn default_max_concurrent_per_client() -> usize {
306 10
307}
308
309fn default_max_concurrent_per_tool() -> usize {
310 50
311}
312
313fn default_queue_timeout() -> usize {
314 5000 }
316
317#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct TimeoutConfig {
320 #[serde(default = "default_timeout")]
322 pub default_timeout_ms: usize,
323
324 #[serde(default)]
326 pub per_tool: std::collections::HashMap<String, usize>,
327}
328
329impl Default for TimeoutConfig {
330 fn default() -> Self {
331 let mut per_tool = std::collections::HashMap::new();
332 per_tool.insert("hedl_validate".to_string(), 5_000);
333 per_tool.insert("hedl_query".to_string(), 10_000);
334 per_tool.insert("hedl_convert_to".to_string(), 60_000);
335 per_tool.insert("hedl_stream".to_string(), 120_000);
336
337 Self {
338 default_timeout_ms: default_timeout(),
339 per_tool,
340 }
341 }
342}
343
344fn default_timeout() -> usize {
345 30_000 }
347
348impl ResourceLimitConfig {
349 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
359 let content = std::fs::read_to_string(path.as_ref()).map_err(|e| ConfigError::Io {
360 path: path.as_ref().display().to_string(),
361 source: e,
362 })?;
363
364 let config: ResourceLimitConfig =
365 toml::from_str(&content).map_err(|e| ConfigError::Parse {
366 path: path.as_ref().display().to_string(),
367 source: e,
368 })?;
369
370 info!(
371 "Loaded resource limit config from {}",
372 path.as_ref().display()
373 );
374
375 Ok(config)
376 }
377
378 pub fn parse_toml(content: &str) -> Result<Self, ConfigError> {
388 let config: ResourceLimitConfig =
389 toml::from_str(content).map_err(|e| ConfigError::Parse {
390 path: "<string>".to_string(),
391 source: e,
392 })?;
393
394 info!("Loaded resource limit config from string");
395
396 Ok(config)
397 }
398
399 #[must_use]
405 pub fn to_manager(&self) -> LimitEnforcer {
406 let request_limits = RequestSizeLimits::new(
408 self.request.max_total_size_bytes,
409 self.request.max_param_size_bytes,
410 self.request.max_array_elements,
411 self.request.max_object_depth,
412 );
413
414 let response_limits = ResponseSizeLimits::new(
416 self.response.max_total_size_bytes,
417 self.response.max_result_items,
418 self.response.enable_streaming,
419 );
420
421 let default_config = RateLimitConfig::new(
423 self.rate_limiting.default_burst,
424 self.rate_limiting.default_per_second,
425 );
426
427 let overrides = self
428 .rate_limiting
429 .overrides
430 .iter()
431 .map(|o| {
432 (
433 o.client_pattern.to_string(),
434 RateLimitConfig::new(o.burst, o.per_second),
435 )
436 })
437 .collect();
438
439 let cleanup_interval =
440 Duration::from_secs(self.rate_limiting.cleanup_interval_seconds as u64);
441 let rate_limiter = PerClientRateLimiter::new(default_config, overrides, cleanup_interval);
442
443 let memory_cache = if self.memory.enable_memory_tracking {
445 Some(MemoryAwareCache::new(self.memory.max_cache_memory_bytes))
446 } else {
447 None
448 };
449
450 let concurrency_config = ConcurrencyConfig {
452 max_concurrent_requests: self.concurrency.max_concurrent_requests,
453 max_concurrent_per_client: self.concurrency.max_concurrent_per_client,
454 max_concurrent_per_tool: self.concurrency.max_concurrent_per_tool,
455 queue_timeout: Duration::from_millis(self.concurrency.queue_timeout_ms as u64),
456 };
457 let concurrency_limits = ConcurrencyLimits::new(concurrency_config);
458
459 let timeout_limits = TimeoutLimits::new(Duration::from_millis(
461 self.timeouts.default_timeout_ms as u64,
462 ));
463 LimitEnforcer::new(
467 request_limits,
468 response_limits,
469 rate_limiter,
470 memory_cache,
471 concurrency_limits,
472 timeout_limits,
473 )
474 }
475
476 pub fn validate(&self) -> Result<(), ConfigError> {
482 if self.request.max_total_size_bytes == 0 {
484 return Err(ConfigError::Validation(
485 "max_total_size_bytes must be greater than 0".to_string(),
486 ));
487 }
488
489 if self.request.max_param_size_bytes > self.request.max_total_size_bytes {
490 return Err(ConfigError::Validation(
491 "max_param_size_bytes cannot exceed max_total_size_bytes".to_string(),
492 ));
493 }
494
495 if self.response.max_total_size_bytes == 0 {
497 return Err(ConfigError::Validation(
498 "max_total_size_bytes (response) must be greater than 0".to_string(),
499 ));
500 }
501
502 if self.rate_limiting.default_burst == 0 {
504 return Err(ConfigError::Validation(
505 "default_burst must be greater than 0".to_string(),
506 ));
507 }
508
509 if self.rate_limiting.default_per_second == 0 {
510 return Err(ConfigError::Validation(
511 "default_per_second must be greater than 0".to_string(),
512 ));
513 }
514
515 if self.concurrency.max_concurrent_requests == 0 {
517 return Err(ConfigError::Validation(
518 "max_concurrent_requests must be greater than 0".to_string(),
519 ));
520 }
521
522 if self.concurrency.max_concurrent_per_client == 0 {
523 return Err(ConfigError::Validation(
524 "max_concurrent_per_client must be greater than 0".to_string(),
525 ));
526 }
527
528 if self.concurrency.max_concurrent_per_tool == 0 {
529 return Err(ConfigError::Validation(
530 "max_concurrent_per_tool must be greater than 0".to_string(),
531 ));
532 }
533
534 if self.timeouts.default_timeout_ms == 0 {
536 return Err(ConfigError::Validation(
537 "default_timeout_ms must be greater than 0".to_string(),
538 ));
539 }
540
541 Ok(())
542 }
543}
544
545#[derive(Debug, thiserror::Error)]
547pub enum ConfigError {
548 #[error("IO error reading config from '{path}': {source}")]
550 Io {
551 path: String,
553 #[source]
555 source: std::io::Error,
556 },
557
558 #[error("Failed to parse TOML from '{path}': {source}")]
560 Parse {
561 path: String,
563 #[source]
565 source: toml::de::Error,
566 },
567
568 #[error("Configuration validation failed: {0}")]
570 Validation(
571 String,
573 ),
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579
580 #[test]
585 fn test_default_config() {
586 let config = ResourceLimitConfig::default();
587 assert!(config.enabled);
588 assert_eq!(config.request.max_total_size_bytes, 10_485_760);
589 assert_eq!(config.response.max_total_size_bytes, 50_000_000);
590 assert_eq!(config.rate_limiting.default_burst, 200);
591 assert_eq!(config.rate_limiting.default_per_second, 100);
592 assert_eq!(config.memory.max_cache_memory_bytes, 104_857_600);
593 assert_eq!(config.concurrency.max_concurrent_requests, 100);
594 assert_eq!(config.timeouts.default_timeout_ms, 30_000);
595 }
596
597 #[test]
598 fn test_request_size_config_default() {
599 let config = RequestSizeConfig::default();
600 assert_eq!(config.max_total_size_bytes, 10_485_760);
601 assert_eq!(config.max_param_size_bytes, 5_242_880);
602 assert_eq!(config.max_array_elements, 10_000);
603 assert_eq!(config.max_object_depth, 32);
604 }
605
606 #[test]
607 fn test_response_size_config_default() {
608 let config = ResponseSizeConfig::default();
609 assert_eq!(config.max_total_size_bytes, 50_000_000);
610 assert_eq!(config.max_result_items, 100_000);
611 assert!(config.enable_streaming);
612 }
613
614 #[test]
615 fn test_rate_limiting_config_default() {
616 let config = RateLimitingConfig::default();
617 assert_eq!(config.mode, "per_client");
618 assert_eq!(config.default_burst, 200);
619 assert_eq!(config.default_per_second, 100);
620 assert_eq!(config.cleanup_interval_seconds, 300);
621 assert!(config.overrides.is_empty());
622 }
623
624 #[test]
625 fn test_memory_config_default() {
626 let config = MemoryConfig::default();
627 assert_eq!(config.max_cache_memory_bytes, 104_857_600);
628 assert_eq!(config.max_operation_memory_bytes, 52_428_800);
629 assert!(config.enable_memory_tracking);
630 }
631
632 #[test]
633 fn test_concurrency_config_default() {
634 let config = ConcurrencyConfigTOML::default();
635 assert_eq!(config.max_concurrent_requests, 100);
636 assert_eq!(config.max_concurrent_per_client, 10);
637 assert_eq!(config.max_concurrent_per_tool, 50);
638 assert_eq!(config.queue_timeout_ms, 5000);
639 }
640
641 #[test]
642 fn test_timeout_config_default() {
643 let config = TimeoutConfig::default();
644 assert_eq!(config.default_timeout_ms, 30_000);
645 assert_eq!(config.per_tool.get("hedl_validate"), Some(&5_000));
646 assert_eq!(config.per_tool.get("hedl_query"), Some(&10_000));
647 assert_eq!(config.per_tool.get("hedl_convert_to"), Some(&60_000));
648 assert_eq!(config.per_tool.get("hedl_stream"), Some(&120_000));
649 }
650
651 #[test]
656 fn test_parse_config_from_str() {
657 let toml_str = r#"
658enabled = true
659
660[request]
661max_total_size_bytes = 2048
662max_param_size_bytes = 1024
663max_array_elements = 100
664max_object_depth = 10
665
666[response]
667max_total_size_bytes = 4096
668max_result_items = 500
669enable_streaming = false
670
671[rate_limiting]
672mode = "per_client"
673default_burst = 50
674default_per_second = 25
675cleanup_interval_seconds = 60
676"#;
677
678 let result = ResourceLimitConfig::parse_toml(toml_str);
679 assert!(result.is_ok(), "Failed to parse: {result:?}");
680 let config = result.unwrap();
681 assert!(config.enabled);
682 assert_eq!(config.request.max_total_size_bytes, 2048);
683 assert_eq!(config.response.max_total_size_bytes, 4096);
684 assert_eq!(config.rate_limiting.default_burst, 50);
685 }
686
687 #[test]
688 fn test_parse_config_with_overrides() {
689 let toml_str = r#"
690[rate_limiting]
691default_burst = 200
692default_per_second = 100
693
694[[rate_limiting.overrides]]
695client_pattern = "premium-*"
696burst = 1000
697per_second = 500
698
699[[rate_limiting.overrides]]
700client_pattern = "free-*"
701burst = 50
702per_second = 10
703"#;
704
705 let result = ResourceLimitConfig::parse_toml(toml_str);
706 assert!(result.is_ok());
707 let config = result.unwrap();
708 assert_eq!(config.rate_limiting.overrides.len(), 2);
709 assert_eq!(
710 config.rate_limiting.overrides[0].client_pattern,
711 "premium-*"
712 );
713 assert_eq!(config.rate_limiting.overrides[0].burst, 1000);
714 assert_eq!(config.rate_limiting.overrides[1].client_pattern, "free-*");
715 assert_eq!(config.rate_limiting.overrides[1].burst, 50);
716 }
717
718 #[test]
719 fn test_parse_invalid_toml() {
720 let invalid_toml = r"
721[resource_limits
722enabled = true
723"; let result = ResourceLimitConfig::parse_toml(invalid_toml);
726 assert!(result.is_err());
727 match result.unwrap_err() {
728 ConfigError::Parse { .. } => {}
729 _ => panic!("Expected ParseError"),
730 }
731 }
732
733 #[test]
738 fn test_validate_valid_config() {
739 let config = ResourceLimitConfig::default();
740 assert!(config.validate().is_ok());
741 }
742
743 #[test]
744 fn test_validate_zero_total_size() {
745 let mut config = ResourceLimitConfig::default();
746 config.request.max_total_size_bytes = 0;
747 let result = config.validate();
748 assert!(result.is_err());
749 match result.unwrap_err() {
750 ConfigError::Validation(msg) => {
751 assert!(msg.contains("max_total_size_bytes"));
752 }
753 _ => panic!("Expected ValidationError"),
754 }
755 }
756
757 #[test]
758 fn test_validate_param_exceeds_total() {
759 let mut config = ResourceLimitConfig::default();
760 config.request.max_param_size_bytes = 20_000_000; config.request.max_total_size_bytes = 10_000_000;
762 let result = config.validate();
763 assert!(result.is_err());
764 match result.unwrap_err() {
765 ConfigError::Validation(msg) => {
766 assert!(msg.contains("max_param_size_bytes"));
767 }
768 _ => panic!("Expected ValidationError"),
769 }
770 }
771
772 #[test]
773 fn test_validate_zero_burst() {
774 let mut config = ResourceLimitConfig::default();
775 config.rate_limiting.default_burst = 0;
776 let result = config.validate();
777 assert!(result.is_err());
778 match result.unwrap_err() {
779 ConfigError::Validation(msg) => {
780 assert!(msg.contains("default_burst"));
781 }
782 _ => panic!("Expected ValidationError"),
783 }
784 }
785
786 #[test]
787 fn test_validate_zero_per_second() {
788 let mut config = ResourceLimitConfig::default();
789 config.rate_limiting.default_per_second = 0;
790 let result = config.validate();
791 assert!(result.is_err());
792 match result.unwrap_err() {
793 ConfigError::Validation(msg) => {
794 assert!(msg.contains("default_per_second"));
795 }
796 _ => panic!("Expected ValidationError"),
797 }
798 }
799
800 #[test]
801 fn test_validate_zero_concurrent_requests() {
802 let mut config = ResourceLimitConfig::default();
803 config.concurrency.max_concurrent_requests = 0;
804 let result = config.validate();
805 assert!(result.is_err());
806 match result.unwrap_err() {
807 ConfigError::Validation(msg) => {
808 assert!(msg.contains("max_concurrent_requests"));
809 }
810 _ => panic!("Expected ValidationError"),
811 }
812 }
813
814 #[test]
815 fn test_validate_zero_timeout() {
816 let mut config = ResourceLimitConfig::default();
817 config.timeouts.default_timeout_ms = 0;
818 let result = config.validate();
819 assert!(result.is_err());
820 match result.unwrap_err() {
821 ConfigError::Validation(msg) => {
822 assert!(msg.contains("default_timeout_ms"));
823 }
824 _ => panic!("Expected ValidationError"),
825 }
826 }
827
828 #[test]
833 fn test_to_manager() {
834 let config = ResourceLimitConfig::default();
835 let manager = config.to_manager();
836 assert!(manager.is_enabled());
837 assert_eq!(manager.request_limits.max_total_size(), 10_485_760);
838 assert_eq!(manager.response_limits.max_total_size(), 50_000_000);
839 }
840
841 #[test]
842 fn test_to_manager_with_memory_tracking() {
843 let mut config = ResourceLimitConfig::default();
844 config.memory.enable_memory_tracking = true;
845 let manager = config.to_manager();
846 assert!(manager.memory_cache.is_some());
847 assert_eq!(
848 manager.memory_cache.as_ref().unwrap().max_size(),
849 104_857_600
850 );
851 }
852
853 #[test]
854 fn test_to_manager_without_memory_tracking() {
855 let mut config = ResourceLimitConfig::default();
856 config.memory.enable_memory_tracking = false;
857 let manager = config.to_manager();
858 assert!(manager.memory_cache.is_none());
859 }
860
861 #[test]
866 fn test_config_error_display() {
867 let err = ConfigError::Validation("test error".to_string());
868 let msg = format!("{err}");
869 assert!(msg.contains("test error"));
870 }
871
872 #[test]
873 fn test_rate_limit_override() {
874 let override_config = RateLimitOverride {
875 client_pattern: "test-*".to_string(),
876 burst: 500,
877 per_second: 250,
878 };
879 assert_eq!(override_config.client_pattern, "test-*");
880 assert_eq!(override_config.burst, 500);
881 assert_eq!(override_config.per_second, 250);
882 }
883}