use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::time::sleep;
static PROCESS_START: std::sync::LazyLock<std::time::Instant> =
std::sync::LazyLock::new(std::time::Instant::now);
const OPENROUTER_FREE_INTERVAL: Duration = Duration::from_millis(4000);
pub static OPENROUTER_FREE_LIMITERS: std::sync::LazyLock<GlobalRateLimiter> =
std::sync::LazyLock::new(GlobalRateLimiter::new);
const QWEN_OAUTH_INTERVAL: Duration = Duration::from_millis(1500);
pub static QWEN_OAUTH_LIMITER: std::sync::LazyLock<Arc<RateLimiter>> =
std::sync::LazyLock::new(|| Arc::new(RateLimiter::new(QWEN_OAUTH_INTERVAL)));
#[derive(Debug)]
pub struct RateLimiter {
pub(crate) min_interval: Duration,
last_granted: AtomicU64,
}
impl RateLimiter {
pub fn new(min_interval: Duration) -> Self {
Self {
min_interval,
last_granted: AtomicU64::new(0),
}
}
fn now_ns() -> u64 {
PROCESS_START.elapsed().as_nanos() as u64
}
pub async fn wait(&self) -> Duration {
let now_ns = Self::now_ns();
loop {
let last = self.last_granted.load(Ordering::Acquire);
if last == 0 {
if self
.last_granted
.compare_exchange(0, now_ns, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return Duration::ZERO;
}
continue;
}
let elapsed_ns = now_ns.saturating_sub(last);
let elapsed = Duration::from_nanos(elapsed_ns);
if elapsed >= self.min_interval {
if self
.last_granted
.compare_exchange(last, now_ns, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return Duration::ZERO;
}
continue;
}
let sleep_for = self.min_interval - elapsed;
let grant_at = now_ns + sleep_for.as_nanos() as u64;
if self
.last_granted
.compare_exchange(last, grant_at, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
sleep(sleep_for).await;
return sleep_for;
}
}
}
}
pub struct GlobalRateLimiter {
limiters: Arc<Mutex<HashMap<String, Arc<RateLimiter>>>>,
}
impl GlobalRateLimiter {
fn new() -> Self {
Self {
limiters: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn get(&self, model: &str) -> Arc<RateLimiter> {
{
let map = self.limiters.lock().unwrap();
if let Some(limiter) = map.get(model) {
return Arc::clone(limiter);
}
}
let mut map = self.limiters.lock().unwrap();
if let Some(limiter) = map.get(model) {
return Arc::clone(limiter);
}
let limiter = Arc::new(RateLimiter::new(OPENROUTER_FREE_INTERVAL));
map.insert(model.to_string(), limiter.clone());
limiter
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rate_limiter_allows_first_request_immediately() {
let limiter = RateLimiter::new(Duration::from_millis(100));
let wait = limiter.wait().await;
assert!(wait.is_zero());
}
#[tokio::test]
async fn test_rate_limiter_enforces_interval() {
let limiter = RateLimiter::new(Duration::from_millis(100));
limiter.wait().await; let before = std::time::Instant::now();
limiter.wait().await; let actual = before.elapsed();
assert!(
actual.as_millis() >= 50,
"Expected ≥50ms wall-clock wait, got {}ms",
actual.as_millis()
);
}
#[tokio::test]
async fn test_global_limiter_returns_same_limiter_for_same_model() {
let global = GlobalRateLimiter::new();
let a = global.get("qwen/qwen3.6-plus:free");
let b = global.get("qwen/qwen3.6-plus:free");
assert!(Arc::ptr_eq(&a, &b));
}
#[tokio::test]
async fn test_global_limiter_returns_different_limiter_for_different_model() {
let global = GlobalRateLimiter::new();
let a = global.get("qwen/qwen3.6-plus:free");
let b = global.get("google/gemma-3-27b-it:free");
assert!(!Arc::ptr_eq(&a, &b));
}
}