use crate::{Context, Model, Provider, ProviderError, ProviderEvent, StreamOptions};
use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Semaphore;
#[derive(Debug, Clone)]
pub struct RateLimitPolicy {
pub rpm: u32,
pub max_concurrent: usize,
}
impl RateLimitPolicy {
pub fn rpm(rpm: u32) -> Self {
Self {
rpm,
max_concurrent: (rpm as usize / 6).max(1),
}
}
pub fn per_second(rps: u32) -> Self {
Self {
rpm: rps * 60,
max_concurrent: rps as usize,
}
}
pub fn unlimited() -> Self {
Self {
rpm: u32::MAX,
max_concurrent: usize::MAX,
}
}
pub fn with_max_concurrent(mut self, max: usize) -> Self {
self.max_concurrent = max;
self
}
}
struct RateLimiterState {
rpm: u32,
timestamps: Vec<Instant>,
}
impl RateLimiterState {
fn new(rpm: u32) -> Self {
Self {
rpm,
timestamps: Vec::with_capacity(64),
}
}
fn prune(&mut self) {
let cutoff = Instant::now() - Duration::from_secs(60);
self.timestamps.retain(|&t| t > cutoff);
}
fn can_proceed(&mut self) -> bool {
self.prune();
(self.timestamps.len() as u32) < self.rpm
}
fn record(&mut self) {
self.timestamps.push(Instant::now());
}
}
pub struct ProviderPool {
inner: Arc<dyn Provider>,
semaphore: Arc<Semaphore>,
limiter: Arc<tokio::sync::Mutex<RateLimiterState>>,
pool_name: String,
}
impl ProviderPool {
pub fn new(
provider: Arc<dyn Provider>,
policy: RateLimitPolicy,
name: impl Into<String>,
) -> Self {
Self {
inner: provider,
semaphore: Arc::new(Semaphore::new(policy.max_concurrent)),
limiter: Arc::new(tokio::sync::Mutex::new(RateLimiterState::new(policy.rpm))),
pool_name: name.into(),
}
}
}
#[async_trait]
impl Provider for ProviderPool {
fn name(&self) -> &str {
&self.pool_name
}
async fn stream(
&self,
model: &Model,
context: &Context,
options: Option<StreamOptions>,
) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
let _permit = self
.semaphore
.acquire()
.await
.map_err(|_| ProviderError::RateLimited {
retry_after: Some(Duration::from_secs(5)),
})?;
{
let mut limiter = self.limiter.lock().await;
if !limiter.can_proceed() {
drop(limiter);
tokio::time::sleep(Duration::from_secs(1)).await;
let mut limiter = self.limiter.lock().await;
if !limiter.can_proceed() {
return Err(ProviderError::RateLimited {
retry_after: Some(Duration::from_secs(5)),
});
}
limiter.record();
} else {
limiter.record();
}
}
self.inner.stream(model, context, options).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limit_policy_rpm() {
let policy = RateLimitPolicy::rpm(60);
assert_eq!(policy.rpm, 60);
assert_eq!(policy.max_concurrent, 10); }
#[test]
fn test_rate_limit_policy_unlimited() {
let policy = RateLimitPolicy::unlimited();
assert_eq!(policy.rpm, u32::MAX);
assert_eq!(policy.max_concurrent, usize::MAX);
}
#[test]
fn test_rate_limit_policy_custom_concurrency() {
let policy = RateLimitPolicy::rpm(60).with_max_concurrent(3);
assert_eq!(policy.max_concurrent, 3);
}
#[tokio::test]
async fn test_rate_limiter_state_allows_within_limit() {
let mut state = RateLimiterState::new(5);
assert!(state.can_proceed());
state.record();
assert!(state.can_proceed());
}
#[tokio::test]
async fn test_rate_limiter_state_blocks_at_limit() {
let mut state = RateLimiterState::new(2);
assert!(state.can_proceed());
state.record();
assert!(state.can_proceed());
state.record();
assert!(!state.can_proceed()); }
}