use std::cell::RefCell;
use std::collections::{HashMap, VecDeque};
use std::time::{Duration, Instant};
const WINDOW_SECS: u64 = 60;
struct SlidingWindow {
max_requests: u32,
window: Duration,
timestamps: VecDeque<Instant>,
}
impl SlidingWindow {
fn new(rpm: u32) -> Self {
Self {
max_requests: rpm,
window: Duration::from_secs(WINDOW_SECS),
timestamps: VecDeque::with_capacity(rpm.min(1024) as usize),
}
}
fn check(&mut self) -> Option<Duration> {
let now = Instant::now();
let cutoff = now - self.window;
while self.timestamps.front().is_some_and(|t| *t < cutoff) {
self.timestamps.pop_front();
}
if (self.timestamps.len() as u32) < self.max_requests {
None
} else {
let oldest = *self.timestamps.front().unwrap();
Some((oldest + self.window).saturating_duration_since(now))
}
}
fn record(&mut self) {
self.timestamps.push_back(Instant::now());
}
}
thread_local! {
static LIMITERS: RefCell<HashMap<String, SlidingWindow>> = RefCell::new(HashMap::new());
}
pub(crate) fn init_from_config() {
let config = crate::llm_config::load_config();
LIMITERS.with(|limiters| {
let mut map = limiters.borrow_mut();
for (name, pdef) in &config.providers {
if let Some(rpm) = pdef.rpm {
if rpm > 0 {
map.insert(name.clone(), SlidingWindow::new(rpm));
}
}
}
});
for (key, val) in std::env::vars() {
if let Some(provider) = key.strip_prefix("HARN_RATE_LIMIT_") {
if let Ok(rpm) = val.parse::<u32>() {
let provider = provider.to_lowercase();
LIMITERS.with(|limiters| {
let mut map = limiters.borrow_mut();
if rpm > 0 {
map.insert(provider, SlidingWindow::new(rpm));
} else {
map.remove(&provider);
}
});
}
}
}
}
pub(crate) fn set_rate_limit(provider: &str, rpm: u32) {
LIMITERS.with(|limiters| {
limiters
.borrow_mut()
.insert(provider.to_string(), SlidingWindow::new(rpm));
});
}
pub(crate) fn clear_rate_limit(provider: &str) {
LIMITERS.with(|limiters| {
limiters.borrow_mut().remove(provider);
});
}
pub(crate) fn get_rate_limit(provider: &str) -> Option<u32> {
LIMITERS.with(|limiters| limiters.borrow().get(provider).map(|sw| sw.max_requests))
}
pub(crate) async fn acquire_permit(provider: &str) {
loop {
let wait = LIMITERS.with(|limiters| {
let mut map = limiters.borrow_mut();
if let Some(sw) = map.get_mut(provider) {
if let Some(duration) = sw.check() {
return Some(duration);
}
sw.record();
}
None
});
match wait {
Some(duration) => {
crate::events::log_debug(
"llm.rate_limit",
&format!(
"Rate limit for '{}': throttling for {}ms",
provider,
duration.as_millis()
),
);
tokio::time::sleep(duration).await;
}
None => return,
}
}
}
pub(crate) fn reset_rate_limit_state() {
LIMITERS.with(|limiters| limiters.borrow_mut().clear());
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sliding_window_allows_within_limit() {
let mut sw = SlidingWindow::new(3);
assert!(sw.check().is_none());
sw.record();
assert!(sw.check().is_none());
sw.record();
assert!(sw.check().is_none());
sw.record();
assert!(sw.check().is_some());
}
#[test]
fn test_sliding_window_returns_wait_duration() {
let mut sw = SlidingWindow::new(1);
sw.record();
let wait = sw.check();
assert!(wait.is_some());
let d = wait.unwrap();
assert!(d.as_secs() <= 60);
assert!(d.as_secs() >= 58);
}
#[test]
fn test_set_and_get_rate_limit() {
reset_rate_limit_state();
assert_eq!(get_rate_limit("test_provider"), None);
set_rate_limit("test_provider", 100);
assert_eq!(get_rate_limit("test_provider"), Some(100));
clear_rate_limit("test_provider");
assert_eq!(get_rate_limit("test_provider"), None);
}
#[tokio::test]
async fn test_acquire_permit_no_limit() {
reset_rate_limit_state();
acquire_permit("unconfigured_provider").await;
}
#[tokio::test]
async fn test_acquire_permit_within_limit() {
reset_rate_limit_state();
set_rate_limit("test_prov", 100);
acquire_permit("test_prov").await;
acquire_permit("test_prov").await;
}
}