use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
pub const DEFAULT_RATE_LIMIT: u32 = 5000;
pub const UNAUTHENTICATED_RATE_LIMIT: u32 = 60;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RateLimitResource {
Core,
Search,
Graphql,
Git,
CodeScanning,
}
impl RateLimitResource {
pub fn default_limit(&self, authenticated: bool) -> u32 {
if !authenticated {
return UNAUTHENTICATED_RATE_LIMIT;
}
match self {
Self::Core => 5000,
Self::Search => 30,
Self::Graphql => 5000,
Self::Git => 5000,
Self::CodeScanning => 1000,
}
}
pub fn reset_interval(&self) -> Duration {
match self {
Self::Core => Duration::from_secs(3600), Self::Search => Duration::from_secs(60), Self::Graphql => Duration::from_secs(3600), Self::Git => Duration::from_secs(3600), Self::CodeScanning => Duration::from_secs(3600), }
}
}
impl std::fmt::Display for RateLimitResource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Core => write!(f, "core"),
Self::Search => write!(f, "search"),
Self::Graphql => write!(f, "graphql"),
Self::Git => write!(f, "git"),
Self::CodeScanning => write!(f, "code_scanning"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitState {
pub limit: u32,
pub remaining: u32,
pub reset: u64,
pub used: u32,
pub resource: RateLimitResource,
}
impl RateLimitState {
pub fn new(limit: u32, resource: RateLimitResource) -> Self {
let reset = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ resource.reset_interval().as_secs();
Self {
limit,
remaining: limit,
reset,
used: 0,
resource,
}
}
pub fn is_exceeded(&self) -> bool {
self.remaining == 0 && !self.is_reset()
}
pub fn is_reset(&self) -> bool {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
now >= self.reset
}
pub fn consume(&mut self) -> bool {
if self.is_reset() {
self.reset_window();
}
if self.remaining > 0 {
self.remaining -= 1;
self.used += 1;
true
} else {
false
}
}
pub fn reset_window(&mut self) {
self.remaining = self.limit;
self.used = 0;
self.reset = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ self.resource.reset_interval().as_secs();
}
pub fn time_until_reset(&self) -> u64 {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
self.reset.saturating_sub(now)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitResponse {
pub resources: RateLimitResources,
pub rate: RateLimitInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitResources {
pub core: RateLimitInfo,
pub search: RateLimitInfo,
pub graphql: RateLimitInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitInfo {
pub limit: u32,
pub remaining: u32,
pub reset: u64,
pub used: u32,
}
impl From<&RateLimitState> for RateLimitInfo {
fn from(state: &RateLimitState) -> Self {
Self {
limit: state.limit,
remaining: state.remaining,
reset: state.reset,
used: state.used,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct RateLimitHeaders {
pub limit: String,
pub remaining: String,
pub reset: String,
pub used: String,
pub resource: String,
}
impl From<&RateLimitState> for RateLimitHeaders {
fn from(state: &RateLimitState) -> Self {
Self {
limit: state.limit.to_string(),
remaining: state.remaining.to_string(),
reset: state.reset.to_string(),
used: state.used.to_string(),
resource: state.resource.to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimiter {
states: Arc<Mutex<HashMap<(String, RateLimitResource), RateLimitState>>>,
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new()
}
}
impl RateLimiter {
pub fn new() -> Self {
Self {
states: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn get_state(
&self,
user_id: &str,
resource: RateLimitResource,
authenticated: bool,
) -> RateLimitState {
let mut states = self.states.lock();
let key = (user_id.to_string(), resource);
states
.entry(key)
.or_insert_with(|| {
let limit = resource.default_limit(authenticated);
RateLimitState::new(limit, resource)
})
.clone()
}
pub fn check_and_consume(
&self,
user_id: &str,
resource: RateLimitResource,
authenticated: bool,
) -> Option<RateLimitState> {
let mut states = self.states.lock();
let key = (user_id.to_string(), resource);
let state = states.entry(key).or_insert_with(|| {
let limit = resource.default_limit(authenticated);
RateLimitState::new(limit, resource)
});
if state.consume() {
Some(state.clone())
} else {
None
}
}
pub fn get_response(&self, user_id: &str, authenticated: bool) -> RateLimitResponse {
let core = self.get_state(user_id, RateLimitResource::Core, authenticated);
let search = self.get_state(user_id, RateLimitResource::Search, authenticated);
let graphql = self.get_state(user_id, RateLimitResource::Graphql, authenticated);
RateLimitResponse {
resources: RateLimitResources {
core: (&core).into(),
search: (&search).into(),
graphql: (&graphql).into(),
},
rate: (&core).into(),
}
}
pub fn cleanup(&self) {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let mut states = self.states.lock();
states.retain(|_, state| {
state.reset > now || state.used > 0
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limit_state() {
let mut state = RateLimitState::new(100, RateLimitResource::Core);
assert_eq!(state.limit, 100);
assert_eq!(state.remaining, 100);
assert_eq!(state.used, 0);
assert!(!state.is_exceeded());
assert!(state.consume());
assert_eq!(state.remaining, 99);
assert_eq!(state.used, 1);
}
#[test]
fn test_rate_limit_exceeded() {
let mut state = RateLimitState::new(2, RateLimitResource::Core);
assert!(state.consume());
assert!(state.consume());
assert!(!state.consume());
assert!(state.is_exceeded());
}
#[test]
fn test_rate_limiter() {
let limiter = RateLimiter::new();
let state = limiter.check_and_consume("user1", RateLimitResource::Core, true);
assert!(state.is_some());
let state = limiter.get_state("user1", RateLimitResource::Core, true);
assert_eq!(state.used, 1);
}
#[test]
fn test_unauthenticated_limit() {
let limiter = RateLimiter::new();
let state = limiter.get_state("anon", RateLimitResource::Core, false);
assert_eq!(state.limit, UNAUTHENTICATED_RATE_LIMIT);
}
#[test]
fn test_rate_limit_headers() {
let state = RateLimitState::new(5000, RateLimitResource::Core);
let headers = RateLimitHeaders::from(&state);
assert_eq!(headers.limit, "5000");
assert_eq!(headers.remaining, "5000");
assert_eq!(headers.resource, "core");
}
#[test]
fn test_resource_default_limits() {
assert_eq!(RateLimitResource::Core.default_limit(true), 5000);
assert_eq!(RateLimitResource::Search.default_limit(true), 30);
assert_eq!(RateLimitResource::Core.default_limit(false), 60);
}
}