hedl_mcp/resource_limits/rate_limiter.rs
1// Dweve HEDL - Hierarchical Entity Data Language
2//
3// Copyright (c) 2025 Dweve IP B.V. and individual contributors.
4//
5// SPDX-License-Identifier: Apache-2.0
6//
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License in the LICENSE file at the
10// root of this repository or at: http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Per-client rate limiting using token bucket algorithm.
19
20use super::client::ClientId;
21use super::error::ResourceLimitError;
22use crate::rate_limiter::RateLimiter;
23use dashmap::DashMap;
24use std::sync::{Arc, Mutex};
25use std::time::{Duration, Instant};
26use tracing::debug;
27
28/// Rate limit configuration for a client or client pattern.
29#[derive(Debug, Clone)]
30pub struct RateLimitConfig {
31 /// Burst capacity (maximum tokens in bucket).
32 pub burst: usize,
33 /// Refill rate (tokens added per second).
34 pub per_second: usize,
35}
36
37impl RateLimitConfig {
38 /// Create a new rate limit configuration.
39 #[must_use]
40 pub fn new(burst: usize, per_second: usize) -> Self {
41 Self { burst, per_second }
42 }
43
44 /// Get default rate limit configuration.
45 #[must_use]
46 pub fn default_config() -> Self {
47 Self {
48 burst: 200,
49 per_second: 100,
50 }
51 }
52}
53
54/// Per-client rate limiter using token bucket algorithm.
55///
56/// Tracks independent rate limits for each client, allowing fair resource
57/// distribution across multiple concurrent clients.
58pub struct PerClientRateLimiter {
59 /// Individual rate limiters per client ID.
60 limiters: Arc<DashMap<ClientId, RateLimiter>>,
61
62 /// Default rate limit configuration for unknown clients.
63 default_config: RateLimitConfig,
64
65 /// Client pattern overrides (glob pattern -> config).
66 overrides: Vec<(glob::Pattern, RateLimitConfig)>,
67
68 /// Last cleanup timestamp.
69 last_cleanup: Arc<Mutex<Instant>>,
70
71 /// Cleanup interval for inactive limiters.
72 cleanup_interval: Duration,
73}
74
75impl PerClientRateLimiter {
76 /// Create a new per-client rate limiter.
77 ///
78 /// # Arguments
79 ///
80 /// * `default_config` - Default rate limit for unknown clients
81 /// * `overrides` - Client pattern overrides (glob pattern -> config)
82 /// * `cleanup_interval` - How often to clean up inactive limiters
83 #[must_use]
84 pub fn new(
85 default_config: RateLimitConfig,
86 overrides: Vec<(String, RateLimitConfig)>,
87 cleanup_interval: Duration,
88 ) -> Self {
89 // Parse glob patterns
90 let overrides = overrides
91 .into_iter()
92 .filter_map(|(pattern, config)| glob::Pattern::new(&pattern).ok().map(|p| (p, config)))
93 .collect();
94
95 Self {
96 limiters: Arc::new(DashMap::new()),
97 default_config,
98 overrides,
99 last_cleanup: Arc::new(Mutex::new(Instant::now())),
100 cleanup_interval,
101 }
102 }
103
104 /// Create a per-client rate limiter with default configuration.
105 #[must_use]
106 pub fn with_defaults() -> Self {
107 Self::new(
108 RateLimitConfig::default_config(),
109 vec![],
110 Duration::from_secs(300),
111 )
112 }
113
114 /// Check if a request from the given client is allowed.
115 ///
116 /// # Arguments
117 ///
118 /// * `client_id` - Client identifier
119 ///
120 /// # Returns
121 ///
122 /// `Ok(())` if request is allowed, `Err` if rate limit exceeded.
123 pub fn check_limit(&self, client_id: &ClientId) -> Result<(), ResourceLimitError> {
124 // Get or create rate limiter for this client
125 let limiter = self.limiters.entry(client_id.clone()).or_insert_with(|| {
126 let config = self.get_config_for_client(client_id);
127 RateLimiter::new(config.burst, config.per_second)
128 });
129
130 // Check limit
131 if !limiter.check_limit() {
132 return Err(ResourceLimitError::RateLimitExceeded {
133 client_id: client_id.to_string(),
134 burst: limiter.max_tokens(),
135 rate: limiter.refill_rate(),
136 });
137 }
138
139 // Periodic cleanup of inactive limiters
140 self.maybe_cleanup();
141
142 Ok(())
143 }
144
145 /// Get rate limit configuration for a specific client.
146 ///
147 /// Checks client pattern overrides and returns the first matching config,
148 /// otherwise returns the default config.
149 fn get_config_for_client(&self, client_id: &ClientId) -> RateLimitConfig {
150 for (pattern, config) in &self.overrides {
151 if pattern.matches(&client_id.0) {
152 return config.clone();
153 }
154 }
155 self.default_config.clone()
156 }
157
158 /// Perform cleanup if enough time has passed.
159 ///
160 /// Removes inactive client limiters to prevent unbounded memory growth.
161 fn maybe_cleanup(&self) {
162 let now = Instant::now();
163
164 // Check if cleanup is needed
165 {
166 let last = self
167 .last_cleanup
168 .lock()
169 .unwrap_or_else(std::sync::PoisonError::into_inner);
170 let elapsed = now.duration_since(*last);
171 if elapsed <= self.cleanup_interval {
172 return;
173 }
174 }
175
176 // Clean up limiters not used recently (10 minutes)
177 // Note: For now we keep all limiters since RateLimiter doesn't track last_used
178 // In production, you'd add last_used tracking to RateLimiter
179
180 // Update last cleanup time
181 let mut last = self
182 .last_cleanup
183 .lock()
184 .unwrap_or_else(std::sync::PoisonError::into_inner);
185 *last = now;
186
187 debug!("Cleaned up inactive rate limiters");
188 }
189
190 /// Get the number of active client limiters.
191 #[must_use]
192 pub fn active_limiter_count(&self) -> usize {
193 self.limiters.len()
194 }
195
196 /// Reset all rate limiters (useful for testing).
197 pub fn reset_all(&self) {
198 self.limiters.clear();
199 }
200
201 /// Remove rate limiter for a specific client.
202 pub fn remove_client(&self, client_id: &ClientId) {
203 self.limiters.remove(client_id);
204 }
205}