use std::fmt;
use std::num::NonZeroU32;
use std::str::FromStr;
use std::time::Duration;
use super::rate_unit::SpeedRateUnit;
use super::TimeUnit;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BandwidthQuota {
pub rate: NonZeroU32,
pub unit: SpeedRateUnit,
pub time_unit: TimeUnit,
}
impl BandwidthQuota {
pub fn to_governor_quota(&self, burst_override: Option<NonZeroU32>) -> governor::Quota {
let rate_unit_mult = self.unit.multiplier().get();
let rate_cells = NonZeroU32::new(self.rate.get() * rate_unit_mult)
.expect("always non-zero: rate and multiplier are non-zero");
let replenish_1_per = Duration::from(self.time_unit) / rate_cells.get();
let base = governor::Quota::with_period(replenish_1_per)
.expect("always non-zero: replenish_1_per is non-zero");
match burst_override {
Some(b) => {
let burst_cells =
NonZeroU32::new(b.get().saturating_mul(rate_unit_mult)).unwrap_or(rate_cells);
base.allow_burst(burst_cells)
}
None => base.allow_burst(rate_cells),
}
}
}
impl From<BandwidthQuota> for governor::Quota {
fn from(value: BandwidthQuota) -> Self {
value.to_governor_quota(None)
}
}
impl fmt::Display for BandwidthQuota {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}{}/{}", self.rate, self.unit, self.time_unit)
}
}
impl FromStr for BandwidthQuota {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let parts: Vec<&str> = s.split('/').collect();
if parts.len() != 2 {
return Err(format!(
"Invalid bandwidth rate format: '{s}', expected {{rate}}{{unit}}/{{time}}"
));
}
let rate_with_unit = parts[0];
let time_unit = TimeUnit::from_str(parts[1])?;
let digit_end = rate_with_unit
.chars()
.position(|c| !c.is_ascii_digit())
.unwrap_or(rate_with_unit.len());
if digit_end == 0 {
return Err(format!("Missing rate value in '{rate_with_unit}'"));
}
let rate_str = &rate_with_unit[..digit_end];
let unit_str = &rate_with_unit[digit_end..];
let rate = rate_str
.parse::<NonZeroU32>()
.map_err(|_| format!("Failed to parse rate from '{rate_str}'"))?;
let unit = SpeedRateUnit::from_str(unit_str).map_err(|_| {
format!(
"BandwidthQuota does not accept '{unit_str}' as a unit. \
Use a bandwidth unit (kb, mb, gb), e.g. \"10mb/s\"."
)
})?;
Ok(BandwidthQuota {
rate,
unit,
time_unit,
})
}
}
impl serde::Serialize for BandwidthQuota {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> serde::Deserialize<'de> for BandwidthQuota {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
BandwidthQuota::from_str(&s).map_err(serde::de::Error::custom)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_bandwidth_rates() {
let b = BandwidthQuota::from_str("10mb/s").unwrap();
assert_eq!(b.to_string(), "10mb/s");
let b = BandwidthQuota::from_str("1gb/s").unwrap();
assert_eq!(b.to_string(), "1gb/s");
let b = BandwidthQuota::from_str("500kb/s").unwrap();
assert_eq!(b.to_string(), "500kb/s");
let b = BandwidthQuota::from_str("500mb/d").unwrap();
assert_eq!(b.to_string(), "500mb/d");
}
#[test]
fn test_request_units_rejected() {
let err = BandwidthQuota::from_str("100r/m").unwrap_err();
assert!(err.contains("does not accept"), "error: {err}");
}
#[test]
fn test_display_roundtrip() {
let b = BandwidthQuota::from_str("10mb/s").unwrap();
assert_eq!(b.to_string(), "10mb/s");
let b2: BandwidthQuota = b.to_string().parse().unwrap();
assert_eq!(b, b2);
}
#[test]
fn test_serde_roundtrip() {
let b = BandwidthQuota::from_str("10mb/s").unwrap();
let json = serde_json::to_string(&b).unwrap();
assert_eq!(json, "\"10mb/s\"");
let b2: BandwidthQuota = serde_json::from_str(&json).unwrap();
assert_eq!(b, b2);
}
#[test]
fn test_serde_rejects_request_units() {
let result: Result<BandwidthQuota, _> = serde_json::from_str("\"100r/m\"");
assert!(result.is_err());
}
#[test]
fn test_converts_to_governor_quota() {
let b = BandwidthQuota::from_str("10mb/s").unwrap();
let _quota: governor::Quota = b.into();
}
}