use crate::email_address::EmailAddress;
use serde::de::{Deserialize, Deserializer, Error as DeError};
use std::{net::IpAddr, num::ParseIntError, str::FromStr, time::Duration};
use thiserror::Error;
#[derive(Debug, Error, Eq, PartialEq)]
pub enum LimitConfigError {
#[error("rate limit must contain a '/' fraction separator")]
NoSeparator,
#[error("rate limit window is missing a time unit")]
NoWindowUnit,
#[error("rate limit window has an invalid unit: {0}")]
InvalidUnit(String),
#[error("could not parse rate limit count as integer: {0}")]
InvalidCount(ParseIntError),
#[error("rate limit contains an invalid keyword: {0}")]
InvalidKeyword(String),
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
#[allow(clippy::struct_excessive_bools)]
pub struct LimitConfig {
pub id: usize,
pub with_email_addr: bool,
pub with_email_domain: bool,
pub with_origin: bool,
pub with_ip: bool,
pub extend_window: bool,
pub decr_complete: bool,
pub max_count: usize,
pub window: Duration,
}
impl FromStr for LimitConfig {
type Err = LimitConfigError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
let mut iter = value.rsplit(':');
let rate = iter.next().unwrap();
let rate_sep = rate.find('/').ok_or(LimitConfigError::NoSeparator)?;
let max_count = rate[..rate_sep]
.parse()
.map_err(LimitConfigError::InvalidCount)?;
let window_str = &rate[(rate_sep + 1)..];
let window_split = window_str
.find(|c: char| !c.is_digit(10))
.ok_or(LimitConfigError::NoWindowUnit)?;
let mut window = if window_split == 0 {
1
} else {
window_str[..window_split].parse().unwrap()
};
match &window_str[window_split..] {
"s" | "sec" | "secs" | "second" | "seconds" => {}
"m" | "min" | "mins" | "minute" | "minutes" => window *= 60,
"h" | "hour" | "hours" => window *= 3600,
"d" | "day" | "days" => window *= 86400,
unit => {
return Err(LimitConfigError::InvalidUnit(unit.to_owned()));
}
}
let mut config = LimitConfig {
id: 0,
with_email_addr: false,
with_email_domain: false,
with_origin: false,
with_ip: false,
extend_window: false,
decr_complete: false,
max_count,
window: Duration::from_secs(window),
};
for keyword in iter {
match keyword {
"ip" => config.with_ip = true,
"email" => config.with_email_addr = true,
"domain" => config.with_email_domain = true,
"origin" => config.with_origin = true,
"extend_window" => config.extend_window = true,
"decr_complete" => config.decr_complete = true,
_ => {
return Err(LimitConfigError::InvalidKeyword(keyword.to_owned()));
}
}
}
Ok(config)
}
}
serde_from_str!(LimitConfig);
pub struct LimitInput {
pub email_addr: EmailAddress,
pub origin: String,
pub ip: IpAddr,
}
impl LimitInput {
pub fn build_key(&self, config: &LimitConfig, prefix: &str, sep: &str) -> String {
let mut result = format!("{}{}", prefix, config.id);
if config.with_ip {
result.push_str(sep);
result.push_str(&format!("{}", self.ip));
}
if config.with_email_addr {
result.push_str(sep);
result.push_str(self.email_addr.as_str());
}
if config.with_email_domain {
result.push_str(sep);
result.push_str(self.email_addr.domain());
}
if config.with_origin {
result.push_str(sep);
result.push_str(&self.origin);
}
result
}
}
#[derive(Clone)]
pub struct LegacyLimitPerEmail(pub LimitConfig);
impl<'de> Deserialize<'de> for LegacyLimitPerEmail {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let input = format!("email:{}", String::deserialize(deserializer)?);
Ok(Self(input.parse().map_err(DeError::custom)?))
}
}
#[cfg(test)]
mod tests {
use super::LimitConfig;
use std::time::Duration;
#[test]
fn test_parse() {
assert_eq!(
"10/s".parse(),
Ok(LimitConfig {
max_count: 10,
window: Duration::from_secs(1),
..Default::default()
})
);
assert_eq!(
"email:decr_complete:11/2min".parse(),
Ok(LimitConfig {
with_email_addr: true,
decr_complete: true,
max_count: 11,
window: Duration::from_secs(120),
..Default::default()
})
);
assert_eq!(
"domain:30/h".parse(),
Ok(LimitConfig {
with_email_domain: true,
max_count: 30,
window: Duration::from_secs(3600),
..Default::default()
})
);
assert_eq!(
"origin:200/day".parse(),
Ok(LimitConfig {
with_origin: true,
max_count: 200,
window: Duration::from_secs(86400),
..Default::default()
})
);
assert_eq!(
"ip:extend_window:5/second".parse(),
Ok(LimitConfig {
with_ip: true,
extend_window: true,
max_count: 5,
window: Duration::from_secs(1),
..Default::default()
})
);
}
}