use std::fmt;
use std::str::FromStr;
use std::{num::NonZeroU32, time::Duration};
use super::{RateUnit, TimeUnit};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct QuotaValue {
pub rate: NonZeroU32,
pub rate_unit: RateUnit,
pub time_unit: TimeUnit,
}
impl From<QuotaValue> for governor::Quota {
fn from(value: QuotaValue) -> Self {
let rate_count = value.rate.get();
let rate_unit = value.rate_unit.multiplier().get();
let rate = NonZeroU32::new(rate_count * rate_unit)
.expect("Is always non-zero because rate count and rate unit multiplier are non-zero");
let time_unit = Duration::from_secs(value.time_unit.multiplier_in_seconds().get() as u64);
let replenish_1_per = time_unit / rate.get();
let base_quota = governor::Quota::with_period(replenish_1_per)
.expect("Is always non-zero because replenish_1_per is non-zero");
base_quota.allow_burst(rate)
}
}
impl fmt::Display for QuotaValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}{}/{}", self.rate, self.rate_unit, self.time_unit)
}
}
impl FromStr for QuotaValue {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let rate_parts: Vec<&str> = s.split('/').collect();
if rate_parts.len() != 2 {
return Err(format!(
"Invalid rate format: '{}', expected {{rate}}{{unit}}/{{time}}",
s
));
}
let rate_with_unit = rate_parts[0];
let time_unit = TimeUnit::from_str(rate_parts[1])?;
let rate_digit_end = rate_with_unit
.chars()
.position(|c| !c.is_ascii_digit())
.unwrap_or(rate_with_unit.len());
if rate_digit_end == 0 {
return Err(format!("Missing rate value in '{}'", rate_with_unit));
}
let rate_str = &rate_with_unit[..rate_digit_end];
let rate_unit_str = &rate_with_unit[rate_digit_end..];
let rate = rate_str
.parse::<NonZeroU32>()
.map_err(|_| format!("Failed to parse rate from '{}'", rate_str))?;
let rate_unit = RateUnit::from_str(rate_unit_str)?;
Ok(QuotaValue {
rate,
rate_unit,
time_unit,
})
}
}
impl serde::Serialize for QuotaValue {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> serde::Deserialize<'de> for QuotaValue {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
QuotaValue::from_str(&s).map_err(serde::de::Error::custom)
}
}
#[cfg(test)]
mod tests {
use crate::quota_config::rate_unit::SpeedRateUnit;
use super::*;
#[test]
fn test_get_quota() {
let quota = QuotaValue::from_str("5r/s").unwrap();
assert_eq!(
governor::Quota::from(quota),
governor::Quota::per_second(NonZeroU32::new(5).unwrap())
);
let quota = QuotaValue::from_str("5r/m").unwrap();
assert_eq!(
governor::Quota::from(quota),
governor::Quota::per_minute(NonZeroU32::new(5).unwrap())
);
let quota = QuotaValue::from_str("5kb/s").unwrap();
assert_eq!(
governor::Quota::from(quota),
governor::Quota::per_second(NonZeroU32::new(5).unwrap())
);
let quota = QuotaValue::from_str("5mb/m").unwrap();
assert_eq!(
governor::Quota::from(quota),
governor::Quota::per_minute(NonZeroU32::new(5 * 1024).unwrap())
);
}
#[test]
fn test_quota_value_from_str() {
let quota = QuotaValue::from_str("5r/s").unwrap();
assert_eq!(quota.rate, NonZeroU32::new(5).unwrap());
assert_eq!(quota.rate_unit, RateUnit::Request);
assert_eq!(quota.time_unit, TimeUnit::Second);
let quota = QuotaValue::from_str("10mb/m").unwrap();
assert_eq!(quota.rate, NonZeroU32::new(10).unwrap());
assert_eq!(
quota.rate_unit,
RateUnit::SpeedRateUnit(SpeedRateUnit::Megabyte)
);
assert_eq!(quota.time_unit, TimeUnit::Minute);
}
#[test]
fn test_quota_value_display() {
let quota = QuotaValue {
rate: NonZeroU32::new(5).unwrap(),
rate_unit: RateUnit::Request,
time_unit: TimeUnit::Second,
};
assert_eq!(quota.to_string(), "5r/s");
let quota = QuotaValue {
rate: NonZeroU32::new(10).unwrap(),
rate_unit: RateUnit::SpeedRateUnit(SpeedRateUnit::Megabyte),
time_unit: TimeUnit::Minute,
};
assert_eq!(quota.to_string(), "10mb/m");
}
#[test]
fn test_quota_value_invalid_formats() {
assert!(QuotaValue::from_str("5rs").is_err());
assert!(QuotaValue::from_str("r/s").is_err());
assert!(QuotaValue::from_str("5x/s").is_err());
assert!(QuotaValue::from_str("5r/x").is_err());
assert!(QuotaValue::from_str("5r/s-2burst").is_err());
}
}