1use parking_lot::Mutex;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, SystemTime, UNIX_EPOCH};
8
9pub const DEFAULT_RATE_LIMIT: u32 = 5000;
11
12pub const UNAUTHENTICATED_RATE_LIMIT: u32 = 60;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
17#[serde(rename_all = "snake_case")]
18pub enum RateLimitResource {
19 Core,
21 Search,
23 Graphql,
25 Git,
27 CodeScanning,
29}
30
31impl RateLimitResource {
32 pub fn default_limit(&self, authenticated: bool) -> u32 {
34 if !authenticated {
35 return UNAUTHENTICATED_RATE_LIMIT;
36 }
37
38 match self {
39 Self::Core => 5000,
40 Self::Search => 30,
41 Self::Graphql => 5000,
42 Self::Git => 5000,
43 Self::CodeScanning => 1000,
44 }
45 }
46
47 pub fn reset_interval(&self) -> Duration {
49 match self {
50 Self::Core => Duration::from_secs(3600), Self::Search => Duration::from_secs(60), Self::Graphql => Duration::from_secs(3600), Self::Git => Duration::from_secs(3600), Self::CodeScanning => Duration::from_secs(3600), }
56 }
57}
58
59impl std::fmt::Display for RateLimitResource {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 Self::Core => write!(f, "core"),
63 Self::Search => write!(f, "search"),
64 Self::Graphql => write!(f, "graphql"),
65 Self::Git => write!(f, "git"),
66 Self::CodeScanning => write!(f, "code_scanning"),
67 }
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct RateLimitState {
74 pub limit: u32,
76 pub remaining: u32,
78 pub reset: u64,
80 pub used: u32,
82 pub resource: RateLimitResource,
84}
85
86impl RateLimitState {
87 pub fn new(limit: u32, resource: RateLimitResource) -> Self {
89 let reset = SystemTime::now()
90 .duration_since(UNIX_EPOCH)
91 .unwrap()
92 .as_secs()
93 + resource.reset_interval().as_secs();
94
95 Self {
96 limit,
97 remaining: limit,
98 reset,
99 used: 0,
100 resource,
101 }
102 }
103
104 pub fn is_exceeded(&self) -> bool {
106 self.remaining == 0 && !self.is_reset()
107 }
108
109 pub fn is_reset(&self) -> bool {
111 let now = SystemTime::now()
112 .duration_since(UNIX_EPOCH)
113 .unwrap()
114 .as_secs();
115 now >= self.reset
116 }
117
118 pub fn consume(&mut self) -> bool {
120 if self.is_reset() {
122 self.reset_window();
123 }
124
125 if self.remaining > 0 {
126 self.remaining -= 1;
127 self.used += 1;
128 true
129 } else {
130 false
131 }
132 }
133
134 pub fn reset_window(&mut self) {
136 self.remaining = self.limit;
137 self.used = 0;
138 self.reset = SystemTime::now()
139 .duration_since(UNIX_EPOCH)
140 .unwrap()
141 .as_secs()
142 + self.resource.reset_interval().as_secs();
143 }
144
145 pub fn time_until_reset(&self) -> u64 {
147 let now = SystemTime::now()
148 .duration_since(UNIX_EPOCH)
149 .unwrap()
150 .as_secs();
151 self.reset.saturating_sub(now)
152 }
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct RateLimitResponse {
158 pub resources: RateLimitResources,
160 pub rate: RateLimitInfo,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct RateLimitResources {
167 pub core: RateLimitInfo,
169 pub search: RateLimitInfo,
171 pub graphql: RateLimitInfo,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct RateLimitInfo {
178 pub limit: u32,
180 pub remaining: u32,
182 pub reset: u64,
184 pub used: u32,
186}
187
188impl From<&RateLimitState> for RateLimitInfo {
189 fn from(state: &RateLimitState) -> Self {
190 Self {
191 limit: state.limit,
192 remaining: state.remaining,
193 reset: state.reset,
194 used: state.used,
195 }
196 }
197}
198
199#[derive(Debug, Clone, Default)]
201pub struct RateLimitHeaders {
202 pub limit: String,
204 pub remaining: String,
206 pub reset: String,
208 pub used: String,
210 pub resource: String,
212}
213
214impl From<&RateLimitState> for RateLimitHeaders {
215 fn from(state: &RateLimitState) -> Self {
216 Self {
217 limit: state.limit.to_string(),
218 remaining: state.remaining.to_string(),
219 reset: state.reset.to_string(),
220 used: state.used.to_string(),
221 resource: state.resource.to_string(),
222 }
223 }
224}
225
226#[derive(Debug, Clone)]
228pub struct RateLimiter {
229 states: Arc<Mutex<HashMap<(String, RateLimitResource), RateLimitState>>>,
231}
232
233impl Default for RateLimiter {
234 fn default() -> Self {
235 Self::new()
236 }
237}
238
239impl RateLimiter {
240 pub fn new() -> Self {
242 Self {
243 states: Arc::new(Mutex::new(HashMap::new())),
244 }
245 }
246
247 pub fn get_state(
249 &self,
250 user_id: &str,
251 resource: RateLimitResource,
252 authenticated: bool,
253 ) -> RateLimitState {
254 let mut states = self.states.lock();
255 let key = (user_id.to_string(), resource);
256
257 states
258 .entry(key)
259 .or_insert_with(|| {
260 let limit = resource.default_limit(authenticated);
261 RateLimitState::new(limit, resource)
262 })
263 .clone()
264 }
265
266 pub fn check_and_consume(
270 &self,
271 user_id: &str,
272 resource: RateLimitResource,
273 authenticated: bool,
274 ) -> Option<RateLimitState> {
275 let mut states = self.states.lock();
276 let key = (user_id.to_string(), resource);
277
278 let state = states.entry(key).or_insert_with(|| {
279 let limit = resource.default_limit(authenticated);
280 RateLimitState::new(limit, resource)
281 });
282
283 if state.consume() {
284 Some(state.clone())
285 } else {
286 None
287 }
288 }
289
290 pub fn get_response(&self, user_id: &str, authenticated: bool) -> RateLimitResponse {
292 let core = self.get_state(user_id, RateLimitResource::Core, authenticated);
293 let search = self.get_state(user_id, RateLimitResource::Search, authenticated);
294 let graphql = self.get_state(user_id, RateLimitResource::Graphql, authenticated);
295
296 RateLimitResponse {
297 resources: RateLimitResources {
298 core: (&core).into(),
299 search: (&search).into(),
300 graphql: (&graphql).into(),
301 },
302 rate: (&core).into(),
303 }
304 }
305
306 pub fn cleanup(&self) {
308 let now = SystemTime::now()
309 .duration_since(UNIX_EPOCH)
310 .unwrap()
311 .as_secs();
312
313 let mut states = self.states.lock();
314 states.retain(|_, state| {
315 state.reset > now || state.used > 0
317 });
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_rate_limit_state() {
327 let mut state = RateLimitState::new(100, RateLimitResource::Core);
328
329 assert_eq!(state.limit, 100);
330 assert_eq!(state.remaining, 100);
331 assert_eq!(state.used, 0);
332 assert!(!state.is_exceeded());
333
334 assert!(state.consume());
336 assert_eq!(state.remaining, 99);
337 assert_eq!(state.used, 1);
338 }
339
340 #[test]
341 fn test_rate_limit_exceeded() {
342 let mut state = RateLimitState::new(2, RateLimitResource::Core);
343
344 assert!(state.consume());
345 assert!(state.consume());
346 assert!(!state.consume()); assert!(state.is_exceeded());
349 }
350
351 #[test]
352 fn test_rate_limiter() {
353 let limiter = RateLimiter::new();
354
355 let state = limiter.check_and_consume("user1", RateLimitResource::Core, true);
357 assert!(state.is_some());
358
359 let state = limiter.get_state("user1", RateLimitResource::Core, true);
361 assert_eq!(state.used, 1);
362 }
363
364 #[test]
365 fn test_unauthenticated_limit() {
366 let limiter = RateLimiter::new();
367 let state = limiter.get_state("anon", RateLimitResource::Core, false);
368
369 assert_eq!(state.limit, UNAUTHENTICATED_RATE_LIMIT);
370 }
371
372 #[test]
373 fn test_rate_limit_headers() {
374 let state = RateLimitState::new(5000, RateLimitResource::Core);
375 let headers = RateLimitHeaders::from(&state);
376
377 assert_eq!(headers.limit, "5000");
378 assert_eq!(headers.remaining, "5000");
379 assert_eq!(headers.resource, "core");
380 }
381
382 #[test]
383 fn test_resource_default_limits() {
384 assert_eq!(RateLimitResource::Core.default_limit(true), 5000);
385 assert_eq!(RateLimitResource::Search.default_limit(true), 30);
386 assert_eq!(RateLimitResource::Core.default_limit(false), 60);
387 }
388}