Skip to main content

byokey_provider/
http_util.rs

1//! Shared HTTP utilities for provider executors.
2//!
3//! Eliminates duplicated send → status-check → stream-or-complete logic
4//! across all executor implementations.
5
6use byokey_types::{
7    ByokError, ProviderId, RateLimitSnapshot, RateLimitStore,
8    traits::{ByteStream, ProviderResponse, Result},
9};
10use futures_util::StreamExt as _;
11use rquest::{Client, RequestBuilder};
12use serde_json::Value;
13use std::collections::HashMap;
14use std::sync::Arc;
15
16/// Optional rate-limit capture context attached to a `ProviderHttp`.
17#[derive(Clone)]
18struct RateLimitCtx {
19    store: Arc<RateLimitStore>,
20    provider: ProviderId,
21    account_id: String,
22}
23
24/// Shared HTTP helper that all executors can use to send requests and
25/// handle the common response patterns (status check, stream vs complete).
26#[derive(Clone)]
27pub struct ProviderHttp {
28    http: Client,
29    rl_ctx: Option<RateLimitCtx>,
30}
31
32impl ProviderHttp {
33    /// Creates a new helper wrapping the given HTTP client.
34    #[must_use]
35    pub fn new(http: Client) -> Self {
36        Self { http, rl_ctx: None }
37    }
38
39    /// Attaches rate-limit capture context. Headers from every response
40    /// sent through this helper will be stored in `store`.
41    #[must_use]
42    pub fn with_ratelimit(mut self, store: Arc<RateLimitStore>, provider: ProviderId) -> Self {
43        self.rl_ctx = Some(RateLimitCtx {
44            store,
45            provider,
46            account_id: "active".to_string(),
47        });
48        self
49    }
50
51    /// Returns a reference to the inner HTTP client for building requests.
52    #[must_use]
53    pub fn client(&self) -> &Client {
54        &self.http
55    }
56
57    /// Extracts rate-limit-related headers from the response and writes
58    /// them into the store (if a context is configured).
59    fn capture_ratelimit_headers(&self, headers: &rquest::header::HeaderMap) {
60        let Some(ctx) = &self.rl_ctx else { return };
61
62        let mut captured = HashMap::new();
63        for (name, value) in headers {
64            let key = name.as_str();
65            // Capture anthropic-ratelimit-*, x-ratelimit-*, retry-after
66            if (key.starts_with("anthropic-ratelimit-")
67                || key.starts_with("x-ratelimit-")
68                || key == "retry-after")
69                && let Ok(v) = value.to_str()
70            {
71                captured.insert(key.to_string(), v.to_string());
72            }
73        }
74
75        if captured.is_empty() {
76            return;
77        }
78
79        let now = std::time::SystemTime::now()
80            .duration_since(std::time::UNIX_EPOCH)
81            .unwrap_or_default()
82            .as_secs();
83
84        ctx.store.update(
85            ctx.provider.clone(),
86            ctx.account_id.clone(),
87            RateLimitSnapshot {
88                headers: captured,
89                captured_at: now,
90            },
91        );
92    }
93
94    /// Sends a request and checks for success status.
95    ///
96    /// On non-2xx responses, reads the body text and returns
97    /// [`ByokError::Upstream`]. Rate limit headers are captured from
98    /// **both** success and error responses.
99    ///
100    /// # Errors
101    ///
102    /// Returns `ByokError::Upstream` on non-success HTTP status codes,
103    /// or a transport error if the request fails to send.
104    pub async fn send(&self, builder: RequestBuilder) -> Result<rquest::Response> {
105        let resp = builder.send().await?;
106        // Capture rate limit headers before consuming the body.
107        self.capture_ratelimit_headers(resp.headers());
108        let status = resp.status();
109        if status.is_success() {
110            Ok(resp)
111        } else {
112            let retry_after = parse_retry_after_header(resp.headers());
113            let text = resp.text().await.unwrap_or_default();
114            let retry_after = parse_retry_after_body(&text, status.as_u16()).or(retry_after);
115            Err(ByokError::Upstream {
116                status: status.as_u16(),
117                body: text,
118                retry_after,
119            })
120        }
121    }
122
123    /// Sends a request and returns a `ProviderResponse` for OpenAI-passthrough
124    /// providers (those that don't need response translation).
125    ///
126    /// If `stream` is true, wraps the bytes stream; otherwise parses JSON.
127    ///
128    /// # Errors
129    ///
130    /// Returns `ByokError::Upstream` on non-success status, or a transport/parse error.
131    pub async fn send_passthrough(
132        &self,
133        builder: RequestBuilder,
134        stream: bool,
135    ) -> Result<ProviderResponse> {
136        let resp = self.send(builder).await?;
137        if stream {
138            Ok(ProviderResponse::Stream(Self::byte_stream(resp)))
139        } else {
140            let json: Value = resp.json().await?;
141            Ok(ProviderResponse::Complete(json))
142        }
143    }
144
145    /// Converts an `rquest::Response` into a `ByteStream`.
146    #[must_use]
147    pub fn byte_stream(resp: rquest::Response) -> ByteStream {
148        Box::pin(resp.bytes_stream().map(|r| r.map_err(ByokError::from)))
149    }
150}
151
152/// Parse `Retry-After` header value (seconds integer).
153fn parse_retry_after_header(headers: &rquest::header::HeaderMap) -> Option<std::time::Duration> {
154    let val = headers.get("retry-after")?.to_str().ok()?;
155    let secs: u64 = val.parse().ok()?;
156    Some(std::time::Duration::from_secs(secs))
157}
158
159/// Parse retry delay from a 429 response body.
160///
161/// Supports multiple provider formats:
162/// - **Codex**: `error.type == "usage_limit_reached"` with `error.resets_in_seconds`
163///   or `error.resets_at` (unix timestamp).
164/// - **Google/Antigravity**: `error.details[]` with `retryDelay` (from `RetryInfo`)
165///   or `metadata.quotaResetDelay` (from `ErrorInfo`).
166fn parse_retry_after_body(body: &str, status: u16) -> Option<std::time::Duration> {
167    if status != 429 {
168        return None;
169    }
170    let json: serde_json::Value = serde_json::from_str(body).ok()?;
171
172    // Codex format: error.type == "usage_limit_reached"
173    if let Some(error) = json.get("error")
174        && error.get("type").and_then(serde_json::Value::as_str) == Some("usage_limit_reached")
175    {
176        if let Some(secs) = error
177            .get("resets_in_seconds")
178            .and_then(serde_json::Value::as_u64)
179        {
180            return Some(std::time::Duration::from_secs(secs));
181        }
182        if let Some(ts) = error.get("resets_at").and_then(serde_json::Value::as_u64) {
183            let now = std::time::SystemTime::now()
184                .duration_since(std::time::UNIX_EPOCH)
185                .unwrap_or_default()
186                .as_secs();
187            if ts > now {
188                return Some(std::time::Duration::from_secs(ts - now));
189            }
190        }
191    }
192
193    // Google/Antigravity format: error.details[] with `RetryInfo` or `ErrorInfo`
194    if let Some(details) = json.pointer("/error/details").and_then(Value::as_array) {
195        for detail in details {
196            if detail.get("@type").and_then(Value::as_str)
197                == Some("type.googleapis.com/google.rpc.RetryInfo")
198                && let Some(delay_str) = detail.get("retryDelay").and_then(Value::as_str)
199                && let Some(d) = parse_google_duration(delay_str)
200            {
201                return Some(d);
202            }
203        }
204        for detail in details {
205            if detail.get("@type").and_then(Value::as_str)
206                == Some("type.googleapis.com/google.rpc.ErrorInfo")
207                && let Some(delay_str) = detail
208                    .pointer("/metadata/quotaResetDelay")
209                    .and_then(Value::as_str)
210                && let Some(d) = parse_google_duration(delay_str)
211            {
212                return Some(d);
213            }
214        }
215    }
216
217    None
218}
219
220/// Parse a Google-style duration string like `"0.847655010s"` or `"373.801628ms"`.
221fn parse_google_duration(s: &str) -> Option<std::time::Duration> {
222    if let Some(ms_str) = s.strip_suffix("ms") {
223        let ms: f64 = ms_str.parse().ok()?;
224        return Some(std::time::Duration::from_secs_f64(ms / 1000.0));
225    }
226    if let Some(secs_str) = s.strip_suffix('s') {
227        let secs: f64 = secs_str.parse().ok()?;
228        return Some(std::time::Duration::from_secs_f64(secs));
229    }
230    None
231}
232
233/// Returns the appropriate `Accept` header value for a request.
234///
235/// Streaming requests need `text/event-stream`; non-streaming need `application/json`.
236#[must_use]
237pub fn accept_for_stream(stream: bool) -> &'static str {
238    if stream {
239        "text/event-stream"
240    } else {
241        "application/json"
242    }
243}
244
245/// Injects `stream_options: { include_usage: true }` into the body when streaming.
246///
247/// Used by OpenAI-passthrough providers (Kimi, Qwen, iFlow, Copilot, Gemini).
248pub fn ensure_stream_options(body: &mut serde_json::Value, stream: bool) {
249    if stream {
250        body["stream_options"] = serde_json::json!({ "include_usage": true });
251    }
252}
253
254/// Resolves a bearer token: returns the API key if present, otherwise fetches
255/// an OAuth token from the [`AuthManager`](byokey_auth::AuthManager).
256///
257/// This is the common pattern used by most providers (Kimi, Qwen, iFlow,
258/// Antigravity, Kiro).
259///
260/// # Errors
261///
262/// Returns an error if the OAuth token fetch fails.
263pub async fn resolve_bearer_token(
264    api_key: Option<&str>,
265    auth: &Arc<byokey_auth::AuthManager>,
266    provider: &ProviderId,
267) -> byokey_types::Result<String> {
268    if let Some(key) = api_key {
269        return Ok(key.to_string());
270    }
271    let token = auth.get_token(provider).await?;
272    Ok(token.access_token)
273}
274
275/// Creates a test `AuthManager` and HTTP client pair for executor unit tests.
276///
277/// Returns `(rquest::Client, Arc<AuthManager>)` backed by an in-memory token store.
278#[cfg(test)]
279#[must_use]
280pub fn test_auth() -> (Client, Arc<byokey_auth::AuthManager>) {
281    let store = Arc::new(byokey_store::InMemoryTokenStore::new());
282    let auth = Arc::new(byokey_auth::AuthManager::new(store, Client::new()));
283    (Client::new(), auth)
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn test_provider_http_clone() {
292        let http = ProviderHttp::new(Client::new());
293        let _http2 = http.clone();
294    }
295
296    #[test]
297    fn test_with_ratelimit() {
298        let store = Arc::new(RateLimitStore::new());
299        let http = ProviderHttp::new(Client::new()).with_ratelimit(store, ProviderId::Claude);
300        assert!(http.rl_ctx.is_some());
301    }
302
303    #[test]
304    fn test_parse_google_duration_seconds() {
305        let d = parse_google_duration("0.847655010s").unwrap();
306        assert!(d.as_micros() > 847_000 && d.as_micros() < 848_000);
307    }
308
309    #[test]
310    fn test_parse_google_duration_millis() {
311        let d = parse_google_duration("373.801628ms").unwrap();
312        assert!(d.as_micros() > 373_000 && d.as_micros() < 374_000);
313    }
314
315    #[test]
316    fn test_parse_google_duration_whole_seconds() {
317        let d = parse_google_duration("5s").unwrap();
318        assert_eq!(d.as_secs(), 5);
319    }
320
321    #[test]
322    fn test_parse_google_duration_invalid() {
323        assert!(parse_google_duration("abc").is_none());
324        assert!(parse_google_duration("").is_none());
325    }
326
327    #[test]
328    fn test_parse_retry_after_body_codex() {
329        let body = r#"{"error":{"type":"usage_limit_reached","resets_in_seconds":300}}"#;
330        let d = parse_retry_after_body(body, 429).unwrap();
331        assert_eq!(d.as_secs(), 300);
332    }
333
334    #[test]
335    fn test_parse_retry_after_body_google_retry_info() {
336        let body = r#"{"error":{"code":429,"details":[{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"1.5s"}]}}"#;
337        let d = parse_retry_after_body(body, 429).unwrap();
338        assert_eq!(d.as_millis(), 1500);
339    }
340
341    #[test]
342    fn test_parse_retry_after_body_google_error_info() {
343        let body = r#"{"error":{"code":429,"details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","metadata":{"quotaResetDelay":"373.8ms"}}]}}"#;
344        let d = parse_retry_after_body(body, 429).unwrap();
345        assert!(d.as_micros() > 373_000 && d.as_micros() < 374_000);
346    }
347
348    #[test]
349    fn test_parse_retry_after_body_non_429() {
350        let body = r#"{"error":{"type":"usage_limit_reached","resets_in_seconds":300}}"#;
351        assert!(parse_retry_after_body(body, 400).is_none());
352    }
353}