use crate::sleep_compat::sleep;
use crate::sync_compat::Mutex;
use crate::time_compat::Instant;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct RateLimiter {
limiters: Arc<Mutex<HashMap<RateLimitCategory, TokenBucket>>>,
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub enum RateLimitCategory {
Trading,
MarketData,
Account,
Auth,
General,
}
#[derive(Debug)]
struct TokenBucket {
capacity: u32,
tokens: u32,
refill_rate: u32,
last_refill: Instant,
}
impl TokenBucket {
fn new(capacity: u32, refill_rate: u32) -> Self {
Self {
capacity,
tokens: capacity,
refill_rate,
last_refill: Instant::now(),
}
}
fn try_consume(&mut self) -> bool {
self.refill();
if self.tokens > 0 {
self.tokens -= 1;
true
} else {
false
}
}
fn time_until_token(&self) -> Duration {
if self.tokens > 0 {
Duration::from_secs(0)
} else {
Duration::from_secs_f64(1.0 / self.refill_rate as f64)
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill);
let tokens_to_add = (elapsed.as_secs_f64() * self.refill_rate as f64) as u32;
if tokens_to_add > 0 {
self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
self.last_refill = now;
}
}
}
impl RateLimiter {
pub fn new() -> Self {
let mut limiters = HashMap::new();
limiters.insert(RateLimitCategory::Trading, TokenBucket::new(250, 200));
limiters.insert(RateLimitCategory::MarketData, TokenBucket::new(500, 400));
limiters.insert(RateLimitCategory::Account, TokenBucket::new(200, 150));
limiters.insert(RateLimitCategory::Auth, TokenBucket::new(50, 30));
limiters.insert(RateLimitCategory::General, TokenBucket::new(300, 200));
Self {
limiters: Arc::new(Mutex::new(limiters)),
}
}
pub async fn wait_for_permission(&self, category: RateLimitCategory) {
loop {
let wait_time = {
let mut limiters = self.limiters.lock().await;
let bucket = limiters
.get_mut(&category)
.expect("Rate limit category should exist");
if bucket.try_consume() {
return; } else {
bucket.time_until_token()
}
};
sleep(wait_time.max(Duration::from_millis(10))).await;
}
}
pub async fn check_permission(&self, category: RateLimitCategory) -> bool {
let mut limiters = self.limiters.lock().await;
let bucket = limiters
.get_mut(&category)
.expect("Rate limit category should exist");
bucket.try_consume()
}
pub async fn get_tokens(&self, category: RateLimitCategory) -> u32 {
let mut limiters = self.limiters.lock().await;
let bucket = limiters
.get_mut(&category)
.expect("Rate limit category should exist");
bucket.refill();
bucket.tokens
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new()
}
}
pub fn categorize_endpoint(endpoint: &str) -> RateLimitCategory {
if endpoint.contains("/private/buy")
|| endpoint.contains("/private/sell")
|| endpoint.contains("/private/cancel")
|| endpoint.contains("/private/edit")
{
RateLimitCategory::Trading
} else if endpoint.contains("/public/ticker")
|| endpoint.contains("/public/get_order_book")
|| endpoint.contains("/public/get_last_trades")
|| endpoint.contains("/public/get_instruments")
{
RateLimitCategory::MarketData
} else if endpoint.contains("/private/get_account_summary")
|| endpoint.contains("/private/get_positions")
|| endpoint.contains("/private/get_subaccounts")
{
RateLimitCategory::Account
} else if endpoint.contains("/public/auth") || endpoint.contains("/private/logout") {
RateLimitCategory::Auth
} else {
RateLimitCategory::General
}
}
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use super::*;
use crate::sleep_compat::sleep;
#[tokio::test]
async fn test_token_bucket_basic() {
let mut bucket = TokenBucket::new(10, 5);
for _ in 0..10 {
assert!(bucket.try_consume());
}
assert!(!bucket.try_consume());
}
#[tokio::test]
async fn test_token_bucket_refill() {
let mut bucket = TokenBucket::new(5, 10);
for _ in 0..5 {
assert!(bucket.try_consume());
}
assert!(!bucket.try_consume());
sleep(Duration::from_millis(200)).await;
assert!(bucket.try_consume());
}
#[tokio::test]
async fn test_rate_limiter() {
let limiter = RateLimiter::new();
assert!(limiter.check_permission(RateLimitCategory::Trading).await);
limiter
.wait_for_permission(RateLimitCategory::MarketData)
.await;
}
#[test]
fn test_endpoint_categorization() {
assert_eq!(
categorize_endpoint("/private/buy"),
RateLimitCategory::Trading
);
assert_eq!(
categorize_endpoint("/public/ticker"),
RateLimitCategory::MarketData
);
assert_eq!(
categorize_endpoint("/private/get_account_summary"),
RateLimitCategory::Account
);
assert_eq!(categorize_endpoint("/public/auth"), RateLimitCategory::Auth);
assert_eq!(
categorize_endpoint("/public/get_time"),
RateLimitCategory::General
);
}
}