use std::sync::RwLock;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tracing::{debug, info};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) enum RateLimitCategory {
Account,
MarketData,
}
impl std::fmt::Display for RateLimitCategory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Account => write!(f, "account"),
Self::MarketData => write!(f, "market_data"),
}
}
}
#[derive(Debug, Clone, Default)]
struct CategoryState {
remaining: Option<u32>,
reset_epoch: Option<u64>,
}
pub(crate) struct RateLimiter {
account: RwLock<CategoryState>,
market_data: RwLock<CategoryState>,
}
impl RateLimiter {
pub(crate) fn new() -> Self {
Self {
account: RwLock::new(CategoryState::default()),
market_data: RwLock::new(CategoryState::default()),
}
}
pub(crate) fn classify(path: &str) -> RateLimitCategory {
if path.starts_with("/markets") {
RateLimitCategory::MarketData
} else {
RateLimitCategory::Account
}
}
fn state_for(&self, category: RateLimitCategory) -> &RwLock<CategoryState> {
match category {
RateLimitCategory::Account => &self.account,
RateLimitCategory::MarketData => &self.market_data,
}
}
pub(crate) fn wait_duration(&self, category: RateLimitCategory) -> Option<Duration> {
let state = self
.state_for(category)
.read()
.expect("rate limit lock poisoned");
if state.remaining != Some(0) {
return None;
}
let reset_epoch = state.reset_epoch?;
let now_epoch = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock before UNIX epoch")
.as_secs();
if reset_epoch > now_epoch {
Some(Duration::from_secs(reset_epoch - now_epoch) + Duration::from_millis(100))
} else {
None
}
}
pub(crate) fn update_from_headers(
&self,
category: RateLimitCategory,
headers: &reqwest::header::HeaderMap,
) {
let remaining = headers
.get("X-RateLimit-Remaining")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.trim().parse::<u32>().ok());
let reset_epoch = headers
.get("X-RateLimit-Reset")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.trim().parse::<u64>().ok());
if remaining.is_none() && reset_epoch.is_none() {
return;
}
let mut state = self
.state_for(category)
.write()
.expect("rate limit lock poisoned");
if let Some(r) = remaining {
state.remaining = Some(r);
}
if let Some(e) = reset_epoch {
state.reset_epoch = Some(e);
}
if state.remaining == Some(0) {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock before UNIX epoch")
.as_secs();
let wait_secs = state.reset_epoch.and_then(|r| r.checked_sub(now));
let reset_at = state.reset_epoch.map(|epoch| {
time::OffsetDateTime::from_unix_timestamp(epoch as i64)
.map(|dt| {
dt.format(&time::format_description::well_known::Rfc3339)
.unwrap_or_else(|_| epoch.to_string())
})
.unwrap_or_else(|_| epoch.to_string())
});
info!(
category = %category,
reset_at = reset_at.as_deref(),
wait_secs,
"rate limit exhausted, will block requests until reset",
);
} else {
debug!(
category = %category,
remaining,
"rate limit state updated",
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn classify_account_endpoints() {
assert_eq!(RateLimiter::classify("/time"), RateLimitCategory::Account);
assert_eq!(
RateLimiter::classify("/accounts/123/positions"),
RateLimitCategory::Account,
);
assert_eq!(
RateLimiter::classify("/accounts/123/balances"),
RateLimitCategory::Account,
);
assert_eq!(
RateLimiter::classify("/accounts/123/orders"),
RateLimitCategory::Account,
);
assert_eq!(
RateLimiter::classify("/accounts/123/executions"),
RateLimitCategory::Account,
);
assert_eq!(
RateLimiter::classify("/accounts/123/activities"),
RateLimitCategory::Account,
);
assert_eq!(
RateLimiter::classify("/symbols/search?prefix=AAPL"),
RateLimitCategory::Account,
);
assert_eq!(
RateLimiter::classify("/symbols/12345"),
RateLimitCategory::Account,
);
assert_eq!(
RateLimiter::classify("/symbols/12345/options"),
RateLimitCategory::Account,
);
}
#[test]
fn classify_market_data_endpoints() {
assert_eq!(
RateLimiter::classify("/markets/quotes/12345"),
RateLimitCategory::MarketData,
);
assert_eq!(
RateLimiter::classify("/markets/quotes/options"),
RateLimitCategory::MarketData,
);
assert_eq!(
RateLimiter::classify("/markets/candles/12345"),
RateLimitCategory::MarketData,
);
assert_eq!(
RateLimiter::classify("/markets"),
RateLimitCategory::MarketData,
);
}
#[test]
fn wait_duration_none_when_no_state() {
let rl = RateLimiter::new();
assert!(rl.wait_duration(RateLimitCategory::Account).is_none());
assert!(rl.wait_duration(RateLimitCategory::MarketData).is_none());
}
#[test]
fn wait_duration_none_when_remaining_positive() {
let rl = RateLimiter::new();
{
let mut state = rl.account.write().unwrap();
state.remaining = Some(10);
state.reset_epoch = Some(u64::MAX);
}
assert!(rl.wait_duration(RateLimitCategory::Account).is_none());
}
#[test]
fn wait_duration_some_when_exhausted_and_reset_in_future() {
let rl = RateLimiter::new();
let future = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 5;
{
let mut state = rl.account.write().unwrap();
state.remaining = Some(0);
state.reset_epoch = Some(future);
}
let wait = rl.wait_duration(RateLimitCategory::Account).unwrap();
assert!(wait.as_secs() >= 4 && wait.as_secs() <= 6);
}
#[test]
fn wait_duration_none_when_exhausted_but_reset_in_past() {
let rl = RateLimiter::new();
let past = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
- 10;
{
let mut state = rl.account.write().unwrap();
state.remaining = Some(0);
state.reset_epoch = Some(past);
}
assert!(rl.wait_duration(RateLimitCategory::Account).is_none());
}
#[test]
fn wait_duration_none_when_exhausted_but_no_reset_epoch() {
let rl = RateLimiter::new();
{
let mut state = rl.market_data.write().unwrap();
state.remaining = Some(0);
state.reset_epoch = None;
}
assert!(rl.wait_duration(RateLimitCategory::MarketData).is_none());
}
#[test]
fn update_from_headers_parses_both_headers() {
let rl = RateLimiter::new();
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("X-RateLimit-Remaining", "42".parse().unwrap());
headers.insert("X-RateLimit-Reset", "1700000000".parse().unwrap());
rl.update_from_headers(RateLimitCategory::Account, &headers);
let state = rl.account.read().unwrap();
assert_eq!(state.remaining, Some(42));
assert_eq!(state.reset_epoch, Some(1_700_000_000));
}
#[test]
fn update_from_headers_ignores_missing_headers() {
let rl = RateLimiter::new();
let headers = reqwest::header::HeaderMap::new();
rl.update_from_headers(RateLimitCategory::Account, &headers);
let state = rl.account.read().unwrap();
assert!(state.remaining.is_none());
assert!(state.reset_epoch.is_none());
}
#[test]
fn update_from_headers_ignores_malformed_values() {
let rl = RateLimiter::new();
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("X-RateLimit-Remaining", "not-a-number".parse().unwrap());
headers.insert("X-RateLimit-Reset", "also-bad".parse().unwrap());
rl.update_from_headers(RateLimitCategory::Account, &headers);
let state = rl.account.read().unwrap();
assert!(state.remaining.is_none());
assert!(state.reset_epoch.is_none());
}
#[test]
fn categories_are_independent() {
let rl = RateLimiter::new();
let future = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 10;
{
let mut state = rl.account.write().unwrap();
state.remaining = Some(0);
state.reset_epoch = Some(future);
}
assert!(rl.wait_duration(RateLimitCategory::Account).is_some());
assert!(rl.wait_duration(RateLimitCategory::MarketData).is_none());
}
}