1use crate::{Context, Model, Provider, ProviderError, ProviderEvent, StreamOptions};
8use async_trait::async_trait;
9use futures::Stream;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tokio::sync::Semaphore;
14
15#[derive(Debug, Clone)]
17pub struct RateLimitPolicy {
18 pub rpm: u32,
20 pub max_concurrent: usize,
22}
23
24impl RateLimitPolicy {
25 pub fn rpm(rpm: u32) -> Self {
29 Self {
30 rpm,
31 max_concurrent: (rpm as usize / 6).max(1),
32 }
33 }
34
35 pub fn per_second(rps: u32) -> Self {
37 Self {
38 rpm: rps * 60,
39 max_concurrent: rps as usize,
40 }
41 }
42
43 pub fn unlimited() -> Self {
45 Self {
46 rpm: u32::MAX,
47 max_concurrent: usize::MAX,
48 }
49 }
50
51 pub fn with_max_concurrent(mut self, max: usize) -> Self {
53 self.max_concurrent = max;
54 self
55 }
56}
57
58struct RateLimiterState {
60 rpm: u32,
61 timestamps: Vec<Instant>,
62}
63
64impl RateLimiterState {
65 fn new(rpm: u32) -> Self {
66 Self {
67 rpm,
68 timestamps: Vec::with_capacity(64),
69 }
70 }
71
72 fn prune(&mut self) {
74 let cutoff = Instant::now() - Duration::from_secs(60);
75 self.timestamps.retain(|&t| t > cutoff);
76 }
77
78 fn can_proceed(&mut self) -> bool {
80 self.prune();
81 (self.timestamps.len() as u32) < self.rpm
82 }
83
84 fn record(&mut self) {
86 self.timestamps.push(Instant::now());
87 }
88}
89
90pub struct ProviderPool {
95 inner: Arc<dyn Provider>,
96 semaphore: Arc<Semaphore>,
97 limiter: Arc<tokio::sync::Mutex<RateLimiterState>>,
98 pool_name: String,
99}
100
101impl ProviderPool {
102 pub fn new(
104 provider: Arc<dyn Provider>,
105 policy: RateLimitPolicy,
106 name: impl Into<String>,
107 ) -> Self {
108 Self {
109 inner: provider,
110 semaphore: Arc::new(Semaphore::new(policy.max_concurrent)),
111 limiter: Arc::new(tokio::sync::Mutex::new(RateLimiterState::new(policy.rpm))),
112 pool_name: name.into(),
113 }
114 }
115}
116
117#[async_trait]
118impl Provider for ProviderPool {
119 fn name(&self) -> &str {
120 &self.pool_name
121 }
122
123 async fn stream(
124 &self,
125 model: &Model,
126 context: &Context,
127 options: Option<StreamOptions>,
128 ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
129 let _permit = self
131 .semaphore
132 .acquire()
133 .await
134 .map_err(|_| ProviderError::RateLimited {
135 retry_after: Some(Duration::from_secs(5)),
136 })?;
137
138 {
140 let mut limiter = self.limiter.lock().await;
141 if !limiter.can_proceed() {
142 drop(limiter);
144 tokio::time::sleep(Duration::from_secs(1)).await;
145 let mut limiter = self.limiter.lock().await;
146 if !limiter.can_proceed() {
147 return Err(ProviderError::RateLimited {
148 retry_after: Some(Duration::from_secs(5)),
149 });
150 }
151 limiter.record();
152 } else {
153 limiter.record();
154 }
155 }
156
157 self.inner.stream(model, context, options).await
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn test_rate_limit_policy_rpm() {
168 let policy = RateLimitPolicy::rpm(60);
169 assert_eq!(policy.rpm, 60);
170 assert_eq!(policy.max_concurrent, 10); }
172
173 #[test]
174 fn test_rate_limit_policy_unlimited() {
175 let policy = RateLimitPolicy::unlimited();
176 assert_eq!(policy.rpm, u32::MAX);
177 assert_eq!(policy.max_concurrent, usize::MAX);
178 }
179
180 #[test]
181 fn test_rate_limit_policy_custom_concurrency() {
182 let policy = RateLimitPolicy::rpm(60).with_max_concurrent(3);
183 assert_eq!(policy.max_concurrent, 3);
184 }
185
186 #[tokio::test]
187 async fn test_rate_limiter_state_allows_within_limit() {
188 let mut state = RateLimiterState::new(5);
189 assert!(state.can_proceed());
190 state.record();
191 assert!(state.can_proceed());
192 }
193
194 #[tokio::test]
195 async fn test_rate_limiter_state_blocks_at_limit() {
196 let mut state = RateLimiterState::new(2);
197 assert!(state.can_proceed());
198 state.record();
199 assert!(state.can_proceed());
200 state.record();
201 assert!(!state.can_proceed()); }
203}