use serde::{Deserialize, Serialize};
fn default_rpm() -> u32 {
1000
}
fn default_tpm() -> u32 {
100_000
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default, alias = "algorithm")]
pub strategy: RateLimitStrategy,
#[serde(default = "default_rpm")]
pub default_rpm: u32,
#[serde(default = "default_tpm")]
pub default_tpm: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub requests_per_second: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub requests_per_minute: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tokens_per_minute: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub burst_size: Option<u32>,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
enabled: false,
strategy: RateLimitStrategy::default(),
default_rpm: default_rpm(),
default_tpm: default_tpm(),
requests_per_second: None,
requests_per_minute: None,
tokens_per_minute: None,
burst_size: None,
}
}
}
impl RateLimitConfig {
pub fn merge(mut self, other: Self) -> Self {
if other.enabled {
self.enabled = other.enabled;
}
if other.default_rpm != default_rpm() {
self.default_rpm = other.default_rpm;
}
if other.default_tpm != default_tpm() {
self.default_tpm = other.default_tpm;
}
self.strategy = other.strategy;
if other.requests_per_second.is_some() {
self.requests_per_second = other.requests_per_second;
}
if other.requests_per_minute.is_some() {
self.requests_per_minute = other.requests_per_minute;
}
if other.tokens_per_minute.is_some() {
self.tokens_per_minute = other.tokens_per_minute;
}
if other.burst_size.is_some() {
self.burst_size = other.burst_size;
}
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum RateLimitStrategy {
#[default]
TokenBucket,
FixedWindow,
SlidingWindow,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limit_strategy_default() {
let strategy = RateLimitStrategy::default();
assert_eq!(strategy, RateLimitStrategy::TokenBucket);
}
#[test]
fn test_rate_limit_strategy_serialization() {
assert_eq!(
serde_json::to_string(&RateLimitStrategy::TokenBucket).unwrap(),
"\"token_bucket\""
);
assert_eq!(
serde_json::to_string(&RateLimitStrategy::SlidingWindow).unwrap(),
"\"sliding_window\""
);
assert_eq!(
serde_json::to_string(&RateLimitStrategy::FixedWindow).unwrap(),
"\"fixed_window\""
);
}
#[test]
fn test_rate_limit_strategy_deserialization() {
let tb: RateLimitStrategy = serde_json::from_str("\"token_bucket\"").unwrap();
assert_eq!(tb, RateLimitStrategy::TokenBucket);
let sw: RateLimitStrategy = serde_json::from_str("\"sliding_window\"").unwrap();
assert_eq!(sw, RateLimitStrategy::SlidingWindow);
let fw: RateLimitStrategy = serde_json::from_str("\"fixed_window\"").unwrap();
assert_eq!(fw, RateLimitStrategy::FixedWindow);
}
#[test]
fn test_rate_limit_config_default() {
let config = RateLimitConfig::default();
assert!(!config.enabled);
assert_eq!(config.default_rpm, 1000);
assert_eq!(config.default_tpm, 100_000);
assert_eq!(config.strategy, RateLimitStrategy::TokenBucket);
assert!(config.requests_per_second.is_none());
assert!(config.burst_size.is_none());
}
#[test]
fn test_rate_limit_config_full() {
let config = RateLimitConfig {
enabled: true,
strategy: RateLimitStrategy::SlidingWindow,
default_rpm: 500,
default_tpm: 50_000,
requests_per_second: Some(10),
requests_per_minute: None,
tokens_per_minute: Some(100_000),
burst_size: Some(20),
};
assert!(config.enabled);
assert_eq!(config.requests_per_second, Some(10));
assert_eq!(config.burst_size, Some(20));
}
#[test]
fn test_rate_limit_config_serialization() {
let config = RateLimitConfig {
enabled: true,
strategy: RateLimitStrategy::FixedWindow,
default_rpm: 600,
default_tpm: 60_000,
requests_per_second: Some(50),
requests_per_minute: None,
tokens_per_minute: None,
burst_size: None,
};
let json = serde_json::to_value(&config).unwrap();
assert_eq!(json["enabled"], true);
assert_eq!(json["strategy"], "fixed_window");
assert_eq!(json["default_rpm"], 600);
assert_eq!(json["requests_per_second"], 50);
}
#[test]
fn test_rate_limit_config_deserialization_defaults() {
let json = r#"{}"#;
let config: RateLimitConfig = serde_json::from_str(json).unwrap();
assert!(!config.enabled);
assert_eq!(config.default_rpm, 1000);
assert_eq!(config.default_tpm, 100_000);
}
#[test]
fn test_rate_limit_config_deserialization_algorithm_alias() {
let json = r#"{"algorithm": "sliding_window"}"#;
let config: RateLimitConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.strategy, RateLimitStrategy::SlidingWindow);
}
#[test]
fn test_rate_limit_config_merge() {
let base = RateLimitConfig::default();
let other = RateLimitConfig {
enabled: true,
strategy: RateLimitStrategy::SlidingWindow,
default_rpm: 500,
default_tpm: 100_000,
requests_per_second: Some(10),
requests_per_minute: None,
tokens_per_minute: None,
burst_size: Some(20),
};
let merged = base.merge(other);
assert!(merged.enabled);
assert_eq!(merged.default_rpm, 500);
assert_eq!(merged.strategy, RateLimitStrategy::SlidingWindow);
assert_eq!(merged.requests_per_second, Some(10));
assert_eq!(merged.burst_size, Some(20));
}
#[test]
fn test_rate_limit_config_merge_no_change() {
let base = RateLimitConfig::default();
let other = RateLimitConfig::default();
let merged = base.merge(other);
assert!(!merged.enabled);
assert_eq!(merged.default_rpm, 1000);
assert_eq!(merged.default_tpm, 100_000);
}
#[test]
fn test_rate_limit_config_skip_none_serialization() {
let config = RateLimitConfig {
enabled: false,
strategy: RateLimitStrategy::TokenBucket,
default_rpm: 1000,
default_tpm: 100_000,
requests_per_second: None,
requests_per_minute: Some(1000),
tokens_per_minute: None,
burst_size: None,
};
let json = serde_json::to_value(&config).unwrap();
let obj = json.as_object().unwrap();
assert!(!obj.contains_key("requests_per_second"));
assert!(!obj.contains_key("tokens_per_minute"));
assert!(!obj.contains_key("burst_size"));
assert!(obj.contains_key("requests_per_minute"));
}
}