use crate::{MetricsError, Result};
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant};
#[repr(align(64))]
pub struct RateMeter {
total_events: AtomicU64,
current_second_events: AtomicU32,
current_minute_events: AtomicU32,
current_hour_events: AtomicU32,
last_second: AtomicU64,
last_minute: AtomicU64,
last_hour: AtomicU64,
window_ns: u64,
created_at: Instant,
}
#[derive(Debug, Clone)]
pub struct RateStats {
pub total_events: u64,
pub per_second: f64,
pub per_minute: f64,
pub per_hour: f64,
pub average_rate: f64,
pub age: Duration,
pub window_fill: f64,
}
impl RateMeter {
#[inline]
pub fn new() -> Self {
Self::with_window(Duration::from_secs(1))
}
#[inline]
pub fn with_window(window: Duration) -> Self {
Self {
total_events: AtomicU64::new(0),
current_second_events: AtomicU32::new(0),
current_minute_events: AtomicU32::new(0),
current_hour_events: AtomicU32::new(0),
last_second: AtomicU64::new(0),
last_minute: AtomicU64::new(0),
last_hour: AtomicU64::new(0),
window_ns: window.as_nanos() as u64,
created_at: Instant::now(),
}
}
#[inline(always)]
pub fn tick(&self) {
self.tick_n(1);
}
#[inline(always)]
pub fn try_tick(&self) -> Result<()> {
self.try_tick_n(1)
}
#[inline(always)]
pub fn tick_n(&self, n: u32) {
if n == 0 {
return;
}
self.total_events.fetch_add(n as u64, Ordering::Relaxed);
let now = self.get_unix_timestamp();
self.update_windows(now, n);
}
#[inline(always)]
pub fn try_tick_n(&self, n: u32) -> Result<()> {
if n == 0 {
return Ok(());
}
let total = self.total_events.load(Ordering::Relaxed);
if total.checked_add(n as u64).is_none() {
return Err(MetricsError::Overflow);
}
let sec = self.current_second_events.load(Ordering::Relaxed);
if sec.checked_add(n).is_none() {
return Err(MetricsError::Overflow);
}
let min = self.current_minute_events.load(Ordering::Relaxed);
if min.checked_add(n).is_none() {
return Err(MetricsError::Overflow);
}
let hour = self.current_hour_events.load(Ordering::Relaxed);
if hour.checked_add(n).is_none() {
return Err(MetricsError::Overflow);
}
self.total_events.fetch_add(n as u64, Ordering::Relaxed);
let now = self.get_unix_timestamp();
self.update_windows(now, n);
Ok(())
}
#[must_use]
#[inline]
pub fn rate(&self) -> f64 {
let now = self.get_unix_timestamp();
self.update_windows(now, 0);
let events = self.current_second_events.load(Ordering::Relaxed);
events as f64
}
#[must_use]
#[inline]
pub fn rate_per_second(&self) -> f64 {
self.rate()
}
#[must_use]
#[inline]
pub fn rate_per_minute(&self) -> f64 {
let now = self.get_unix_timestamp();
self.update_windows(now, 0);
let events = self.current_minute_events.load(Ordering::Relaxed);
events as f64
}
#[must_use]
#[inline]
pub fn rate_per_hour(&self) -> f64 {
let now = self.get_unix_timestamp();
self.update_windows(now, 0);
let events = self.current_hour_events.load(Ordering::Relaxed);
events as f64
}
#[must_use]
#[inline(always)]
pub fn total(&self) -> u64 {
self.total_events.load(Ordering::Relaxed)
}
#[must_use]
#[inline]
pub fn exceeds_rate(&self, limit: f64) -> bool {
self.rate() > limit
}
#[must_use]
#[inline]
pub fn can_allow(&self, n: u32, limit: f64) -> bool {
let current_rate = self.rate();
(current_rate + n as f64) <= limit
}
#[must_use]
#[inline]
pub fn tick_if_under_limit(&self, limit: f64) -> bool {
if self.can_allow(1, limit) {
self.tick();
true
} else {
false
}
}
#[inline]
pub fn try_tick_if_under_limit(&self, limit: f64) -> Result<bool> {
if self.can_allow(1, limit) {
self.try_tick()?;
Ok(true)
} else {
Ok(false)
}
}
#[must_use]
#[inline]
pub fn tick_burst_if_under_limit(&self, n: u32, limit: f64) -> bool {
if self.can_allow(n, limit) {
self.tick_n(n);
true
} else {
false
}
}
#[inline]
pub fn reset(&self) {
let now = self.get_unix_timestamp();
self.total_events.store(0, Ordering::SeqCst);
self.current_second_events.store(0, Ordering::SeqCst);
self.current_minute_events.store(0, Ordering::SeqCst);
self.current_hour_events.store(0, Ordering::SeqCst);
self.last_second.store(now, Ordering::SeqCst);
self.last_minute.store(now / 60, Ordering::SeqCst);
self.last_hour.store(now / 3600, Ordering::SeqCst);
}
#[must_use]
pub fn stats(&self) -> RateStats {
let now = self.get_unix_timestamp();
self.update_windows(now, 0);
let total_events = self.total();
let per_second = self.current_second_events.load(Ordering::Relaxed) as f64;
let per_minute = self.current_minute_events.load(Ordering::Relaxed) as f64;
let per_hour = self.current_hour_events.load(Ordering::Relaxed) as f64;
let age = self.created_at.elapsed();
let average_rate = if age.as_secs_f64() > 0.0 {
total_events as f64 / age.as_secs_f64()
} else {
0.0
};
let window_fill = if self.window_ns > 0 {
let window_seconds = self.window_ns as f64 / 1_000_000_000.0;
let elapsed_in_window = age.as_secs_f64().min(window_seconds);
(elapsed_in_window / window_seconds * 100.0).min(100.0)
} else {
100.0
};
RateStats {
total_events,
per_second,
per_minute,
per_hour,
average_rate,
age,
window_fill,
}
}
#[must_use]
#[inline]
pub fn age(&self) -> Duration {
self.created_at.elapsed()
}
#[must_use]
#[inline]
pub fn is_empty(&self) -> bool {
self.total() == 0
}
#[inline(always)]
fn get_unix_timestamp(&self) -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[inline]
fn update_windows(&self, now: u64, new_events: u32) {
let current_second = now;
let last_second = self.last_second.load(Ordering::Relaxed);
if current_second != last_second {
if self
.last_second
.compare_exchange(
last_second,
current_second,
Ordering::Relaxed,
Ordering::Relaxed,
)
.is_ok()
{
self.current_second_events
.store(new_events, Ordering::Relaxed);
} else {
self.current_second_events
.fetch_add(new_events, Ordering::Relaxed);
}
} else if new_events > 0 {
self.current_second_events
.fetch_add(new_events, Ordering::Relaxed);
}
let current_minute = now / 60;
let last_minute = self.last_minute.load(Ordering::Relaxed);
if current_minute != last_minute {
if self
.last_minute
.compare_exchange(
last_minute,
current_minute,
Ordering::Relaxed,
Ordering::Relaxed,
)
.is_ok()
{
self.current_minute_events
.store(new_events, Ordering::Relaxed);
} else {
self.current_minute_events
.fetch_add(new_events, Ordering::Relaxed);
}
} else if new_events > 0 {
self.current_minute_events
.fetch_add(new_events, Ordering::Relaxed);
}
let current_hour = now / 3600;
let last_hour = self.last_hour.load(Ordering::Relaxed);
if current_hour != last_hour {
if self
.last_hour
.compare_exchange(
last_hour,
current_hour,
Ordering::Relaxed,
Ordering::Relaxed,
)
.is_ok()
{
self.current_hour_events
.store(new_events, Ordering::Relaxed);
} else {
self.current_hour_events
.fetch_add(new_events, Ordering::Relaxed);
}
} else if new_events > 0 {
self.current_hour_events
.fetch_add(new_events, Ordering::Relaxed);
}
}
}
impl Default for RateMeter {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for RateMeter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RateMeter({:.1}/s, {} total)", self.rate(), self.total())
}
}
impl std::fmt::Debug for RateMeter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let stats = self.stats();
f.debug_struct("RateMeter")
.field("total_events", &stats.total_events)
.field("per_second", &stats.per_second)
.field("per_minute", &stats.per_minute)
.field("average_rate", &stats.average_rate)
.field("age", &stats.age)
.finish()
}
}
pub mod specialized {
use super::*;
#[repr(align(64))]
pub struct ApiRateLimiter {
meter: RateMeter,
limit: AtomicU32, }
impl ApiRateLimiter {
#[inline]
pub fn new(requests_per_second: u32) -> Self {
Self {
meter: RateMeter::new(),
limit: AtomicU32::new(requests_per_second),
}
}
#[inline]
pub fn try_request(&self) -> bool {
let limit = self.limit.load(Ordering::Relaxed) as f64;
self.meter.tick_if_under_limit(limit)
}
#[inline]
pub fn try_requests(&self, n: u32) -> bool {
let limit = self.limit.load(Ordering::Relaxed) as f64;
self.meter.tick_burst_if_under_limit(n, limit)
}
#[inline]
pub fn set_limit(&self, requests_per_second: u32) {
self.limit.store(requests_per_second, Ordering::Relaxed);
}
#[inline]
pub fn get_limit(&self) -> u32 {
self.limit.load(Ordering::Relaxed)
}
#[inline]
pub fn current_rate(&self) -> f64 {
self.meter.rate()
}
#[inline]
pub fn total_requests(&self) -> u64 {
self.meter.total()
}
#[inline]
pub fn is_over_limit(&self) -> bool {
let limit = self.limit.load(Ordering::Relaxed) as f64;
self.meter.rate() > limit
}
#[inline]
pub fn reset(&self) {
self.meter.reset();
}
}
impl Default for ApiRateLimiter {
fn default() -> Self {
Self::new(1000)
} }
#[repr(align(64))]
pub struct ThroughputMeter {
meter: RateMeter,
}
impl ThroughputMeter {
#[inline]
pub fn new() -> Self {
Self {
meter: RateMeter::new(),
}
}
#[inline(always)]
pub fn record_bytes(&self, bytes: u64) {
self.meter.tick_n(bytes as u32);
}
#[inline]
pub fn bytes_per_second(&self) -> f64 {
self.meter.rate()
}
#[inline]
pub fn kb_per_second(&self) -> f64 {
self.meter.rate() / 1024.0
}
#[inline]
pub fn mb_per_second(&self) -> f64 {
self.meter.rate() / (1024.0 * 1024.0)
}
#[inline]
pub fn gb_per_second(&self) -> f64 {
self.meter.rate() / (1024.0 * 1024.0 * 1024.0)
}
#[inline]
pub fn total_bytes(&self) -> u64 {
self.meter.total()
}
#[inline]
pub fn reset(&self) {
self.meter.reset();
}
}
impl Default for ThroughputMeter {
fn default() -> Self {
Self::new()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_basic_operations() {
let meter = RateMeter::new();
assert!(meter.is_empty());
assert_eq!(meter.total(), 0);
assert_eq!(meter.rate(), 0.0);
meter.tick();
assert!(!meter.is_empty());
assert_eq!(meter.total(), 1);
meter.tick_n(5);
assert_eq!(meter.total(), 6);
}
#[test]
fn test_rate_calculations() {
let meter = RateMeter::new();
for _ in 0..100 {
meter.tick();
}
let rate = meter.rate();
assert_eq!(rate, 100.0);
assert_eq!(meter.rate_per_second(), 100.0);
}
#[test]
fn test_multiple_windows() {
let meter = RateMeter::new();
for _ in 0..60 {
meter.tick();
}
let stats = meter.stats();
assert_eq!(stats.total_events, 60);
assert_eq!(stats.per_second, 60.0);
assert_eq!(stats.per_minute, 60.0);
assert_eq!(stats.per_hour, 60.0);
}
#[test]
fn test_rate_limiting() {
let meter = RateMeter::new();
assert!(meter.tick_if_under_limit(10.0));
assert!(meter.tick_if_under_limit(10.0));
meter.tick_n(8);
assert!(!meter.tick_if_under_limit(10.0));
assert!(meter.exceeds_rate(9.0));
assert!(!meter.exceeds_rate(11.0));
}
#[test]
fn test_burst_rate_limiting() {
let meter = RateMeter::new();
assert!(meter.tick_burst_if_under_limit(5, 10.0));
assert_eq!(meter.total(), 5);
assert!(!meter.tick_burst_if_under_limit(10, 10.0));
assert_eq!(meter.total(), 5);
assert!(meter.tick_burst_if_under_limit(3, 10.0));
assert_eq!(meter.total(), 8);
}
#[test]
fn test_can_allow() {
let meter = RateMeter::new();
meter.tick_n(5);
assert!(meter.can_allow(3, 10.0)); assert!(!meter.can_allow(6, 10.0)); assert!(meter.can_allow(5, 10.0)); }
#[test]
fn test_reset() {
let meter = RateMeter::new();
meter.tick_n(100);
assert_eq!(meter.total(), 100);
assert!(meter.rate() > 0.0);
meter.reset();
assert_eq!(meter.total(), 0);
assert_eq!(meter.rate(), 0.0);
assert!(meter.is_empty());
}
#[test]
fn test_statistics() {
let meter = RateMeter::new();
meter.tick_n(50);
let stats = meter.stats();
assert_eq!(stats.total_events, 50);
assert_eq!(stats.per_second, 50.0);
assert!(stats.average_rate > 0.0);
assert!(stats.age > Duration::from_nanos(0));
assert!(stats.window_fill >= 0.0);
}
#[test]
fn test_api_rate_limiter() {
let limiter = specialized::ApiRateLimiter::new(10);
for _ in 0..10 {
assert!(limiter.try_request());
}
assert!(!limiter.try_request());
assert_eq!(limiter.current_rate(), 10.0);
assert_eq!(limiter.total_requests(), 10);
assert_eq!(limiter.get_limit(), 10);
limiter.set_limit(20);
assert_eq!(limiter.get_limit(), 20);
assert!(!limiter.is_over_limit());
limiter.reset();
assert!(limiter.try_requests(5));
assert_eq!(limiter.total_requests(), 5);
assert!(!limiter.try_requests(20)); assert_eq!(limiter.total_requests(), 5); }
#[test]
fn test_throughput_meter() {
let meter = specialized::ThroughputMeter::new();
meter.record_bytes(1024); assert_eq!(meter.bytes_per_second(), 1024.0);
assert_eq!(meter.kb_per_second(), 1.0);
assert_eq!(meter.total_bytes(), 1024);
meter.record_bytes(1024 * 1024); assert_eq!(meter.total_bytes(), 1024 + 1024 * 1024);
assert!((meter.mb_per_second() - 1.001).abs() < 0.01);
}
#[test]
fn test_high_concurrency() {
let meter = Arc::new(RateMeter::new());
let num_threads = 50;
let ticks_per_thread = 1000;
let handles: Vec<_> = (0..num_threads)
.map(|_| {
let meter = Arc::clone(&meter);
thread::spawn(move || {
for _ in 0..ticks_per_thread {
meter.tick();
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
assert_eq!(meter.total(), num_threads * ticks_per_thread);
let stats = meter.stats();
assert!(stats.average_rate > 0.0);
assert_eq!(stats.total_events, num_threads * ticks_per_thread);
}
#[test]
fn test_concurrent_rate_limiting() {
let limiter = Arc::new(specialized::ApiRateLimiter::new(100));
let num_threads = 20;
let handles: Vec<_> = (0..num_threads)
.map(|_| {
let limiter = Arc::clone(&limiter);
thread::spawn(move || {
let mut successful = 0;
for _ in 0..10 {
if limiter.try_request() {
successful += 1;
}
}
successful
})
})
.collect();
let total_successful: i32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
let total_attempts = num_threads * 10;
let strict_cap = 2 * 100; let upper_bound = if cfg!(coverage) {
total_attempts.min(strict_cap.max(160))
} else {
total_attempts.min(strict_cap)
};
assert!(
total_successful <= upper_bound,
"total_successful={total_successful} > upper_bound={upper_bound}",
);
assert!(
total_successful >= 90,
"total_successful={total_successful} < lower_bound=90",
); }
#[test]
fn test_display_and_debug() {
let meter = RateMeter::new();
meter.tick_n(42);
let display_str = format!("{meter}");
assert!(display_str.contains("RateMeter"));
assert!(display_str.contains("42 total"));
let debug_str = format!("{meter:?}");
assert!(debug_str.contains("RateMeter"));
assert!(debug_str.contains("total_events"));
}
#[test]
fn test_custom_window() {
let meter = RateMeter::with_window(Duration::from_secs(5));
meter.tick_n(10);
assert_eq!(meter.total(), 10);
assert_eq!(meter.rate(), 10.0);
let stats = meter.stats();
assert!(stats.window_fill >= 0.0);
}
#[test]
fn test_try_tick_and_try_tick_n_ok() {
let meter = RateMeter::new();
assert!(meter.try_tick().is_ok());
assert!(meter.try_tick_n(5).is_ok());
assert_eq!(meter.total(), 6);
}
#[test]
fn test_try_tick_n_total_overflow() {
let meter = RateMeter::new();
meter.total_events.store(u64::MAX - 1, Ordering::Relaxed);
let err = meter.try_tick_n(2).unwrap_err();
assert_eq!(err, MetricsError::Overflow);
}
#[test]
fn test_try_tick_n_window_overflow() {
let meter = RateMeter::new();
let now = meter.get_unix_timestamp();
meter.last_second.store(now, Ordering::Relaxed);
meter.last_minute.store(now / 60, Ordering::Relaxed);
meter.last_hour.store(now / 3600, Ordering::Relaxed);
meter
.current_second_events
.store(u32::MAX - 1, Ordering::Relaxed);
meter
.current_minute_events
.store(u32::MAX - 1, Ordering::Relaxed);
meter
.current_hour_events
.store(u32::MAX - 1, Ordering::Relaxed);
let err = meter.try_tick_n(2).unwrap_err();
assert_eq!(err, MetricsError::Overflow);
}
#[test]
fn test_try_tick_if_under_limit() {
let meter = RateMeter::new();
assert!(meter.try_tick_if_under_limit(10.0).unwrap());
assert!(meter.try_tick_n(8).is_ok());
assert!(meter.try_tick_if_under_limit(10.0).unwrap());
assert!(!meter.try_tick_if_under_limit(10.0).unwrap());
}
}
#[cfg(all(test, feature = "bench-tests", not(tarpaulin)))]
#[allow(unused_imports)]
mod benchmarks {
use super::*;
use std::time::Instant;
#[cfg_attr(not(feature = "bench-tests"), ignore)]
#[test]
fn bench_rate_meter_tick() {
let meter = RateMeter::new();
let iterations = 10_000_000;
let start = Instant::now();
for _ in 0..iterations {
meter.tick();
}
let elapsed = start.elapsed();
println!(
"RateMeter tick: {:.2} ns/op",
elapsed.as_nanos() as f64 / iterations as f64
);
assert_eq!(meter.total(), iterations);
assert!(elapsed.as_nanos() / (iterations as u128) < 400);
}
#[cfg_attr(not(feature = "bench-tests"), ignore)]
#[test]
fn bench_rate_meter_tick_n() {
let meter = RateMeter::new();
let iterations = 1_000_000;
let start = Instant::now();
for i in 0..iterations {
meter.tick_n((i % 10) + 1);
}
let elapsed = start.elapsed();
println!(
"RateMeter tick_n: {:.2} ns/op",
elapsed.as_nanos() as f64 / iterations as f64
);
assert!(elapsed.as_nanos() / (iterations as u128) < 500);
}
#[cfg_attr(not(feature = "bench-tests"), ignore)]
#[test]
fn bench_rate_calculation() {
let meter = RateMeter::new();
meter.tick_n(1000);
let iterations = 1_000_000;
let start = Instant::now();
for _ in 0..iterations {
let _ = meter.rate();
}
let elapsed = start.elapsed();
println!(
"RateMeter rate: {:.2} ns/op",
elapsed.as_nanos() as f64 / iterations as f64
);
assert!(elapsed.as_nanos() / iterations < 300);
}
#[cfg_attr(not(feature = "bench-tests"), ignore)]
#[test]
fn bench_api_rate_limiter() {
let limiter = specialized::ApiRateLimiter::new(1_000_000); let iterations = 1_000_000;
let start = Instant::now();
for _ in 0..iterations {
let _ = limiter.try_request();
}
let elapsed = start.elapsed();
println!(
"ApiRateLimiter try_request: {:.2} ns/op",
elapsed.as_nanos() as f64 / iterations as f64
);
assert!(elapsed.as_nanos() / iterations < 1000);
}
}