use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::error::{ConfigError, Result};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Quota {
max_requests: u64,
window: Duration,
burst: Option<u64>,
refill_rate: Option<f64>,
}
impl Quota {
pub fn new(max_requests: u64, window: Duration) -> Self {
assert!(max_requests > 0, "max_requests must be greater than 0");
assert!(!window.is_zero(), "window must be non-zero");
Self {
max_requests,
window,
burst: None,
refill_rate: None,
}
}
pub fn per_second(n: u64) -> Self {
Self::new(n, Duration::from_secs(1))
}
pub fn per_minute(n: u64) -> Self {
Self::new(n, Duration::from_secs(60))
}
pub fn per_hour(n: u64) -> Self {
Self::new(n, Duration::from_secs(3600))
}
pub fn per_day(n: u64) -> Self {
Self::new(n, Duration::from_secs(86400))
}
pub fn simple(period: Duration) -> Self {
Self::new(1, period)
}
pub fn with_period_and_burst(period: Duration, burst: u64) -> Self {
Self::new(1, period).with_burst(burst)
}
pub fn try_new(max_requests: u64, window: Duration) -> Result<Self> {
if max_requests == 0 {
return Err(ConfigError::InvalidQuota("max_requests must be greater than 0".into()).into());
}
if window.is_zero() {
return Err(ConfigError::InvalidQuota("window must be non-zero".into()).into());
}
Ok(Self {
max_requests,
window,
burst: None,
refill_rate: None,
})
}
pub fn with_burst(mut self, burst: u64) -> Self {
self.burst = Some(burst);
self
}
pub fn with_refill_rate(mut self, rate: f64) -> Self {
self.refill_rate = Some(rate);
self
}
pub fn max_requests(&self) -> u64 {
self.max_requests
}
pub fn window(&self) -> Duration {
self.window
}
pub fn effective_burst(&self) -> u64 {
self.burst.unwrap_or(self.max_requests)
}
pub fn effective_refill_rate(&self) -> f64 {
self.refill_rate.unwrap_or_else(|| {
self.max_requests as f64 / self.window.as_secs_f64()
})
}
pub fn period(&self) -> Duration {
Duration::from_secs_f64(self.window.as_secs_f64() / self.max_requests as f64)
}
pub fn max_tat_offset(&self) -> Duration {
let burst = self.effective_burst();
Duration::from_secs_f64(self.period().as_secs_f64() * (burst - 1) as f64)
}
pub fn full_replenish_time(&self) -> Duration {
self.window
}
}
impl Default for Quota {
fn default() -> Self {
Self::per_minute(60)
}
}
#[derive(Debug, Default)]
pub struct QuotaBuilder {
max_requests: Option<u64>,
window: Option<Duration>,
burst: Option<u64>,
refill_rate: Option<f64>,
}
impl QuotaBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn max_requests(mut self, n: u64) -> Self {
self.max_requests = Some(n);
self
}
pub fn window(mut self, duration: Duration) -> Self {
self.window = Some(duration);
self
}
pub fn burst(mut self, n: u64) -> Self {
self.burst = Some(n);
self
}
pub fn refill_rate(mut self, rate: f64) -> Self {
self.refill_rate = Some(rate);
self
}
pub fn build(self) -> Result<Quota> {
let max_requests = self.max_requests
.ok_or_else(|| ConfigError::MissingRequired("max_requests".into()))?;
let window = self.window
.ok_or_else(|| ConfigError::MissingRequired("window".into()))?;
let mut quota = Quota::try_new(max_requests, window)?;
if let Some(burst) = self.burst {
quota = quota.with_burst(burst);
}
if let Some(rate) = self.refill_rate {
quota = quota.with_refill_rate(rate);
}
Ok(quota)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quota_per_second() {
let quota = Quota::per_second(10);
assert_eq!(quota.max_requests(), 10);
assert_eq!(quota.window(), Duration::from_secs(1));
assert_eq!(quota.effective_burst(), 10);
assert!((quota.effective_refill_rate() - 10.0).abs() < 0.001);
}
#[test]
fn test_quota_per_minute() {
let quota = Quota::per_minute(60);
assert_eq!(quota.max_requests(), 60);
assert_eq!(quota.window(), Duration::from_secs(60));
assert!((quota.effective_refill_rate() - 1.0).abs() < 0.001);
}
#[test]
fn test_quota_with_burst() {
let quota = Quota::per_minute(60).with_burst(100);
assert_eq!(quota.max_requests(), 60);
assert_eq!(quota.effective_burst(), 100);
}
#[test]
fn test_quota_burst_smaller_than_max() {
let quota = Quota::per_minute(60).with_burst(30);
assert_eq!(quota.effective_burst(), 30);
}
#[test]
fn test_quota_simple() {
let quota = Quota::simple(Duration::from_millis(100));
assert_eq!(quota.max_requests(), 1);
assert_eq!(quota.window(), Duration::from_millis(100));
assert_eq!(quota.period(), Duration::from_millis(100));
}
#[test]
fn test_quota_gcra_period() {
let quota = Quota::per_second(10);
assert_eq!(quota.period(), Duration::from_millis(100));
}
#[test]
fn test_quota_max_tat_offset() {
let quota = Quota::per_second(1).with_burst(5);
let offset = quota.max_tat_offset();
assert_eq!(offset, Duration::from_secs(4));
}
#[test]
fn test_quota_builder() {
let quota = QuotaBuilder::new()
.max_requests(100)
.window(Duration::from_secs(60))
.burst(150)
.build()
.unwrap();
assert_eq!(quota.max_requests(), 100);
assert_eq!(quota.window(), Duration::from_secs(60));
assert_eq!(quota.effective_burst(), 150);
}
#[test]
fn test_quota_builder_missing_fields() {
let result = QuotaBuilder::new()
.max_requests(100)
.build();
assert!(result.is_err());
let result = QuotaBuilder::new()
.window(Duration::from_secs(60))
.build();
assert!(result.is_err());
}
#[test]
#[should_panic]
fn test_quota_zero_requests_panics() {
Quota::new(0, Duration::from_secs(60));
}
#[test]
#[should_panic]
fn test_quota_zero_window_panics() {
Quota::new(100, Duration::ZERO);
}
}