use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::Mutex;
use tracing::instrument;
use kraken_types::{
RateLimitCategory, RateLimitConfig, RateLimitResult, TokenBucket, TokenBucketConfig,
};
#[derive(Debug)]
pub struct KrakenRateLimiter {
config: RateLimitConfig,
buckets: HashMap<RateLimitCategory, Mutex<TokenBucket>>,
symbol_buckets: Mutex<HashMap<String, TokenBucket>>,
}
impl Default for KrakenRateLimiter {
fn default() -> Self {
Self::new(RateLimitConfig::kraken_defaults())
}
}
impl KrakenRateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
let mut buckets = HashMap::new();
buckets.insert(
RateLimitCategory::Connection,
Mutex::new(config.connection_limit.create_bucket()),
);
buckets.insert(
RateLimitCategory::RestPublic,
Mutex::new(config.rest_public.create_bucket()),
);
buckets.insert(
RateLimitCategory::RestPrivate,
Mutex::new(config.rest_private.create_bucket()),
);
buckets.insert(
RateLimitCategory::WsOrders,
Mutex::new(config.ws_orders.create_bucket()),
);
buckets.insert(
RateLimitCategory::L3Depth10,
Mutex::new(config.l3_depth_10.create_bucket()),
);
buckets.insert(
RateLimitCategory::L3Depth100,
Mutex::new(config.l3_depth_100.create_bucket()),
);
buckets.insert(
RateLimitCategory::L3Depth1000,
Mutex::new(config.l3_depth_1000.create_bucket()),
);
Self {
config,
buckets,
symbol_buckets: Mutex::new(HashMap::new()),
}
}
pub fn kraken_defaults() -> Self {
Self::new(RateLimitConfig::kraken_defaults())
}
pub fn high_tier() -> Self {
Self::new(RateLimitConfig::high_tier())
}
pub fn permissive() -> Self {
Self::new(RateLimitConfig::permissive())
}
pub fn try_acquire(&self, category: RateLimitCategory) -> RateLimitResult {
self.try_acquire_n(category, 1)
}
pub fn try_acquire_n(&self, category: RateLimitCategory, tokens: u32) -> RateLimitResult {
if let Some(bucket) = self.buckets.get(&category) {
let mut bucket = bucket.lock();
match bucket.try_acquire(tokens) {
Ok(()) => RateLimitResult::Allowed,
Err(wait) => RateLimitResult::Limited { wait, category },
}
} else {
RateLimitResult::Allowed
}
}
pub fn check(&self, category: RateLimitCategory) -> bool {
self.check_n(category, 1)
}
pub fn check_n(&self, category: RateLimitCategory, tokens: u32) -> bool {
if let Some(bucket) = self.buckets.get(&category) {
let mut bucket = bucket.lock();
bucket.check_available(tokens)
} else {
true
}
}
pub fn available(&self, category: RateLimitCategory) -> u32 {
if let Some(bucket) = self.buckets.get(&category) {
let mut bucket = bucket.lock();
bucket.available()
} else {
u32::MAX
}
}
#[instrument(skip(self), level = "debug")]
pub async fn acquire(&self, category: RateLimitCategory) {
self.acquire_n(category, 1).await
}
#[instrument(skip(self), level = "debug")]
pub async fn acquire_n(&self, category: RateLimitCategory, tokens: u32) {
loop {
match self.try_acquire_n(category, tokens) {
RateLimitResult::Allowed => return,
RateLimitResult::Limited { wait, .. } => {
tokio::time::sleep(wait).await;
}
}
}
}
pub fn try_acquire_l3(&self, depth: u32) -> RateLimitResult {
let category = RateLimitCategory::from_l3_depth(depth);
self.try_acquire(category)
}
pub fn try_acquire_ws_order(&self) -> RateLimitResult {
self.try_acquire(RateLimitCategory::WsOrders)
}
pub fn try_acquire_connection(&self) -> RateLimitResult {
self.try_acquire(RateLimitCategory::Connection)
}
pub fn try_acquire_symbol(&self, symbol: &str, config: TokenBucketConfig) -> RateLimitResult {
let mut symbol_buckets = self.symbol_buckets.lock();
let bucket = symbol_buckets
.entry(symbol.to_string())
.or_insert_with(|| config.create_bucket());
match bucket.try_acquire(1) {
Ok(()) => RateLimitResult::Allowed,
Err(wait) => RateLimitResult::Limited {
wait,
category: RateLimitCategory::RestPublic, },
}
}
pub fn reset_all(&self) {
for bucket in self.buckets.values() {
bucket.lock().reset();
}
self.symbol_buckets.lock().clear();
}
pub fn reset(&self, category: RateLimitCategory) {
if let Some(bucket) = self.buckets.get(&category) {
bucket.lock().reset();
}
}
pub fn get_config(&self, category: RateLimitCategory) -> TokenBucketConfig {
category.get_config(&self.config)
}
pub fn utilization(&self, category: RateLimitCategory) -> f64 {
if let Some(bucket) = self.buckets.get(&category) {
let mut bucket = bucket.lock();
1.0 - (bucket.available() as f64 / bucket.capacity() as f64)
} else {
0.0
}
}
}
pub type SharedRateLimiter = Arc<KrakenRateLimiter>;
pub fn shared_rate_limiter() -> SharedRateLimiter {
Arc::new(KrakenRateLimiter::kraken_defaults())
}
pub fn shared_rate_limiter_with_config(config: RateLimitConfig) -> SharedRateLimiter {
Arc::new(KrakenRateLimiter::new(config))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiter_creation() {
let limiter = KrakenRateLimiter::kraken_defaults();
assert!(limiter.check(RateLimitCategory::Connection));
}
#[test]
fn test_rate_limiter_acquire() {
let limiter = KrakenRateLimiter::kraken_defaults();
let result = limiter.try_acquire(RateLimitCategory::WsOrders);
assert!(result.is_allowed());
for _ in 0..150 {
let result = limiter.try_acquire(RateLimitCategory::Connection);
assert!(result.is_allowed());
}
let result = limiter.try_acquire(RateLimitCategory::Connection);
assert!(!result.is_allowed());
assert!(result.wait_duration().is_some());
}
#[test]
fn test_rate_limiter_reset() {
let limiter = KrakenRateLimiter::kraken_defaults();
for _ in 0..15 {
limiter.try_acquire(RateLimitCategory::WsOrders);
}
assert!(!limiter.check(RateLimitCategory::WsOrders));
limiter.reset(RateLimitCategory::WsOrders);
assert!(limiter.check(RateLimitCategory::WsOrders));
}
#[test]
fn test_rate_limiter_utilization() {
let limiter = KrakenRateLimiter::kraken_defaults();
let util = limiter.utilization(RateLimitCategory::WsOrders);
assert!(util < 0.01);
for _ in 0..7 {
limiter.try_acquire(RateLimitCategory::WsOrders);
}
let util = limiter.utilization(RateLimitCategory::WsOrders);
assert!(util > 0.4 && util < 0.6);
}
#[test]
fn test_l3_depth_acquire() {
let limiter = KrakenRateLimiter::kraken_defaults();
for _ in 0..5 {
let result = limiter.try_acquire_l3(10);
assert!(result.is_allowed());
}
let result = limiter.try_acquire_l3(10);
assert!(!result.is_allowed());
}
#[test]
fn test_shared_rate_limiter() {
let limiter = shared_rate_limiter();
let limiter2 = Arc::clone(&limiter);
limiter.try_acquire(RateLimitCategory::WsOrders);
assert_eq!(limiter.available(RateLimitCategory::WsOrders), limiter2.available(RateLimitCategory::WsOrders));
}
#[tokio::test]
async fn test_async_acquire() {
let limiter = KrakenRateLimiter::permissive();
limiter.acquire(RateLimitCategory::Connection).await;
}
}