use crate::error::Result;
use crate::recursive::llm::{Llm, LmOutput};
use std::future::Future;
use std::pin::Pin;
use std::sync::Mutex;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub requests_per_second: f64,
pub burst: u32,
}
impl RateLimitConfig {
pub fn new(requests_per_second: f64) -> Self {
Self {
requests_per_second,
burst: 1,
}
}
pub fn with_burst(mut self, burst: u32) -> Self {
self.burst = burst;
self
}
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
requests_per_second: 10.0,
burst: 1,
}
}
}
#[derive(Debug)]
struct TokenBucketState {
tokens: f64,
max_tokens: f64,
refill_rate: f64,
last_refill: Instant,
}
impl TokenBucketState {
fn new(config: &RateLimitConfig) -> Self {
Self {
tokens: config.burst as f64,
max_tokens: config.burst as f64,
refill_rate: config.requests_per_second,
last_refill: Instant::now(),
}
}
fn try_acquire(&mut self) -> Option<Duration> {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill);
self.tokens += elapsed.as_secs_f64() * self.refill_rate;
if self.tokens > self.max_tokens {
self.tokens = self.max_tokens;
}
self.last_refill = now;
if self.tokens >= 1.0 {
self.tokens -= 1.0;
None
} else {
let deficit = 1.0 - self.tokens;
let wait_secs = deficit / self.refill_rate;
Some(Duration::from_secs_f64(wait_secs))
}
}
}
pub struct RateLimitedLlm<L: Llm> {
inner: L,
state: Mutex<TokenBucketState>,
}
impl<L: Llm> RateLimitedLlm<L> {
pub fn new(inner: L, config: RateLimitConfig) -> Self {
Self {
state: Mutex::new(TokenBucketState::new(&config)),
inner,
}
}
pub fn inner(&self) -> &L {
&self.inner
}
}
enum RateLimitState<'a, L: Llm + 'a> {
#[cfg(feature = "native")]
WaitingForSlot(Pin<Box<tokio::time::Sleep>>),
#[cfg(not(feature = "native"))]
WaitingForSlot(Instant, Duration),
Generating(Pin<Box<L::GenerateFut<'a>>>),
}
pub struct RateLimitFut<'a, L: Llm + 'a> {
llm: &'a L,
prompt: &'a str,
context: &'a str,
feedback: Option<&'a str>,
bucket: &'a Mutex<TokenBucketState>,
state: RateLimitState<'a, L>,
}
impl<'a, L: Llm + 'a> Unpin for RateLimitFut<'a, L> {}
impl<'a, L: Llm + 'a> Future for RateLimitFut<'a, L> {
type Output = Result<LmOutput>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
match &mut this.state {
#[cfg(feature = "native")]
RateLimitState::WaitingForSlot(sleep) => match sleep.as_mut().poll(cx) {
Poll::Ready(()) => match this
.bucket
.lock()
.expect("rate limiter lock poisoned")
.try_acquire()
{
None => {
let fut = this.llm.generate(this.prompt, this.context, this.feedback);
this.state = RateLimitState::Generating(Box::pin(fut));
}
Some(wait) => {
this.state =
RateLimitState::WaitingForSlot(Box::pin(tokio::time::sleep(wait)));
}
},
Poll::Pending => return Poll::Pending,
},
#[cfg(not(feature = "native"))]
RateLimitState::WaitingForSlot(start, duration) => {
if start.elapsed() >= *duration {
match this
.bucket
.lock()
.expect("rate limiter lock poisoned")
.try_acquire()
{
None => {
let fut =
this.llm.generate(this.prompt, this.context, this.feedback);
this.state = RateLimitState::Generating(Box::pin(fut));
}
Some(wait) => {
*start = Instant::now();
*duration = wait;
cx.waker().wake_by_ref();
return Poll::Pending;
}
}
} else {
cx.waker().wake_by_ref();
return Poll::Pending;
}
}
RateLimitState::Generating(fut) => {
return fut.as_mut().poll(cx);
}
}
}
}
}
impl<L: Llm> Llm for RateLimitedLlm<L> {
type GenerateFut<'a>
= RateLimitFut<'a, L>
where
Self: 'a;
fn generate<'a>(
&'a self,
prompt: &'a str,
context: &'a str,
feedback: Option<&'a str>,
) -> Self::GenerateFut<'a> {
match self
.state
.lock()
.expect("rate limiter lock poisoned")
.try_acquire()
{
None => {
let fut = self.inner.generate(prompt, context, feedback);
RateLimitFut {
llm: &self.inner,
prompt,
context,
feedback,
bucket: &self.state,
state: RateLimitState::Generating(Box::pin(fut)),
}
}
Some(wait) => {
#[cfg(feature = "native")]
let state = RateLimitState::WaitingForSlot(Box::pin(tokio::time::sleep(wait)));
#[cfg(not(feature = "native"))]
let state = RateLimitState::WaitingForSlot(Instant::now(), wait);
RateLimitFut {
llm: &self.inner,
prompt,
context,
feedback,
bucket: &self.state,
state,
}
}
}
}
fn model_name(&self) -> &str {
self.inner.model_name()
}
fn max_context(&self) -> usize {
self.inner.max_context()
}
}
pub trait RateLimitExt: Llm + Sized {
fn with_rate_limit(self, rps: f64) -> RateLimitedLlm<Self> {
RateLimitedLlm::new(self, RateLimitConfig::new(rps))
}
fn with_rate_limit_config(self, config: RateLimitConfig) -> RateLimitedLlm<Self> {
RateLimitedLlm::new(self, config)
}
}
impl<L: Llm> RateLimitExt for L {}
#[cfg(test)]
mod tests {
use super::*;
use crate::recursive::llm::MockLlm;
#[test]
fn test_config_defaults() {
let config = RateLimitConfig::default();
assert!((config.requests_per_second - 10.0).abs() < f64::EPSILON);
assert_eq!(config.burst, 1);
}
#[test]
fn test_config_builder() {
let config = RateLimitConfig::new(5.0).with_burst(10);
assert!((config.requests_per_second - 5.0).abs() < f64::EPSILON);
assert_eq!(config.burst, 10);
}
#[test]
fn test_token_bucket_immediate_acquire() {
let mut state = TokenBucketState::new(&RateLimitConfig::new(10.0).with_burst(5));
for _ in 0..5 {
assert!(state.try_acquire().is_none());
}
assert!(state.try_acquire().is_some());
}
#[test]
fn test_token_bucket_refill() {
let mut state = TokenBucketState::new(&RateLimitConfig::new(1000.0).with_burst(1));
assert!(state.try_acquire().is_none());
let wait = state.try_acquire();
assert!(wait.is_some());
let w = wait.unwrap();
assert!(w < Duration::from_millis(5));
}
#[test]
fn test_model_name_preserved() {
let llm = MockLlm::new(|_, _| "ok".to_string());
let limited = llm.with_rate_limit(10.0);
assert_eq!(limited.model_name(), "mock");
}
#[test]
fn test_inner_accessible() {
let llm = MockLlm::new(|_, _| "ok".to_string());
let limited = llm.with_rate_limit(10.0);
assert_eq!(limited.inner().model_name(), "mock");
}
#[tokio::test]
async fn test_rate_limit_allows_burst() {
let llm = MockLlm::new(|_, _| "ok".to_string())
.with_rate_limit_config(RateLimitConfig::new(10.0).with_burst(3));
let start = Instant::now();
for _ in 0..3 {
llm.generate("test", "", None).await.unwrap();
}
assert!(start.elapsed() < Duration::from_millis(50));
}
#[tokio::test]
async fn test_rate_limit_paces_after_burst() {
let llm = MockLlm::new(|_, _| "ok".to_string())
.with_rate_limit_config(RateLimitConfig::new(20.0).with_burst(1));
llm.generate("test", "", None).await.unwrap();
let start = Instant::now();
llm.generate("test", "", None).await.unwrap();
assert!(start.elapsed() >= Duration::from_millis(30));
}
#[tokio::test]
async fn test_rate_limit_composable_with_retry() {
use crate::recursive::retry::LlmExt;
let llm = MockLlm::new(|_, _| "ok".to_string())
.with_rate_limit(10.0)
.with_retry(3);
let result = llm.generate("test", "", None).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_rate_limit_composable_with_cache() {
use crate::recursive::cache::CacheExt;
let llm = MockLlm::new(|_, _| "ok".to_string())
.with_cache(10)
.with_rate_limit(10.0);
let result = llm.generate("test", "", None).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_rate_limit_full_chain() {
use crate::recursive::cache::CacheExt;
use crate::recursive::retry::LlmExt;
let llm = MockLlm::new(|_, _| "ok".to_string())
.with_cache(100)
.with_rate_limit(10.0)
.with_retry(3);
let result = llm.generate("test", "", None).await;
assert!(result.is_ok());
assert_eq!(&*result.unwrap().text, "ok");
}
#[test]
fn test_rate_limit_wait_via_block_on() {
let llm = MockLlm::new(|_, _| "ok".to_string())
.with_rate_limit_config(RateLimitConfig::new(10.0).with_burst(1));
crate::recursive::shared::block_on(async {
let r1 = llm.generate("test", "", None).await;
assert!(r1.is_ok());
let r2 = llm.generate("test", "", None).await;
assert!(r2.is_ok());
});
}
}