use crate::error::HammerworkError;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tokio::time::sleep;
fn serialize_duration<S>(duration: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match duration {
Some(d) => serializer.serialize_some(&d.as_secs()),
None => serializer.serialize_none(),
}
}
fn deserialize_duration<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::Deserialize;
let secs: Option<u64> = Option::deserialize(deserializer)?;
Ok(secs.map(Duration::from_secs))
}
#[derive(Debug, Clone)]
pub struct RateLimit {
pub rate: u64,
pub per: Duration,
pub burst_limit: u64,
}
impl RateLimit {
pub fn per_second(rate: u64) -> Self {
Self {
rate,
per: Duration::from_secs(1),
burst_limit: rate,
}
}
pub fn per_minute(rate: u64) -> Self {
Self {
rate,
per: Duration::from_secs(60),
burst_limit: rate,
}
}
pub fn per_hour(rate: u64) -> Self {
Self {
rate,
per: Duration::from_secs(3600),
burst_limit: rate,
}
}
pub fn with_burst_limit(mut self, burst_limit: u64) -> Self {
self.burst_limit = burst_limit;
self
}
fn refill_rate_per_ms(&self) -> f64 {
self.rate as f64 / self.per.as_millis() as f64
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ThrottleConfig {
pub max_concurrent: Option<u64>,
pub rate_per_minute: Option<u64>,
#[serde(
serialize_with = "serialize_duration",
deserialize_with = "deserialize_duration",
skip_serializing_if = "Option::is_none",
default
)]
pub backoff_on_error: Option<Duration>,
pub enabled: bool,
}
impl ThrottleConfig {
pub fn new() -> Self {
Self {
max_concurrent: None,
rate_per_minute: None,
backoff_on_error: None,
enabled: true,
}
}
pub fn max_concurrent(mut self, max: u64) -> Self {
self.max_concurrent = Some(max);
self
}
pub fn rate_per_minute(mut self, rate: u64) -> Self {
self.rate_per_minute = Some(rate);
self
}
pub fn backoff_on_error(mut self, duration: Duration) -> Self {
self.backoff_on_error = Some(duration);
self
}
pub fn enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
pub fn to_rate_limit(&self) -> Option<RateLimit> {
self.rate_per_minute.map(RateLimit::per_minute)
}
}
impl Default for ThrottleConfig {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct TokenBucket {
tokens: f64,
capacity: f64,
refill_rate: f64,
last_refill: Instant,
}
impl TokenBucket {
pub fn new(capacity: f64, refill_rate: f64) -> Self {
Self {
tokens: capacity,
capacity,
refill_rate,
last_refill: Instant::now(),
}
}
pub fn try_consume(&mut self, tokens: f64) -> bool {
self.refill();
if self.tokens >= tokens {
self.tokens -= tokens;
true
} else {
false
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill);
let tokens_to_add = self.refill_rate * elapsed.as_millis() as f64;
self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
self.last_refill = now;
}
pub fn available_tokens(&mut self) -> f64 {
self.refill();
self.tokens
}
pub fn time_until_token(&mut self) -> Duration {
self.refill();
if self.tokens >= 1.0 {
Duration::from_millis(0)
} else {
let tokens_needed = 1.0 - self.tokens;
let ms_needed = (tokens_needed / self.refill_rate).ceil() as u64;
Duration::from_millis(ms_needed)
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimiter {
bucket: Arc<Mutex<TokenBucket>>,
rate_limit: RateLimit,
}
impl RateLimiter {
pub fn new(rate_limit: RateLimit) -> Self {
let capacity = rate_limit.burst_limit as f64;
let refill_rate = rate_limit.refill_rate_per_ms();
Self {
bucket: Arc::new(Mutex::new(TokenBucket::new(capacity, refill_rate))),
rate_limit,
}
}
pub fn check(&self) -> bool {
if let Ok(mut bucket) = self.bucket.lock() {
bucket.try_consume(1.0)
} else {
true
}
}
pub async fn acquire(&self) -> Result<(), HammerworkError> {
loop {
let wait_time = {
if let Ok(mut bucket) = self.bucket.lock() {
if bucket.try_consume(1.0) {
return Ok(());
}
bucket.time_until_token()
} else {
return Err(HammerworkError::RateLimit {
message: "Rate limiter lock poisoned".to_string(),
});
}
};
if wait_time > Duration::from_millis(0) {
sleep(wait_time).await;
}
}
}
pub fn try_acquire(&self) -> bool {
self.check()
}
pub fn rate_limit(&self) -> &RateLimit {
&self.rate_limit
}
pub fn available_tokens(&self) -> f64 {
if let Ok(mut bucket) = self.bucket.lock() {
bucket.available_tokens()
} else {
0.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio::time::{Instant, sleep};
#[test]
fn test_rate_limit_creation() {
let rate_limit = RateLimit::per_second(10);
assert_eq!(rate_limit.rate, 10);
assert_eq!(rate_limit.per, Duration::from_secs(1));
assert_eq!(rate_limit.burst_limit, 10);
let rate_limit = RateLimit::per_minute(60).with_burst_limit(100);
assert_eq!(rate_limit.rate, 60);
assert_eq!(rate_limit.per, Duration::from_secs(60));
assert_eq!(rate_limit.burst_limit, 100);
}
#[test]
fn test_throttle_config() {
let config = ThrottleConfig::new()
.max_concurrent(5)
.rate_per_minute(100)
.backoff_on_error(Duration::from_secs(30));
assert_eq!(config.max_concurrent, Some(5));
assert_eq!(config.rate_per_minute, Some(100));
assert_eq!(config.backoff_on_error, Some(Duration::from_secs(30)));
assert!(config.enabled);
let rate_limit = config.to_rate_limit().unwrap();
assert_eq!(rate_limit.rate, 100);
assert_eq!(rate_limit.per, Duration::from_secs(60));
}
#[test]
fn test_token_bucket() {
let mut bucket = TokenBucket::new(10.0, 1.0);
assert_eq!(bucket.available_tokens(), 10.0);
assert!(bucket.try_consume(5.0));
assert_eq!(bucket.available_tokens(), 5.0);
assert!(!bucket.try_consume(10.0));
assert_eq!(bucket.available_tokens(), 5.0);
assert!(bucket.try_consume(5.0));
assert_eq!(bucket.available_tokens(), 0.0);
}
#[tokio::test]
async fn test_token_bucket_refill() {
let mut bucket = TokenBucket::new(10.0, 10.0);
assert!(bucket.try_consume(10.0));
assert_eq!(bucket.available_tokens(), 0.0);
sleep(Duration::from_millis(2)).await;
let tokens = bucket.available_tokens();
assert_eq!(tokens, 10.0);
}
#[test]
fn test_rate_limiter_creation() {
let rate_limit = RateLimit::per_second(10);
let limiter = RateLimiter::new(rate_limit);
assert_eq!(limiter.rate_limit().rate, 10);
assert!(limiter.available_tokens() > 0.0);
}
#[tokio::test]
async fn test_rate_limiter_check() {
let rate_limit = RateLimit::per_second(2); let limiter = RateLimiter::new(rate_limit);
assert!(limiter.check());
assert!(limiter.check());
assert!(!limiter.check());
}
#[tokio::test]
async fn test_rate_limiter_acquire() {
let rate_limit = RateLimit::per_second(1000); let limiter = RateLimiter::new(rate_limit);
let start = Instant::now();
limiter.acquire().await.unwrap();
let elapsed = start.elapsed();
assert!(elapsed < Duration::from_millis(10));
}
#[tokio::test]
async fn test_rate_limiter_try_acquire() {
let rate_limit = RateLimit::per_second(1);
let limiter = RateLimiter::new(rate_limit);
assert!(limiter.try_acquire());
assert!(!limiter.try_acquire());
}
#[test]
fn test_rate_limit_refill_calculation() {
let rate_limit = RateLimit::per_second(10);
let refill_rate = rate_limit.refill_rate_per_ms();
assert!((refill_rate - 0.01).abs() < 0.001);
let rate_limit = RateLimit::per_minute(60);
let refill_rate = rate_limit.refill_rate_per_ms();
assert!((refill_rate - 0.001).abs() < 0.0001);
}
#[tokio::test]
async fn test_rate_limiter_clone() {
let rate_limit = RateLimit::per_second(2);
let limiter1 = RateLimiter::new(rate_limit);
let limiter2 = limiter1.clone();
assert!(limiter1.try_acquire());
assert!(limiter2.try_acquire());
assert!(!limiter1.try_acquire());
assert!(!limiter2.try_acquire());
}
#[test]
fn test_throttle_config_defaults() {
let config = ThrottleConfig::default();
assert!(config.enabled);
assert!(config.max_concurrent.is_none());
assert!(config.rate_per_minute.is_none());
assert!(config.backoff_on_error.is_none());
}
#[test]
fn test_rate_limit_edge_cases() {
let rate_limit = RateLimit::per_second(0);
assert_eq!(rate_limit.rate, 0);
let rate_limit = RateLimit::per_second(1_000_000);
assert_eq!(rate_limit.rate, 1_000_000);
let rate_limit = RateLimit::per_hour(1);
assert_eq!(rate_limit.per, Duration::from_secs(3600));
}
#[tokio::test]
async fn test_token_bucket_time_until_token() {
let mut bucket = TokenBucket::new(1.0, 1.0);
assert_eq!(bucket.time_until_token(), Duration::from_millis(0));
assert!(bucket.try_consume(1.0));
let wait_time = bucket.time_until_token();
assert!(wait_time > Duration::from_millis(0));
assert!(wait_time <= Duration::from_millis(1));
}
}