#![forbid(unsafe_code)]
use std::{
sync::atomic::{AtomicI32, AtomicU32, AtomicU64, Ordering},
time::Instant,
};
use crate::hash::SydRandomState;
struct RlimitState {
begin: Instant,
n_left: AtomicI32,
missed: AtomicU32,
}
pub(crate) struct LogRlimit {
map: scc::HashMap<u64, RlimitState, SydRandomState>,
interval_ns: AtomicU64,
burst: AtomicU32,
}
impl LogRlimit {
pub(crate) fn new(interval_ns: u64, burst: u32) -> Self {
Self {
map: scc::HashMap::with_hasher(SydRandomState::new()),
interval_ns: AtomicU64::new(interval_ns),
burst: AtomicU32::new(burst),
}
}
fn clamp_ns(ns: u128) -> u64 {
u64::try_from(ns).unwrap_or(u64::MAX)
}
pub(crate) fn should_log(&self, msg_hash: u64) -> (bool, u32) {
let interval_ns = self.interval_ns.load(Ordering::Acquire);
if interval_ns == 0 {
return (true, 0);
}
let burst = self.burst.load(Ordering::Acquire);
if burst == 0 {
return (false, 0);
}
let burst = i32::try_from(burst).unwrap_or(i32::MAX);
let now = Instant::now();
let fast = self.map.read_sync(&msg_hash, |_, state| {
let elapsed_ns = Self::clamp_ns(now.duration_since(state.begin).as_nanos());
if elapsed_ns > interval_ns {
return None;
}
if state
.n_left
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |n| {
n.checked_sub(1).filter(|v| *v >= 0)
})
.is_ok()
{
return Some((true, 0));
}
state.missed.fetch_add(1, Ordering::Release);
Some((false, 0))
});
if let Some(Some(result)) = fast {
return result;
}
let _guard = if let Some(guard) = self.map.reserve(1) {
guard
} else {
return (false, 0);
};
let entry = self.map.entry_sync(msg_hash);
match entry {
scc::hash_map::Entry::Occupied(mut occ) => {
let state = occ.get_mut();
let elapsed_ns = Self::clamp_ns(now.duration_since(state.begin).as_nanos());
if elapsed_ns > interval_ns {
let missed = state.missed.swap(0, Ordering::AcqRel);
state.begin = now;
state
.n_left
.store(burst.saturating_sub(1), Ordering::Release);
(true, missed)
} else if state
.n_left
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |n| {
n.checked_sub(1).filter(|v| *v >= 0)
})
.is_ok()
{
(true, 0)
} else {
state.missed.fetch_add(1, Ordering::Release);
(false, 0)
}
}
scc::hash_map::Entry::Vacant(vac) => {
vac.insert_entry(RlimitState {
begin: now,
n_left: AtomicI32::new(burst.saturating_sub(1)),
missed: AtomicU32::new(0),
});
(true, 0)
}
}
}
pub(crate) fn set_interval_ns(&self, ns: u64) {
self.interval_ns.store(ns, Ordering::Release);
let burst = i32::try_from(self.burst.load(Ordering::Acquire)).unwrap_or(i32::MAX);
self.map.retain_sync(|_, state| {
state.begin = Instant::now();
state.n_left.store(burst, Ordering::Release);
true
});
}
pub(crate) fn set_burst(&self, burst: u32) {
self.burst.store(burst, Ordering::Release);
let burst = i32::try_from(burst).unwrap_or(i32::MAX);
self.map.retain_sync(|_, state| {
state.n_left.store(burst, Ordering::Release);
state.missed.store(0, Ordering::Release);
true
});
}
}
#[cfg(test)]
mod tests {
use std::{
sync::{Arc, Barrier},
thread,
time::Duration,
};
use super::*;
#[test]
fn test_log_rlimit_1() {
let rl = LogRlimit::new(5_000_000_000, 5);
for _ in 0..5 {
let (allowed, missed) = rl.should_log(42);
assert!(allowed);
assert_eq!(missed, 0);
}
}
#[test]
fn test_log_rlimit_2() {
let rl = LogRlimit::new(5_000_000_000, 3);
for _ in 0..3 {
assert!(rl.should_log(42).0);
}
assert!(!rl.should_log(42).0);
}
#[test]
fn test_log_rlimit_3() {
let rl = LogRlimit::new(10_000_000, 2);
assert!(rl.should_log(42).0);
assert!(rl.should_log(42).0);
assert!(!rl.should_log(42).0);
assert!(!rl.should_log(42).0);
thread::sleep(Duration::from_millis(20));
let (allowed, missed) = rl.should_log(42);
assert!(allowed);
assert_eq!(missed, 2);
}
#[test]
fn test_log_rlimit_4() {
let rl = LogRlimit::new(5_000_000_000, 2);
assert!(rl.should_log(42).0);
assert!(rl.should_log(42).0);
assert!(!rl.should_log(42).0);
assert!(rl.should_log(99).0);
assert!(rl.should_log(99).0);
assert!(!rl.should_log(99).0);
}
#[test]
fn test_log_rlimit_5() {
let rl = LogRlimit::new(5_000_000_000, 1);
for h in 0u64..100 {
assert!(rl.should_log(h).0);
}
for h in 0u64..100 {
assert!(!rl.should_log(h).0);
}
}
#[test]
fn test_log_rlimit_6() {
let rl = LogRlimit::new(0, 10);
for _ in 0..100 {
assert!(rl.should_log(42).0);
}
}
#[test]
fn test_log_rlimit_7() {
let rl = LogRlimit::new(5_000_000_000, 0);
assert!(!rl.should_log(42).0);
}
#[test]
fn test_log_rlimit_8() {
let rl = LogRlimit::new(5_000_000_000, 2);
assert!(rl.should_log(42).0);
assert!(rl.should_log(42).0);
assert!(!rl.should_log(42).0);
rl.set_burst(100);
rl.set_interval_ns(0);
assert!(rl.should_log(42).0);
}
#[test]
fn test_log_rlimit_9() {
let rl = LogRlimit::new(10_000_000, 1);
assert!(rl.should_log(42).0);
for _ in 0..5 {
assert!(!rl.should_log(42).0);
}
thread::sleep(Duration::from_millis(20));
let (allowed, missed) = rl.should_log(42);
assert!(allowed);
assert_eq!(missed, 5);
thread::sleep(Duration::from_millis(20));
let (allowed, missed) = rl.should_log(42);
assert!(allowed);
assert_eq!(missed, 0);
}
#[test]
fn test_log_rlimit_10() {
const BURST: u32 = 5;
const THREADS: usize = 32;
const ITERS: usize = 200;
let rl = Arc::new(LogRlimit::new(60_000_000_000, BURST));
let barrier = Arc::new(Barrier::new(THREADS));
let handles: Vec<_> = (0..THREADS)
.map(|_| {
let rl = Arc::clone(&rl);
let barrier = Arc::clone(&barrier);
thread::spawn(move || {
barrier.wait();
let mut allowed = 0i64;
for _ in 0..ITERS {
if rl.should_log(42).0 {
allowed = allowed.saturating_add(1);
}
}
allowed
})
})
.collect();
let total_allowed: i64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
assert!(total_allowed <= i64::from(BURST));
assert!(total_allowed >= 1);
}
#[test]
fn test_log_rlimit_11() {
const BURST: u32 = 3;
const THREADS_PER_HASH: usize = 16;
const HASHES: usize = 4;
const ITERS: usize = 100;
let rl = Arc::new(LogRlimit::new(60_000_000_000, BURST));
let total_threads = THREADS_PER_HASH.saturating_mul(HASHES);
let barrier = Arc::new(Barrier::new(total_threads));
let mut handles = Vec::with_capacity(total_threads);
for h in 0..HASHES {
for _ in 0..THREADS_PER_HASH {
let rl = Arc::clone(&rl);
let barrier = Arc::clone(&barrier);
handles.push(thread::spawn(move || {
barrier.wait();
let mut allowed = 0i64;
for _ in 0..ITERS {
if rl.should_log(h as u64).0 {
allowed = allowed.saturating_add(1);
}
}
(h, allowed)
}));
}
}
let mut per_hash = [0i64; HASHES];
for h in handles {
let (hash, allowed) = h.join().unwrap();
per_hash[hash] = per_hash[hash].saturating_add(allowed);
}
for (_h, &count) in per_hash.iter().enumerate() {
assert!(count <= i64::from(BURST));
assert!(count >= 1);
}
}
#[test]
fn test_log_rlimit_12() {
const BURST: u32 = 2;
const THREADS: usize = 16;
const ITERS: usize = 500;
let rl = Arc::new(LogRlimit::new(1, BURST));
let barrier = Arc::new(Barrier::new(THREADS));
let handles: Vec<_> = (0..THREADS)
.map(|_| {
let rl = Arc::clone(&rl);
let barrier = Arc::clone(&barrier);
thread::spawn(move || {
barrier.wait();
let mut allowed = 0u64;
let mut missed_total = 0u64;
for _ in 0..ITERS {
let (a, m) = rl.should_log(42);
if a {
allowed = allowed.saturating_add(1);
}
missed_total = missed_total.saturating_add(u64::from(m));
}
(allowed, missed_total)
})
})
.collect();
let (total_allowed, _total_missed): (u64, u64) = handles
.into_iter()
.map(|h| h.join().unwrap())
.fold((0, 0), |(a1, m1), (a2, m2)| {
(a1.saturating_add(a2), m1.saturating_add(m2))
});
assert!(total_allowed > 0);
}
#[test]
fn test_log_rlimit_13() {
const THREADS: usize = 8;
const ITERS: usize = 500;
let rl = Arc::new(LogRlimit::new(60_000_000_000, 5));
let barrier = Arc::new(Barrier::new(THREADS + 1));
let handles: Vec<_> = (0..THREADS)
.map(|_| {
let rl = Arc::clone(&rl);
let barrier = Arc::clone(&barrier);
thread::spawn(move || {
barrier.wait();
for _ in 0..ITERS {
let _ = rl.should_log(42);
}
})
})
.collect();
barrier.wait();
for i in 0..ITERS {
if i % 2 == 0 {
rl.set_burst(0);
} else {
rl.set_burst(100);
}
}
for h in handles {
h.join().unwrap();
}
}
#[test]
fn test_log_rlimit_14() {
const BURST: u32 = 1;
const THREADS: usize = 64;
const ITERS: usize = 100;
let rl = Arc::new(LogRlimit::new(60_000_000_000, BURST));
let barrier = Arc::new(Barrier::new(THREADS));
let handles: Vec<_> = (0..THREADS)
.map(|_| {
let rl = Arc::clone(&rl);
let barrier = Arc::clone(&barrier);
thread::spawn(move || {
barrier.wait();
for _ in 0..ITERS {
let _ = rl.should_log(42);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
let n_left_ok = rl
.map
.read_sync(&42u64, |_, state| state.n_left.load(Ordering::Acquire) >= 0);
assert!(n_left_ok.unwrap_or(true));
}
}