Skip to main content

hedl_mcp/resource_limits/
rate_limiter.rs

1// Dweve HEDL - Hierarchical Entity Data Language
2//
3// Copyright (c) 2025 Dweve IP B.V. and individual contributors.
4//
5// SPDX-License-Identifier: Apache-2.0
6//
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License in the LICENSE file at the
10// root of this repository or at: http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Per-client rate limiting using token bucket algorithm.
19
20use super::client::ClientId;
21use super::error::ResourceLimitError;
22use crate::rate_limiter::RateLimiter;
23use dashmap::DashMap;
24use std::sync::{Arc, Mutex};
25use std::time::{Duration, Instant};
26use tracing::debug;
27
28/// Rate limit configuration for a client or client pattern.
29#[derive(Debug, Clone)]
30pub struct RateLimitConfig {
31    /// Burst capacity (maximum tokens in bucket).
32    pub burst: usize,
33    /// Refill rate (tokens added per second).
34    pub per_second: usize,
35}
36
37impl RateLimitConfig {
38    /// Create a new rate limit configuration.
39    #[must_use]
40    pub fn new(burst: usize, per_second: usize) -> Self {
41        Self { burst, per_second }
42    }
43
44    /// Get default rate limit configuration.
45    #[must_use]
46    pub fn default_config() -> Self {
47        Self {
48            burst: 200,
49            per_second: 100,
50        }
51    }
52}
53
54/// Per-client rate limiter using token bucket algorithm.
55///
56/// Tracks independent rate limits for each client, allowing fair resource
57/// distribution across multiple concurrent clients.
58pub struct PerClientRateLimiter {
59    /// Individual rate limiters per client ID.
60    limiters: Arc<DashMap<ClientId, RateLimiter>>,
61
62    /// Default rate limit configuration for unknown clients.
63    default_config: RateLimitConfig,
64
65    /// Client pattern overrides (glob pattern -> config).
66    overrides: Vec<(glob::Pattern, RateLimitConfig)>,
67
68    /// Last cleanup timestamp.
69    last_cleanup: Arc<Mutex<Instant>>,
70
71    /// Cleanup interval for inactive limiters.
72    cleanup_interval: Duration,
73}
74
75impl PerClientRateLimiter {
76    /// Create a new per-client rate limiter.
77    ///
78    /// # Arguments
79    ///
80    /// * `default_config` - Default rate limit for unknown clients
81    /// * `overrides` - Client pattern overrides (glob pattern -> config)
82    /// * `cleanup_interval` - How often to clean up inactive limiters
83    #[must_use]
84    pub fn new(
85        default_config: RateLimitConfig,
86        overrides: Vec<(String, RateLimitConfig)>,
87        cleanup_interval: Duration,
88    ) -> Self {
89        // Parse glob patterns
90        let overrides = overrides
91            .into_iter()
92            .filter_map(|(pattern, config)| glob::Pattern::new(&pattern).ok().map(|p| (p, config)))
93            .collect();
94
95        Self {
96            limiters: Arc::new(DashMap::new()),
97            default_config,
98            overrides,
99            last_cleanup: Arc::new(Mutex::new(Instant::now())),
100            cleanup_interval,
101        }
102    }
103
104    /// Create a per-client rate limiter with default configuration.
105    #[must_use]
106    pub fn with_defaults() -> Self {
107        Self::new(
108            RateLimitConfig::default_config(),
109            vec![],
110            Duration::from_secs(300),
111        )
112    }
113
114    /// Check if a request from the given client is allowed.
115    ///
116    /// # Arguments
117    ///
118    /// * `client_id` - Client identifier
119    ///
120    /// # Returns
121    ///
122    /// `Ok(())` if request is allowed, `Err` if rate limit exceeded.
123    pub fn check_limit(&self, client_id: &ClientId) -> Result<(), ResourceLimitError> {
124        // Get or create rate limiter for this client
125        let limiter = self.limiters.entry(client_id.clone()).or_insert_with(|| {
126            let config = self.get_config_for_client(client_id);
127            RateLimiter::new(config.burst, config.per_second)
128        });
129
130        // Check limit
131        if !limiter.check_limit() {
132            return Err(ResourceLimitError::RateLimitExceeded {
133                client_id: client_id.to_string(),
134                burst: limiter.max_tokens(),
135                rate: limiter.refill_rate(),
136            });
137        }
138
139        // Periodic cleanup of inactive limiters
140        self.maybe_cleanup();
141
142        Ok(())
143    }
144
145    /// Get rate limit configuration for a specific client.
146    ///
147    /// Checks client pattern overrides and returns the first matching config,
148    /// otherwise returns the default config.
149    fn get_config_for_client(&self, client_id: &ClientId) -> RateLimitConfig {
150        for (pattern, config) in &self.overrides {
151            if pattern.matches(&client_id.0) {
152                return config.clone();
153            }
154        }
155        self.default_config.clone()
156    }
157
158    /// Perform cleanup if enough time has passed.
159    ///
160    /// Removes inactive client limiters to prevent unbounded memory growth.
161    fn maybe_cleanup(&self) {
162        let now = Instant::now();
163
164        // Check if cleanup is needed
165        {
166            let last = self
167                .last_cleanup
168                .lock()
169                .unwrap_or_else(std::sync::PoisonError::into_inner);
170            let elapsed = now.duration_since(*last);
171            if elapsed <= self.cleanup_interval {
172                return;
173            }
174        }
175
176        // Clean up limiters not used recently (10 minutes)
177        // Note: For now we keep all limiters since RateLimiter doesn't track last_used
178        // In production, you'd add last_used tracking to RateLimiter
179
180        // Update last cleanup time
181        let mut last = self
182            .last_cleanup
183            .lock()
184            .unwrap_or_else(std::sync::PoisonError::into_inner);
185        *last = now;
186
187        debug!("Cleaned up inactive rate limiters");
188    }
189
190    /// Get the number of active client limiters.
191    #[must_use]
192    pub fn active_limiter_count(&self) -> usize {
193        self.limiters.len()
194    }
195
196    /// Reset all rate limiters (useful for testing).
197    pub fn reset_all(&self) {
198        self.limiters.clear();
199    }
200
201    /// Remove rate limiter for a specific client.
202    pub fn remove_client(&self, client_id: &ClientId) {
203        self.limiters.remove(client_id);
204    }
205}