use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::time::sleep;
static PROCESS_START: std::sync::LazyLock<std::time::Instant> =
std::sync::LazyLock::new(std::time::Instant::now);
const OPENROUTER_FREE_INTERVAL: Duration = Duration::from_millis(4000);
pub static OPENROUTER_FREE_LIMITERS: std::sync::LazyLock<GlobalRateLimiter> =
std::sync::LazyLock::new(GlobalRateLimiter::new);
const QWEN_OAUTH_INTERVAL: Duration = Duration::from_millis(1500);
pub static QWEN_OAUTH_LIMITER: std::sync::LazyLock<Arc<RateLimiter>> =
std::sync::LazyLock::new(|| Arc::new(RateLimiter::new(QWEN_OAUTH_INTERVAL)));
#[derive(Debug)]
pub struct RateLimiter {
pub(crate) min_interval: Duration,
last_granted: AtomicU64,
}
impl RateLimiter {
pub fn new(min_interval: Duration) -> Self {
Self {
min_interval,
last_granted: AtomicU64::new(0),
}
}
fn now_ns() -> u64 {
PROCESS_START.elapsed().as_nanos() as u64
}
pub async fn wait(&self) -> Duration {
let now_ns = Self::now_ns();
loop {
let last = self.last_granted.load(Ordering::Acquire);
if last == 0 {
if self
.last_granted
.compare_exchange(0, now_ns, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return Duration::ZERO;
}
continue;
}
let elapsed_ns = now_ns.saturating_sub(last);
let elapsed = Duration::from_nanos(elapsed_ns);
if elapsed >= self.min_interval {
if self
.last_granted
.compare_exchange(last, now_ns, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return Duration::ZERO;
}
continue;
}
let sleep_for = self.min_interval - elapsed;
let grant_at = now_ns + sleep_for.as_nanos() as u64;
if self
.last_granted
.compare_exchange(last, grant_at, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
sleep(sleep_for).await;
return sleep_for;
}
}
}
}
pub struct GlobalRateLimiter {
limiters: Arc<Mutex<HashMap<String, Arc<RateLimiter>>>>,
}
impl GlobalRateLimiter {
pub(crate) fn new() -> Self {
Self {
limiters: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn get(&self, model: &str) -> Arc<RateLimiter> {
{
let map = self.limiters.lock().unwrap();
if let Some(limiter) = map.get(model) {
return Arc::clone(limiter);
}
}
let mut map = self.limiters.lock().unwrap();
if let Some(limiter) = map.get(model) {
return Arc::clone(limiter);
}
let limiter = Arc::new(RateLimiter::new(OPENROUTER_FREE_INTERVAL));
map.insert(model.to_string(), limiter.clone());
limiter
}
}