Skip to main content

oxi_ai/
provider_pool.rs

1//! Provider pool — rate limiting and concurrency control for provider calls.
2//!
3//! Wraps any `Provider` with a semaphore (max concurrent requests) and
4//! a simple sliding-window rate limiter (RPM). Used when multiple agents
5//! share a single API key and need coordinated access.
6
7use crate::{Context, Model, Provider, ProviderError, ProviderEvent, StreamOptions};
8use async_trait::async_trait;
9use futures::Stream;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tokio::sync::Semaphore;
14
15/// Rate-limiting and concurrency policy for a provider pool.
16#[derive(Debug, Clone)]
17pub struct RateLimitPolicy {
18    /// Maximum requests per minute.
19    pub rpm: u32,
20    /// Maximum number of concurrent in-flight requests.
21    pub max_concurrent: usize,
22}
23
24impl RateLimitPolicy {
25    /// Create a policy with the given RPM.
26    ///
27    /// Sets `max_concurrent` to `rpm / 6` (distributed over 10-second windows).
28    pub fn rpm(rpm: u32) -> Self {
29        Self {
30            rpm,
31            max_concurrent: (rpm as usize / 6).max(1),
32        }
33    }
34
35    /// Create a policy with the given requests-per-second.
36    pub fn per_second(rps: u32) -> Self {
37        Self {
38            rpm: rps * 60,
39            max_concurrent: rps as usize,
40        }
41    }
42
43    /// Create an unrestricted policy (no rate limiting).
44    pub fn unlimited() -> Self {
45        Self {
46            rpm: u32::MAX,
47            max_concurrent: usize::MAX,
48        }
49    }
50
51    /// Set a custom max-concurrency override.
52    pub fn with_max_concurrent(mut self, max: usize) -> Self {
53        self.max_concurrent = max;
54        self
55    }
56}
57
58/// Internal rate-limiter state (sliding window).
59struct RateLimiterState {
60    rpm: u32,
61    timestamps: Vec<Instant>,
62}
63
64impl RateLimiterState {
65    fn new(rpm: u32) -> Self {
66        Self {
67            rpm,
68            timestamps: Vec::with_capacity(64),
69        }
70    }
71
72    /// Remove timestamps older than 60 seconds.
73    fn prune(&mut self) {
74        let cutoff = Instant::now() - Duration::from_secs(60);
75        self.timestamps.retain(|&t| t > cutoff);
76    }
77
78    /// Check if a new request can be made within the RPM limit.
79    fn can_proceed(&mut self) -> bool {
80        self.prune();
81        (self.timestamps.len() as u32) < self.rpm
82    }
83
84    /// Record a request timestamp.
85    fn record(&mut self) {
86        self.timestamps.push(Instant::now());
87    }
88}
89
90/// A `Provider` wrapper that enforces rate limits and concurrency.
91///
92/// Multiple agents sharing an API key should go through a single
93/// `ProviderPool` instance to avoid exceeding rate limits.
94pub struct ProviderPool {
95    inner: Arc<dyn Provider>,
96    semaphore: Arc<Semaphore>,
97    limiter: Arc<tokio::sync::Mutex<RateLimiterState>>,
98    pool_name: String,
99}
100
101impl ProviderPool {
102    /// Create a new pool wrapping the given provider with the specified policy.
103    pub fn new(
104        provider: Arc<dyn Provider>,
105        policy: RateLimitPolicy,
106        name: impl Into<String>,
107    ) -> Self {
108        Self {
109            inner: provider,
110            semaphore: Arc::new(Semaphore::new(policy.max_concurrent)),
111            limiter: Arc::new(tokio::sync::Mutex::new(RateLimiterState::new(policy.rpm))),
112            pool_name: name.into(),
113        }
114    }
115}
116
117#[async_trait]
118impl Provider for ProviderPool {
119    fn name(&self) -> &str {
120        &self.pool_name
121    }
122
123    async fn stream(
124        &self,
125        model: &Model,
126        context: &Context,
127        options: Option<StreamOptions>,
128    ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
129        // 1. Acquire concurrency permit
130        let _permit = self
131            .semaphore
132            .acquire()
133            .await
134            .map_err(|_| ProviderError::RateLimited {
135                retry_after: Some(Duration::from_secs(5)),
136            })?;
137
138        // 2. Rate-limit check with simple backoff
139        {
140            let mut limiter = self.limiter.lock().await;
141            if !limiter.can_proceed() {
142                // Simple: wait 1 second and retry once
143                drop(limiter);
144                tokio::time::sleep(Duration::from_secs(1)).await;
145                let mut limiter = self.limiter.lock().await;
146                if !limiter.can_proceed() {
147                    return Err(ProviderError::RateLimited {
148                        retry_after: Some(Duration::from_secs(5)),
149                    });
150                }
151                limiter.record();
152            } else {
153                limiter.record();
154            }
155        }
156
157        // 3. Delegate to inner provider
158        self.inner.stream(model, context, options).await
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn test_rate_limit_policy_rpm() {
168        let policy = RateLimitPolicy::rpm(60);
169        assert_eq!(policy.rpm, 60);
170        assert_eq!(policy.max_concurrent, 10); // 60/6
171    }
172
173    #[test]
174    fn test_rate_limit_policy_unlimited() {
175        let policy = RateLimitPolicy::unlimited();
176        assert_eq!(policy.rpm, u32::MAX);
177        assert_eq!(policy.max_concurrent, usize::MAX);
178    }
179
180    #[test]
181    fn test_rate_limit_policy_custom_concurrency() {
182        let policy = RateLimitPolicy::rpm(60).with_max_concurrent(3);
183        assert_eq!(policy.max_concurrent, 3);
184    }
185
186    #[tokio::test]
187    async fn test_rate_limiter_state_allows_within_limit() {
188        let mut state = RateLimiterState::new(5);
189        assert!(state.can_proceed());
190        state.record();
191        assert!(state.can_proceed());
192    }
193
194    #[tokio::test]
195    async fn test_rate_limiter_state_blocks_at_limit() {
196        let mut state = RateLimiterState::new(2);
197        assert!(state.can_proceed());
198        state.record();
199        assert!(state.can_proceed());
200        state.record();
201        assert!(!state.can_proceed()); // at limit
202    }
203}