#![allow(dead_code)]
use crate::neo_error::{Neo3Error, NetworkError};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, Semaphore};
use tokio::time::sleep;
pub struct RateLimiter {
max_requests: u32,
window: Duration,
semaphore: Arc<Semaphore>,
bucket: Arc<Mutex<TokenBucket>>,
}
struct TokenBucket {
capacity: u32,
tokens: f64,
last_refill: Instant,
refill_rate: f64,
}
impl RateLimiter {
pub fn new(max_requests: u32, window: Duration, max_concurrent: usize) -> Self {
let refill_rate = max_requests as f64 / window.as_secs_f64();
Self {
max_requests,
window,
semaphore: Arc::new(Semaphore::new(max_concurrent)),
bucket: Arc::new(Mutex::new(TokenBucket {
capacity: max_requests,
tokens: max_requests as f64,
last_refill: Instant::now(),
refill_rate,
})),
}
}
pub async fn acquire(&self) -> Result<RateLimitPermit<'_>, Neo3Error> {
let _sem_permit = self
.semaphore
.acquire()
.await
.map_err(|_| Neo3Error::Network(NetworkError::RateLimitExceeded))?;
loop {
let mut bucket = self.bucket.lock().await;
let now = Instant::now();
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
bucket.tokens =
(bucket.tokens + elapsed * bucket.refill_rate).min(bucket.capacity as f64);
bucket.last_refill = now;
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
return Ok(RateLimitPermit { _semaphore: _sem_permit });
}
let wait_time = Duration::from_secs_f64(1.0 / bucket.refill_rate);
drop(bucket);
sleep(wait_time).await;
}
}
pub async fn try_acquire(&self) -> Result<RateLimitPermit<'_>, Neo3Error> {
let _sem_permit = self
.semaphore
.try_acquire()
.map_err(|_| Neo3Error::Network(NetworkError::RateLimitExceeded))?;
let mut bucket = self.bucket.lock().await;
let now = Instant::now();
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
bucket.tokens = (bucket.tokens + elapsed * bucket.refill_rate).min(bucket.capacity as f64);
bucket.last_refill = now;
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
Ok(RateLimitPermit { _semaphore: _sem_permit })
} else {
Err(Neo3Error::Network(NetworkError::RateLimitExceeded))
}
}
pub async fn available_tokens(&self) -> f64 {
let mut bucket = self.bucket.lock().await;
let now = Instant::now();
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
bucket.tokens = (bucket.tokens + elapsed * bucket.refill_rate).min(bucket.capacity as f64);
bucket.last_refill = now;
bucket.tokens
}
pub async fn reset(&self) {
let mut bucket = self.bucket.lock().await;
bucket.tokens = bucket.capacity as f64;
bucket.last_refill = Instant::now();
}
}
pub struct RateLimitPermit<'a> {
_semaphore: tokio::sync::SemaphorePermit<'a>,
}
pub struct RateLimiterBuilder {
max_requests: u32,
window: Duration,
max_concurrent: usize,
}
impl RateLimiterBuilder {
pub fn new() -> Self {
Self { max_requests: 100, window: Duration::from_secs(1), max_concurrent: 10 }
}
pub fn max_requests(mut self, max: u32) -> Self {
self.max_requests = max;
self
}
pub fn window(mut self, window: Duration) -> Self {
self.window = window;
self
}
pub fn max_concurrent(mut self, max: usize) -> Self {
self.max_concurrent = max;
self
}
pub fn build(self) -> RateLimiter {
RateLimiter::new(self.max_requests, self.window, self.max_concurrent)
}
}
impl Default for RateLimiterBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct RateLimiterPresets;
impl RateLimiterPresets {
pub fn conservative() -> RateLimiter {
RateLimiterBuilder::new()
.max_requests(10)
.window(Duration::from_secs(1))
.max_concurrent(5)
.build()
}
pub fn standard() -> RateLimiter {
RateLimiterBuilder::new()
.max_requests(100)
.window(Duration::from_secs(1))
.max_concurrent(20)
.build()
}
pub fn aggressive() -> RateLimiter {
RateLimiterBuilder::new()
.max_requests(1000)
.window(Duration::from_secs(1))
.max_concurrent(100)
.build()
}
pub fn custom(requests_per_second: u32, max_concurrent: usize) -> RateLimiter {
RateLimiterBuilder::new()
.max_requests(requests_per_second)
.window(Duration::from_secs(1))
.max_concurrent(max_concurrent)
.build()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rate_limiter_basic() {
let limiter = RateLimiter::new(5, Duration::from_secs(1), 2);
for _ in 0..5 {
assert!(limiter.try_acquire().await.is_ok());
}
assert!(limiter.try_acquire().await.is_err());
}
#[tokio::test]
async fn test_rate_limiter_refill() {
let limiter = RateLimiter::new(2, Duration::from_secs(1), 10);
assert!(limiter.try_acquire().await.is_ok());
assert!(limiter.try_acquire().await.is_ok());
assert!(limiter.try_acquire().await.is_err());
sleep(Duration::from_millis(600)).await;
assert!(limiter.try_acquire().await.is_ok());
}
#[tokio::test]
async fn test_concurrent_limiting() {
let limiter = Arc::new(RateLimiter::new(100, Duration::from_secs(1), 2));
let mut handles = vec![];
for _ in 0..3 {
let limiter = limiter.clone();
handles.push(tokio::spawn(async move {
limiter.acquire().await.is_ok()
}));
}
let results = futures_util::future::join_all(handles).await;
assert!(results.into_iter().all(|r| r.unwrap_or(false)));
let tokens = limiter.available_tokens().await;
assert!(tokens < 100.0); }
#[test]
fn test_builder() {
let limiter = RateLimiterBuilder::new()
.max_requests(50)
.window(Duration::from_secs(2))
.max_concurrent(15)
.build();
assert_eq!(limiter.max_requests, 50);
assert_eq!(limiter.window, Duration::from_secs(2));
}
}