riglr_core/util/
rate_limit_strategy.rs

1//! Rate limiting strategies for flexible request rate control
2//!
3//! This module provides trait-based abstractions for different rate limiting algorithms,
4//! allowing users to choose the most appropriate strategy for their use case.
5
6use crate::ToolError;
7use std::time::{Duration, Instant};
8
9/// Trait defining the interface for rate limiting strategies
10///
11/// Different strategies can implement this trait to provide various
12/// rate limiting algorithms such as token bucket, fixed window, sliding window, etc.
13pub trait RateLimitStrategy: Send + Sync {
14    /// Check if a request should be allowed for the given client
15    ///
16    /// # Arguments
17    /// * `client_id` - Unique identifier for the client
18    ///
19    /// # Returns
20    /// * `Ok(())` if the request is allowed
21    /// * `Err(ToolError::RateLimited)` if the rate limit is exceeded
22    fn check_rate_limit(&self, client_id: &str) -> Result<(), ToolError>;
23
24    /// Reset rate limit state for a specific client
25    fn reset_client(&self, client_id: &str);
26
27    /// Clear all rate limit data
28    fn clear_all(&self);
29
30    /// Get current request count for a client
31    fn get_request_count(&self, client_id: &str) -> usize;
32
33    /// Get strategy name for debugging/logging
34    fn strategy_name(&self) -> &str;
35}
36
37/// Information about a client's rate limit status
38#[derive(Debug, Clone)]
39pub struct ClientRateInfo {
40    /// Timestamps of recent requests
41    pub request_times: Vec<Instant>,
42    /// Number of tokens available for burst (stored as f64 for fractional replenishment)
43    pub burst_tokens: f64,
44    /// Last time tokens were refilled
45    pub last_refill: Instant,
46}
47
48impl ClientRateInfo {
49    /// Create new client rate info with initial values
50    pub fn new(initial_tokens: f64) -> Self {
51        Self {
52            request_times: Vec::new(),
53            burst_tokens: initial_tokens,
54            last_refill: Instant::now(),
55        }
56    }
57}
58
59/// Fixed window rate limiting strategy
60///
61/// This strategy divides time into fixed windows and allows a fixed number
62/// of requests per window. When a window expires, the count resets.
63#[derive(Debug)]
64pub struct FixedWindowStrategy {
65    /// Maximum requests per window
66    pub max_requests: usize,
67    /// Duration of each window
68    pub window_duration: Duration,
69    /// Client tracking
70    pub clients: dashmap::DashMap<String, FixedWindowClientInfo>,
71}
72
73/// Information about a client's fixed window rate limit state
74#[derive(Debug, Clone)]
75pub struct FixedWindowClientInfo {
76    /// Start of current window
77    pub window_start: Instant,
78    /// Number of requests in current window
79    pub request_count: usize,
80}
81
82impl RateLimitStrategy for FixedWindowStrategy {
83    fn check_rate_limit(&self, client_id: &str) -> Result<(), ToolError> {
84        let now = Instant::now();
85        let mut entry = self
86            .clients
87            .entry(client_id.to_string())
88            .or_insert_with(|| FixedWindowClientInfo {
89                window_start: now,
90                request_count: 0,
91            });
92
93        // Check if we're in a new window
94        if now.duration_since(entry.window_start) >= self.window_duration {
95            // Reset for new window
96            entry.window_start = now;
97            entry.request_count = 0;
98        }
99
100        // Check if limit exceeded
101        if entry.request_count >= self.max_requests {
102            let time_until_reset = self
103                .window_duration
104                .saturating_sub(now.duration_since(entry.window_start));
105
106            return Err(ToolError::RateLimited {
107                source: None,
108                source_message: format!(
109                    "Fixed window rate limit: {} requests per {:?}",
110                    self.max_requests, self.window_duration
111                ),
112                context: format!("Exceeded {} requests in current window", self.max_requests),
113                retry_after: Some(time_until_reset),
114            });
115        }
116
117        entry.request_count += 1;
118        Ok(())
119    }
120
121    fn reset_client(&self, client_id: &str) {
122        self.clients.remove(client_id);
123    }
124
125    fn clear_all(&self) {
126        self.clients.clear();
127    }
128
129    fn get_request_count(&self, client_id: &str) -> usize {
130        self.clients
131            .get(client_id)
132            .map(|entry| entry.request_count)
133            .unwrap_or(0)
134    }
135
136    fn strategy_name(&self) -> &str {
137        "FixedWindow"
138    }
139}
140
141impl FixedWindowStrategy {
142    /// Create a new fixed window rate limiter
143    pub fn new(max_requests: usize, window_duration: Duration) -> Self {
144        Self {
145            max_requests,
146            window_duration,
147            clients: dashmap::DashMap::new(),
148        }
149    }
150}