Skip to main content

guts_compat/
rate_limit.rs

1//! Rate limiting with GitHub-compatible headers.
2
3use parking_lot::Mutex;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, SystemTime, UNIX_EPOCH};
8
9/// Default rate limit (requests per hour).
10pub const DEFAULT_RATE_LIMIT: u32 = 5000;
11
12/// Rate limit for unauthenticated requests.
13pub const UNAUTHENTICATED_RATE_LIMIT: u32 = 60;
14
15/// Rate limit resource types.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
17#[serde(rename_all = "snake_case")]
18pub enum RateLimitResource {
19    /// Core API endpoints.
20    Core,
21    /// Search endpoints (lower limit).
22    Search,
23    /// GraphQL endpoint.
24    Graphql,
25    /// Git operations (clone, fetch, push).
26    Git,
27    /// Code scanning.
28    CodeScanning,
29}
30
31impl RateLimitResource {
32    /// Get the default limit for this resource.
33    pub fn default_limit(&self, authenticated: bool) -> u32 {
34        if !authenticated {
35            return UNAUTHENTICATED_RATE_LIMIT;
36        }
37
38        match self {
39            Self::Core => 5000,
40            Self::Search => 30,
41            Self::Graphql => 5000,
42            Self::Git => 5000,
43            Self::CodeScanning => 1000,
44        }
45    }
46
47    /// Get the reset interval for this resource.
48    pub fn reset_interval(&self) -> Duration {
49        match self {
50            Self::Core => Duration::from_secs(3600),         // 1 hour
51            Self::Search => Duration::from_secs(60),         // 1 minute
52            Self::Graphql => Duration::from_secs(3600),      // 1 hour
53            Self::Git => Duration::from_secs(3600),          // 1 hour
54            Self::CodeScanning => Duration::from_secs(3600), // 1 hour
55        }
56    }
57}
58
59impl std::fmt::Display for RateLimitResource {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match self {
62            Self::Core => write!(f, "core"),
63            Self::Search => write!(f, "search"),
64            Self::Graphql => write!(f, "graphql"),
65            Self::Git => write!(f, "git"),
66            Self::CodeScanning => write!(f, "code_scanning"),
67        }
68    }
69}
70
71/// Rate limit state for a specific user/resource.
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct RateLimitState {
74    /// Maximum requests allowed.
75    pub limit: u32,
76    /// Remaining requests in current window.
77    pub remaining: u32,
78    /// Unix timestamp when the limit resets.
79    pub reset: u64,
80    /// Requests used in current window.
81    pub used: u32,
82    /// Resource type.
83    pub resource: RateLimitResource,
84}
85
86impl RateLimitState {
87    /// Create a new rate limit state.
88    pub fn new(limit: u32, resource: RateLimitResource) -> Self {
89        let reset = SystemTime::now()
90            .duration_since(UNIX_EPOCH)
91            .unwrap()
92            .as_secs()
93            + resource.reset_interval().as_secs();
94
95        Self {
96            limit,
97            remaining: limit,
98            reset,
99            used: 0,
100            resource,
101        }
102    }
103
104    /// Check if the rate limit is exceeded.
105    pub fn is_exceeded(&self) -> bool {
106        self.remaining == 0 && !self.is_reset()
107    }
108
109    /// Check if the window has reset.
110    pub fn is_reset(&self) -> bool {
111        let now = SystemTime::now()
112            .duration_since(UNIX_EPOCH)
113            .unwrap()
114            .as_secs();
115        now >= self.reset
116    }
117
118    /// Consume one request from the limit.
119    pub fn consume(&mut self) -> bool {
120        // Reset if window expired
121        if self.is_reset() {
122            self.reset_window();
123        }
124
125        if self.remaining > 0 {
126            self.remaining -= 1;
127            self.used += 1;
128            true
129        } else {
130            false
131        }
132    }
133
134    /// Reset the window.
135    pub fn reset_window(&mut self) {
136        self.remaining = self.limit;
137        self.used = 0;
138        self.reset = SystemTime::now()
139            .duration_since(UNIX_EPOCH)
140            .unwrap()
141            .as_secs()
142            + self.resource.reset_interval().as_secs();
143    }
144
145    /// Get the time until reset in seconds.
146    pub fn time_until_reset(&self) -> u64 {
147        let now = SystemTime::now()
148            .duration_since(UNIX_EPOCH)
149            .unwrap()
150            .as_secs();
151        self.reset.saturating_sub(now)
152    }
153}
154
155/// Rate limit response for API.
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct RateLimitResponse {
158    /// Resources with their limits.
159    pub resources: RateLimitResources,
160    /// Rate limit state for the primary resource.
161    pub rate: RateLimitInfo,
162}
163
164/// Rate limit resources in response.
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct RateLimitResources {
167    /// Core API limit.
168    pub core: RateLimitInfo,
169    /// Search API limit.
170    pub search: RateLimitInfo,
171    /// GraphQL API limit.
172    pub graphql: RateLimitInfo,
173}
174
175/// Rate limit info for a single resource.
176#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct RateLimitInfo {
178    /// Maximum requests allowed.
179    pub limit: u32,
180    /// Remaining requests.
181    pub remaining: u32,
182    /// Unix timestamp when limit resets.
183    pub reset: u64,
184    /// Requests used.
185    pub used: u32,
186}
187
188impl From<&RateLimitState> for RateLimitInfo {
189    fn from(state: &RateLimitState) -> Self {
190        Self {
191            limit: state.limit,
192            remaining: state.remaining,
193            reset: state.reset,
194            used: state.used,
195        }
196    }
197}
198
199/// Rate limit headers for HTTP responses.
200#[derive(Debug, Clone, Default)]
201pub struct RateLimitHeaders {
202    /// X-RateLimit-Limit header value.
203    pub limit: String,
204    /// X-RateLimit-Remaining header value.
205    pub remaining: String,
206    /// X-RateLimit-Reset header value.
207    pub reset: String,
208    /// X-RateLimit-Used header value.
209    pub used: String,
210    /// X-RateLimit-Resource header value.
211    pub resource: String,
212}
213
214impl From<&RateLimitState> for RateLimitHeaders {
215    fn from(state: &RateLimitState) -> Self {
216        Self {
217            limit: state.limit.to_string(),
218            remaining: state.remaining.to_string(),
219            reset: state.reset.to_string(),
220            used: state.used.to_string(),
221            resource: state.resource.to_string(),
222        }
223    }
224}
225
226/// Rate limiter that tracks limits per user and resource.
227#[derive(Debug, Clone)]
228pub struct RateLimiter {
229    /// User states keyed by (user_id, resource).
230    states: Arc<Mutex<HashMap<(String, RateLimitResource), RateLimitState>>>,
231}
232
233impl Default for RateLimiter {
234    fn default() -> Self {
235        Self::new()
236    }
237}
238
239impl RateLimiter {
240    /// Create a new rate limiter.
241    pub fn new() -> Self {
242        Self {
243            states: Arc::new(Mutex::new(HashMap::new())),
244        }
245    }
246
247    /// Get or create a rate limit state for a user/resource.
248    pub fn get_state(
249        &self,
250        user_id: &str,
251        resource: RateLimitResource,
252        authenticated: bool,
253    ) -> RateLimitState {
254        let mut states = self.states.lock();
255        let key = (user_id.to_string(), resource);
256
257        states
258            .entry(key)
259            .or_insert_with(|| {
260                let limit = resource.default_limit(authenticated);
261                RateLimitState::new(limit, resource)
262            })
263            .clone()
264    }
265
266    /// Check and consume a request for a user/resource.
267    ///
268    /// Returns the updated state if allowed, or None if rate limited.
269    pub fn check_and_consume(
270        &self,
271        user_id: &str,
272        resource: RateLimitResource,
273        authenticated: bool,
274    ) -> Option<RateLimitState> {
275        let mut states = self.states.lock();
276        let key = (user_id.to_string(), resource);
277
278        let state = states.entry(key).or_insert_with(|| {
279            let limit = resource.default_limit(authenticated);
280            RateLimitState::new(limit, resource)
281        });
282
283        if state.consume() {
284            Some(state.clone())
285        } else {
286            None
287        }
288    }
289
290    /// Get the rate limit response for a user.
291    pub fn get_response(&self, user_id: &str, authenticated: bool) -> RateLimitResponse {
292        let core = self.get_state(user_id, RateLimitResource::Core, authenticated);
293        let search = self.get_state(user_id, RateLimitResource::Search, authenticated);
294        let graphql = self.get_state(user_id, RateLimitResource::Graphql, authenticated);
295
296        RateLimitResponse {
297            resources: RateLimitResources {
298                core: (&core).into(),
299                search: (&search).into(),
300                graphql: (&graphql).into(),
301            },
302            rate: (&core).into(),
303        }
304    }
305
306    /// Clean up expired states to free memory.
307    pub fn cleanup(&self) {
308        let now = SystemTime::now()
309            .duration_since(UNIX_EPOCH)
310            .unwrap()
311            .as_secs();
312
313        let mut states = self.states.lock();
314        states.retain(|_, state| {
315            // Keep states that haven't expired yet or have been used
316            state.reset > now || state.used > 0
317        });
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_rate_limit_state() {
327        let mut state = RateLimitState::new(100, RateLimitResource::Core);
328
329        assert_eq!(state.limit, 100);
330        assert_eq!(state.remaining, 100);
331        assert_eq!(state.used, 0);
332        assert!(!state.is_exceeded());
333
334        // Consume a request
335        assert!(state.consume());
336        assert_eq!(state.remaining, 99);
337        assert_eq!(state.used, 1);
338    }
339
340    #[test]
341    fn test_rate_limit_exceeded() {
342        let mut state = RateLimitState::new(2, RateLimitResource::Core);
343
344        assert!(state.consume());
345        assert!(state.consume());
346        assert!(!state.consume()); // Exceeded
347
348        assert!(state.is_exceeded());
349    }
350
351    #[test]
352    fn test_rate_limiter() {
353        let limiter = RateLimiter::new();
354
355        // First request should succeed
356        let state = limiter.check_and_consume("user1", RateLimitResource::Core, true);
357        assert!(state.is_some());
358
359        // Get state
360        let state = limiter.get_state("user1", RateLimitResource::Core, true);
361        assert_eq!(state.used, 1);
362    }
363
364    #[test]
365    fn test_unauthenticated_limit() {
366        let limiter = RateLimiter::new();
367        let state = limiter.get_state("anon", RateLimitResource::Core, false);
368
369        assert_eq!(state.limit, UNAUTHENTICATED_RATE_LIMIT);
370    }
371
372    #[test]
373    fn test_rate_limit_headers() {
374        let state = RateLimitState::new(5000, RateLimitResource::Core);
375        let headers = RateLimitHeaders::from(&state);
376
377        assert_eq!(headers.limit, "5000");
378        assert_eq!(headers.remaining, "5000");
379        assert_eq!(headers.resource, "core");
380    }
381
382    #[test]
383    fn test_resource_default_limits() {
384        assert_eq!(RateLimitResource::Core.default_limit(true), 5000);
385        assert_eq!(RateLimitResource::Search.default_limit(true), 30);
386        assert_eq!(RateLimitResource::Core.default_limit(false), 60);
387    }
388}