use std::str::FromStr;
use std::time::Duration;
use chrono::{DateTime, Utc};
use crate::ForgeError;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum RateLimitKey {
#[default]
User,
Ip,
Tenant,
UserAction,
Global,
}
impl RateLimitKey {
pub fn as_str(&self) -> &'static str {
match self {
Self::User => "user",
Self::Ip => "ip",
Self::Tenant => "tenant",
Self::UserAction => "user_action",
Self::Global => "global",
}
}
}
impl FromStr for RateLimitKey {
type Err = std::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s {
"user" => Self::User,
"ip" => Self::Ip,
"tenant" => Self::Tenant,
"user_action" => Self::UserAction,
"global" => Self::Global,
_ => Self::User,
})
}
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub requests: u32,
pub per: Duration,
pub key: RateLimitKey,
}
impl RateLimitConfig {
pub fn new(requests: u32, per: Duration) -> Self {
Self {
requests,
per,
key: RateLimitKey::default(),
}
}
pub fn with_key(mut self, key: RateLimitKey) -> Self {
self.key = key;
self
}
pub fn refill_rate(&self) -> f64 {
self.requests as f64 / self.per.as_secs_f64()
}
pub fn parse_duration(s: &str) -> Option<Duration> {
let s = s.trim();
if s.is_empty() {
return None;
}
let (num_str, unit) = s.split_at(s.len() - 1);
let num: u64 = num_str.parse().ok()?;
match unit {
"s" => Some(Duration::from_secs(num)),
"m" => Some(Duration::from_secs(num * 60)),
"h" => Some(Duration::from_secs(num * 3600)),
"d" => Some(Duration::from_secs(num * 86400)),
_ => None,
}
}
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
requests: 100,
per: Duration::from_secs(60),
key: RateLimitKey::User,
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimitResult {
pub allowed: bool,
pub remaining: u32,
pub reset_at: DateTime<Utc>,
pub retry_after: Option<Duration>,
}
impl RateLimitResult {
pub fn allowed(remaining: u32, reset_at: DateTime<Utc>) -> Self {
Self {
allowed: true,
remaining,
reset_at,
retry_after: None,
}
}
pub fn denied(remaining: u32, reset_at: DateTime<Utc>, retry_after: Duration) -> Self {
Self {
allowed: false,
remaining,
reset_at,
retry_after: Some(retry_after),
}
}
pub fn to_error(&self, limit: u32) -> Option<ForgeError> {
if self.allowed {
None
} else {
Some(ForgeError::RateLimitExceeded {
retry_after: self.retry_after.unwrap_or(Duration::from_secs(1)),
limit,
remaining: self.remaining,
})
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimitHeaders {
pub limit: u32,
pub remaining: u32,
pub reset: i64,
pub retry_after: Option<u64>,
}
impl RateLimitHeaders {
pub fn from_result(result: &RateLimitResult, limit: u32) -> Self {
Self {
limit,
remaining: result.remaining,
reset: result.reset_at.timestamp(),
retry_after: result.retry_after.map(|d| d.as_secs()),
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
use super::*;
#[test]
fn test_rate_limit_key() {
assert_eq!(RateLimitKey::User.as_str(), "user");
assert_eq!(RateLimitKey::Ip.as_str(), "ip");
assert_eq!(RateLimitKey::Global.as_str(), "global");
assert_eq!("user".parse::<RateLimitKey>().unwrap(), RateLimitKey::User);
assert_eq!("ip".parse::<RateLimitKey>().unwrap(), RateLimitKey::Ip);
}
#[test]
fn test_rate_limit_config() {
let config = RateLimitConfig::new(100, Duration::from_secs(60));
assert_eq!(config.requests, 100);
assert_eq!(config.per, Duration::from_secs(60));
assert!((config.refill_rate() - 1.666666).abs() < 0.01);
}
#[test]
fn test_parse_duration() {
assert_eq!(
RateLimitConfig::parse_duration("1s"),
Some(Duration::from_secs(1))
);
assert_eq!(
RateLimitConfig::parse_duration("1m"),
Some(Duration::from_secs(60))
);
assert_eq!(
RateLimitConfig::parse_duration("1h"),
Some(Duration::from_secs(3600))
);
assert_eq!(
RateLimitConfig::parse_duration("1d"),
Some(Duration::from_secs(86400))
);
assert_eq!(RateLimitConfig::parse_duration("invalid"), None);
}
#[test]
fn test_rate_limit_result_allowed() {
let result = RateLimitResult::allowed(99, Utc::now());
assert!(result.allowed);
assert!(result.retry_after.is_none());
assert!(result.to_error(100).is_none());
}
#[test]
fn test_rate_limit_result_denied() {
let result = RateLimitResult::denied(0, Utc::now(), Duration::from_secs(30));
assert!(!result.allowed);
assert!(result.retry_after.is_some());
assert!(result.to_error(100).is_some());
}
}