#![cfg_attr(docsrs, feature(doc_cfg))]
#![warn(
missing_docs,
rust_2018_idioms,
unreachable_pub,
missing_debug_implementations
)]
#![forbid(unsafe_op_in_unsafe_fn)]
mod rate_limiter;
pub use rate_limiter::{
cpu_relax, current_time_ms, current_time_ns, current_time_us, HealthStatus,
IpRateLimiterManager, ManagerStats, MemoryOrdering, RateLimiter, RateLimiterConfig,
RateLimiterMetrics,
};
pub type SharedRateLimiter = std::sync::Arc<RateLimiter>;
pub type SharedIpManager = std::sync::Arc<IpRateLimiterManager>;
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
pub const MSRV: &str = "1.70.0";
pub mod prelude {
pub use crate::{
HealthStatus, IpRateLimiterManager, ManagerStats, MemoryOrdering, RateLimiter,
RateLimiterConfig, RateLimiterMetrics, SharedIpManager, SharedRateLimiter,
};
}
#[derive(Debug, Clone)]
pub struct RateLimiterBuilder {
config: RateLimiterConfig,
}
impl RateLimiterBuilder {
pub fn new() -> Self {
Self {
config: RateLimiterConfig::default(),
}
}
pub fn max_tokens(mut self, tokens: u64) -> Self {
self.config.max_tokens = tokens;
self
}
pub fn refill_rate(mut self, rate: u32) -> Self {
self.config.refill_rate = rate;
self
}
pub fn refill_interval_ms(mut self, ms: u64) -> Self {
self.config.refill_interval_ms = ms;
self
}
pub fn memory_ordering(mut self, ordering: MemoryOrdering) -> Self {
self.config.ordering = ordering;
self
}
pub fn build(self) -> RateLimiter {
RateLimiter::with_config(self.config)
}
pub fn try_build(self) -> Result<RateLimiter, &'static str> {
self.config.validate()?;
Ok(RateLimiter::with_config(self.config))
}
}
impl Default for RateLimiterBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
#[cfg_attr(miri, ignore)]
fn test_basic_functionality() {
let config = RateLimiterConfig::new(10, 1, 600_000);
let limiter = RateLimiter::with_config(config);
for _ in 0..10 {
assert!(limiter.try_acquire());
}
assert!(!limiter.try_acquire());
let metrics = limiter.metrics();
assert_eq!(metrics.total_acquired, 10);
assert_eq!(metrics.total_rejected, 1);
}
#[test]
fn test_builder() {
let limiter = RateLimiterBuilder::new()
.max_tokens(50)
.refill_rate(5)
.refill_interval_ms(1000)
.build();
assert_eq!(limiter.available_tokens(), 50);
}
#[test]
fn test_builder_validation() {
let result = RateLimiterBuilder::new().max_tokens(0).try_build();
assert!(result.is_err());
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_thread_safety() {
let config = RateLimiterConfig::new(1000, 100, 60_000);
let limiter = Arc::new(RateLimiter::with_config(config));
let mut handles = vec![];
for _ in 0..10 {
let limiter_clone = limiter.clone();
let handle = thread::spawn(move || {
let mut acquired = 0;
for _ in 0..200 {
if limiter_clone.try_acquire() {
acquired += 1;
}
}
acquired
});
handles.push(handle);
}
let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
assert!(total <= 1100, "total={total}");
assert!(total >= 900, "total={total}");
}
#[test]
fn test_prelude_imports() {
use crate::prelude::*;
let _limiter = RateLimiter::new(10, 1);
let _config = RateLimiterConfig::default();
let _ordering = MemoryOrdering::AcquireRelease;
let _status = HealthStatus::Healthy;
}
#[test]
fn test_shared_types() {
let limiter = RateLimiter::new(10, 1);
let _shared: SharedRateLimiter = std::sync::Arc::new(limiter);
let manager = IpRateLimiterManager::new(RateLimiterConfig::default());
let _shared_manager: SharedIpManager = std::sync::Arc::new(manager);
}
#[test]
fn test_constants() {
assert!(!VERSION.is_empty());
assert_eq!(MSRV, "1.70.0");
}
#[test]
fn test_builder_default() {
let builder = RateLimiterBuilder::default();
let limiter = builder.build();
assert!(limiter.available_tokens() > 0);
}
#[test]
fn test_builder_chain() {
let limiter = RateLimiterBuilder::new()
.max_tokens(100)
.refill_rate(10)
.refill_interval_ms(500)
.memory_ordering(MemoryOrdering::Sequential)
.build();
assert_eq!(limiter.available_tokens(), 100);
}
#[test]
fn test_builder_try_build_refill_rate_zero() {
let result = RateLimiterBuilder::new().refill_rate(0).try_build();
assert!(result.is_err());
}
#[test]
fn test_builder_try_build_success() {
let result = RateLimiterBuilder::new()
.max_tokens(50)
.refill_rate(10)
.refill_interval_ms(1000)
.try_build();
assert!(result.is_ok());
let limiter = result.unwrap();
assert_eq!(limiter.available_tokens(), 50);
}
#[test]
fn test_builder_debug() {
let builder = RateLimiterBuilder::new();
let debug = format!("{:?}", builder);
assert!(debug.contains("RateLimiterBuilder"));
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_metrics_after_mixed_operations() {
let config = RateLimiterConfig::new(10, 1, 600_000);
let limiter = RateLimiter::with_config(config);
for _ in 0..5 {
assert!(limiter.try_acquire());
}
assert!(limiter.try_acquire_n(3));
assert!(!limiter.try_acquire_n(5));
assert!(limiter.try_acquire());
assert!(limiter.try_acquire());
assert!(!limiter.try_acquire());
let metrics = limiter.metrics();
assert_eq!(metrics.total_acquired, 8);
assert_eq!(metrics.total_rejected, 2);
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_shared_limiter_across_threads() {
let shared: SharedRateLimiter = std::sync::Arc::new(RateLimiter::new(100, 10));
let mut handles = vec![];
for _ in 0..4 {
let s = shared.clone();
handles.push(thread::spawn(move || {
let mut count = 0u32;
for _ in 0..50 {
if s.try_acquire() {
count += 1;
}
}
count
}));
}
let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
assert!(total > 0 && total <= 100);
}
#[test]
fn test_version_is_valid_semver() {
let parts: Vec<&str> = VERSION.split('.').collect();
assert_eq!(parts.len(), 3, "VERSION should be semver: {}", VERSION);
for part in &parts {
assert!(part.parse::<u32>().is_ok(), "Invalid semver part: {}", part);
}
}
}