#![forbid(unsafe_code)]
use core::time::Duration;
use oorandom::Rand32;
use std::cmp::max;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::hash::Hash;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::time::Instant;
#[must_use]
const fn to_ipv4_mapped(addr: &Ipv6Addr) -> Option<Ipv4Addr> {
match addr.octets() {
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, a, b, c, d] => Some(Ipv4Addr::new(a, b, c, d)),
_ => None,
}
}
#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct IpAddrKey(Ipv6Addr);
impl IpAddrKey {
#[must_use]
pub fn new(ip_addr: IpAddr) -> Self {
match ip_addr {
IpAddr::V4(addr) => Self(addr.to_ipv6_mapped()),
IpAddr::V6(addr) => Self(addr),
}
}
}
impl From<IpAddr> for IpAddrKey {
fn from(addr: IpAddr) -> Self {
Self::new(addr)
}
}
impl From<Ipv4Addr> for IpAddrKey {
fn from(addr: Ipv4Addr) -> Self {
Self(addr.to_ipv6_mapped())
}
}
impl From<Ipv6Addr> for IpAddrKey {
fn from(addr: Ipv6Addr) -> Self {
Self(addr)
}
}
impl Display for IpAddrKey {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
if let Some(addr) = to_ipv4_mapped(&self.0) {
write!(f, "{}", addr)
} else {
write!(f, "{}", self.0)
}
}
}
trait SaturatingAddAssign<T> {
fn saturating_add_assign(&mut self, rhs: T);
}
impl SaturatingAddAssign<u32> for u32 {
fn saturating_add_assign(&mut self, rhs: u32) {
*self = self.saturating_add(rhs);
}
}
fn decide(recent_cost: u32, max_cost: u32, mut rand_float: impl FnMut() -> f32) -> bool {
let load = if max_cost == 0 || recent_cost >= max_cost {
return false;
} else {
f64::from(recent_cost) / f64::from(max_cost)
};
let linear_reject_prob = (load - 0.75) * 4.0;
if linear_reject_prob <= 0.0 {
return true;
}
let reject_prob = linear_reject_prob.powi(2);
reject_prob < rand_float().into()
}
#[cfg(test)]
#[test]
#[allow(clippy::unreadable_literal)]
fn test_decide() {
assert!(!decide(0, 0, || unreachable!()));
assert!(decide(0, 100, || unreachable!()));
assert!(decide(50, 100, || unreachable!()));
assert!(decide(75, 100, || unreachable!()));
assert!(decide(76, 100, || 0.999999));
assert!(!decide(76, 100, || 0.0));
assert!(!decide(85, 100, || 0.15));
assert!(decide(85, 100, || 0.17));
assert!(!decide(90, 100, || 0.35));
assert!(decide(90, 100, || 0.37));
assert!(!decide(95, 100, || 0.63));
assert!(decide(95, 100, || 0.65));
assert!(!decide(99, 100, || 0.92));
assert!(decide(99, 100, || 0.93));
assert!(!decide(100, 100, || unreachable!()));
assert!(!decide(101, 100, || unreachable!()));
}
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
fn max_cost(sources_max: u32, recent_cost: u32, keys: u32) -> u32 {
if sources_max < 1 {
return 0;
}
let load = f64::from(recent_cost) / f64::from(sources_max);
if keys < 1 {
sources_max
} else if load > 1.0 {
(f64::from(sources_max) / f64::from(keys)) as u32
} else if load > 0.75 {
let x = (load - 0.75) * 4.0;
(f64::from(sources_max) * (1.0 - (1.0 - 1.0 / f64::from(keys)) * x)) as u32
} else {
sources_max
}
}
#[cfg(test)]
#[test]
fn test_max_cost() {
assert_eq!(100, max_cost(100, 0, 0));
assert_eq!(100, max_cost(100, 0, 1));
assert_eq!(100, max_cost(100, 1, 1));
assert_eq!(100, max_cost(100, 100, 1));
assert_eq!(100, max_cost(100, 0, 2));
assert_eq!(100, max_cost(100, 75, 2));
assert_eq!(98, max_cost(100, 76, 2));
assert_eq!(70, max_cost(100, 90, 2));
assert_eq!(52, max_cost(100, 99, 2));
assert_eq!(50, max_cost(100, 100, 2));
assert_eq!(50, max_cost(100, 150, 2));
assert_eq!(1000, max_cost(1000, 0, 10));
assert_eq!(1000, max_cost(1000, 750, 10));
assert_eq!(996, max_cost(1000, 751, 10));
assert_eq!(459, max_cost(1000, 900, 10));
assert_eq!(103, max_cost(1000, 999, 10));
assert_eq!(99, max_cost(1000, 1000, 10));
assert_eq!(100, max_cost(1000, 1500, 10));
}
#[derive(Clone, Copy, Debug)]
struct RecentCosts {
cost: u32,
last: Instant,
}
impl RecentCosts {
#[must_use]
pub fn new(now: Instant) -> Self {
Self {
cost: 0_u32,
last: now,
}
}
pub fn is_empty(&self) -> bool {
self.cost == 0
}
pub fn add(&mut self, cost: u32) {
self.cost.saturating_add_assign(cost);
}
pub fn update(&mut self, tick_duration: Duration, now: Instant) {
let elapsed = now.saturating_duration_since(self.last);
#[allow(clippy::cast_possible_truncation)]
let elapsed_ticks = (elapsed.as_millis() / tick_duration.as_millis()) as u32;
self.last += tick_duration * elapsed_ticks;
self.cost = self.cost.wrapping_shr(elapsed_ticks);
}
#[must_use]
pub fn recent_cost(&self) -> u32 {
self.cost
}
}
#[derive(Clone, Copy, Debug)]
struct Source<K> {
pub key: K,
pub costs: RecentCosts,
}
impl<K> Source<K> {
pub fn new(key: K, now: Instant) -> Self {
Self {
key,
costs: RecentCosts::new(now),
}
}
}
#[derive(Clone, Debug)]
pub struct FairRateLimiter<K: Clone + Copy + Eq + Hash, const MAX_KEYS: usize> {
tick_duration: Duration,
sources_max: u32,
other_max: u32,
prng: Rand32,
sources_costs: RecentCosts,
keys: HashMap<K, usize>,
sources: Box<[Option<Source<K>>]>,
other_costs: RecentCosts,
}
impl<Key: Clone + Copy + Eq + Hash, const MAX_KEYS: usize> FairRateLimiter<Key, MAX_KEYS> {
pub fn new(
tick_duration: Duration,
max_cost_per_tick_from_tracked_sources: u32,
max_cost_per_tick_from_untracked_sources: u32,
prng: Rand32,
now: Instant,
) -> Result<Self, String> {
if tick_duration.as_micros() == 0 {
return Err(format!("tick_duration too small: {:?}", tick_duration));
}
Ok(Self {
tick_duration,
sources_max: max_cost_per_tick_from_tracked_sources * 2,
other_max: max_cost_per_tick_from_untracked_sources * 2,
prng,
sources_costs: RecentCosts::new(now),
keys: HashMap::with_capacity(MAX_KEYS),
sources: vec![None; MAX_KEYS].into_boxed_slice(),
other_costs: RecentCosts::new(now),
})
}
#[allow(clippy::missing_panics_doc)]
pub fn check(&mut self, key: Key, cost: u32, now: Instant) -> bool {
self.sources_costs.update(self.tick_duration, now);
#[allow(clippy::cast_possible_truncation)]
let num_keys = self.keys.len() as u32;
match self.keys.entry(key) {
Entry::Occupied(entry) => {
let index = *entry.get();
let source = self.sources[index].as_mut().unwrap();
source.costs.update(self.tick_duration, now);
let max_cost =
max_cost(self.sources_max, self.sources_costs.recent_cost(), num_keys);
let rand_float = || self.prng.rand_float();
if decide(source.costs.recent_cost(), max_cost, rand_float) {
self.sources_costs.add(cost);
source.costs.add(cost);
true
} else {
if source.costs.is_empty() {
entry.remove();
self.sources[index] = None;
}
false
}
}
Entry::Vacant(entry) => {
self.other_costs.update(self.tick_duration, now);
let recent_cost = self.other_costs.recent_cost();
if !decide(recent_cost, self.other_max, || self.prng.rand_float()) {
return false;
}
let mut new_source = Source::new(*entry.key(), now);
new_source.costs.add(cost);
#[allow(clippy::cast_possible_truncation)]
let index = self.prng.rand_range(0..(MAX_KEYS as u32)) as usize;
if let Some(source) = &mut self.sources[index] {
source.costs.update(self.tick_duration, now);
let coefficient: u32 = match self.prng.rand_range(0..10_000u32) {
0 => 10_000,
x if x < 10 => 1_000,
x if x < 100 => 100,
x if x < 1000 => 10,
_ => 1,
};
let adjusted_cost = coefficient.saturating_mul(cost);
if adjusted_cost < source.costs.recent_cost() {
self.other_costs.add(cost);
true
} else {
entry.insert(index);
self.keys.remove(&source.key);
*source = new_source;
self.sources_costs.add(cost);
true
}
} else {
self.sources[index] = Some(new_source);
entry.insert(index);
self.sources_costs.add(cost);
true
}
}
}
}
}
pub fn new_fair_ip_address_rate_limiter(
max_cost_per_sec: f32,
) -> Result<FairRateLimiter<IpAddrKey, 1000>, String> {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let other_max = max((max_cost_per_sec * 0.20) as u32, 1);
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let sources_max = (max_cost_per_sec as u32).saturating_sub(other_max);
if max_cost_per_sec != 0.0 && sources_max == 0 {
return Err(format!(
"max_cost_per_sec is too small: {:?}",
max_cost_per_sec
));
}
FairRateLimiter::new(
Duration::from_secs(1),
sources_max,
other_max,
Rand32::new(0),
Instant::now(),
)
}