use super::backend::ThrottleBackend;
use super::{Throttle, ThrottleError, ThrottleResult};
use super::time_provider::{SystemTimeProvider, TimeProvider};
use async_trait::async_trait;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TimeRange {
pub start_hour: u8,
pub end_hour: u8,
}
impl TimeRange {
pub fn new(start_hour: u8, end_hour: u8) -> Self {
assert!(start_hour < 24, "start_hour must be 0-23");
assert!(end_hour < 24, "end_hour must be 0-23");
Self {
start_hour,
end_hour,
}
}
pub fn contains(&self, hour: u8) -> bool {
if self.start_hour <= self.end_hour {
hour >= self.start_hour && hour <= self.end_hour
} else {
hour >= self.start_hour || hour <= self.end_hour
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct TimeOfDayConfig {
pub peak_hours: TimeRange,
pub peak_rate: (usize, u64),
pub off_peak_rate: (usize, u64),
}
impl TimeOfDayConfig {
pub fn new(
peak_hours: TimeRange,
peak_rate: (usize, u64),
off_peak_rate: (usize, u64),
) -> Self {
Self {
peak_hours,
peak_rate,
off_peak_rate,
}
}
pub fn get_rate(&self, hour: u8) -> (usize, u64) {
if self.peak_hours.contains(hour) {
self.peak_rate
} else {
self.off_peak_rate
}
}
}
pub struct TimeOfDayThrottle<B: ThrottleBackend, T: TimeProvider = SystemTimeProvider> {
backend: Arc<B>,
config: TimeOfDayConfig,
time_provider: Arc<T>,
}
impl<B: ThrottleBackend> TimeOfDayThrottle<B, SystemTimeProvider> {
pub fn new(backend: Arc<B>, config: TimeOfDayConfig) -> Self {
Self {
backend,
config,
time_provider: Arc::new(SystemTimeProvider::new()),
}
}
}
impl<B: ThrottleBackend, T: TimeProvider> TimeOfDayThrottle<B, T> {
pub fn with_time_provider(
backend: Arc<B>,
config: TimeOfDayConfig,
time_provider: Arc<T>,
) -> Self {
Self {
backend,
config,
time_provider,
}
}
fn get_current_hour(&self) -> u8 {
let now = self.time_provider.now();
let duration_since_epoch = now.elapsed();
let total_hours = duration_since_epoch.as_secs() / 3600;
(total_hours % 24) as u8
}
async fn get_current_rate(&self) -> (usize, u64) {
let hour = self.get_current_hour();
self.config.get_rate(hour)
}
}
#[async_trait]
impl<B: ThrottleBackend, T: TimeProvider> Throttle for TimeOfDayThrottle<B, T> {
async fn allow_request(&self, key: &str) -> ThrottleResult<bool> {
let (rate, period) = self.get_current_rate().await;
let count = self
.backend
.increment(key, period)
.await
.map_err(ThrottleError::ThrottleError)?;
Ok(count <= rate)
}
async fn wait_time(&self, key: &str) -> ThrottleResult<Option<u64>> {
let (rate, period) = self.get_current_rate().await;
let count = self
.backend
.get_count(key)
.await
.map_err(ThrottleError::ThrottleError)?;
if count > rate {
Ok(Some(period))
} else {
Ok(None)
}
}
fn get_rate(&self) -> (usize, u64) {
self.config.peak_rate
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::throttling::backend::MemoryBackend;
use crate::throttling::time_provider::MockTimeProvider;
use tokio::time::Instant;
#[test]
fn test_time_range_normal() {
let range = TimeRange::new(9, 17);
assert!(range.contains(9));
assert!(range.contains(12));
assert!(range.contains(17));
assert!(!range.contains(8));
assert!(!range.contains(18));
}
#[test]
fn test_time_range_wrapping() {
let range = TimeRange::new(22, 2);
assert!(range.contains(22));
assert!(range.contains(23));
assert!(range.contains(0));
assert!(range.contains(1));
assert!(range.contains(2));
assert!(!range.contains(3));
assert!(!range.contains(21));
}
#[test]
fn test_time_of_day_config_get_rate() {
let config = TimeOfDayConfig::new(TimeRange::new(9, 17), (50, 60), (100, 60));
assert_eq!(config.get_rate(9), (50, 60));
assert_eq!(config.get_rate(12), (50, 60));
assert_eq!(config.get_rate(17), (50, 60));
assert_eq!(config.get_rate(8), (100, 60));
assert_eq!(config.get_rate(18), (100, 60));
assert_eq!(config.get_rate(0), (100, 60));
}
#[tokio::test]
async fn test_time_of_day_throttle_basic() {
let backend = Arc::new(MemoryBackend::new());
let config = TimeOfDayConfig::new(TimeRange::new(9, 17), (5, 60), (10, 60));
let throttle = TimeOfDayThrottle::new(backend, config);
let current_rate = throttle.get_current_rate().await;
let (limit, _) = current_rate;
for _ in 0..limit {
assert!(throttle.allow_request("test_key").await.unwrap());
}
assert!(!throttle.allow_request("test_key").await.unwrap());
}
#[tokio::test]
async fn test_time_of_day_throttle_with_mock_time() {
let time_provider = Arc::new(MockTimeProvider::new(Instant::now()));
let backend = Arc::new(MemoryBackend::with_time_provider(time_provider.clone()));
let config = TimeOfDayConfig::new(TimeRange::new(9, 17), (5, 60), (10, 60));
let throttle = TimeOfDayThrottle::with_time_provider(backend, config, time_provider);
let (limit, _) = throttle.get_current_rate().await;
for _ in 0..limit {
assert!(throttle.allow_request("test_key").await.unwrap());
}
assert!(!throttle.allow_request("test_key").await.unwrap());
}
#[tokio::test]
async fn test_time_of_day_throttle_get_rate() {
let backend = Arc::new(MemoryBackend::new());
let config = TimeOfDayConfig::new(TimeRange::new(9, 17), (50, 60), (100, 60));
let throttle = TimeOfDayThrottle::new(backend, config);
assert_eq!(throttle.get_rate(), (50, 60));
}
#[test]
#[should_panic(expected = "start_hour must be 0-23")]
fn test_time_range_invalid_start() {
TimeRange::new(24, 10);
}
#[test]
#[should_panic(expected = "end_hour must be 0-23")]
fn test_time_range_invalid_end() {
TimeRange::new(10, 24);
}
}