byokey_provider/
http_util.rs1use byokey_types::{
7 ByokError, ProviderId, RateLimitSnapshot, RateLimitStore,
8 traits::{ByteStream, ProviderResponse, Result},
9};
10use futures_util::StreamExt as _;
11use rquest::{Client, RequestBuilder};
12use serde_json::Value;
13use std::collections::HashMap;
14use std::sync::Arc;
15
16#[derive(Clone)]
18struct RateLimitCtx {
19 store: Arc<RateLimitStore>,
20 provider: ProviderId,
21 account_id: String,
22}
23
24#[derive(Clone)]
27pub struct ProviderHttp {
28 http: Client,
29 rl_ctx: Option<RateLimitCtx>,
30}
31
32impl ProviderHttp {
33 #[must_use]
35 pub fn new(http: Client) -> Self {
36 Self { http, rl_ctx: None }
37 }
38
39 #[must_use]
42 pub fn with_ratelimit(mut self, store: Arc<RateLimitStore>, provider: ProviderId) -> Self {
43 self.rl_ctx = Some(RateLimitCtx {
44 store,
45 provider,
46 account_id: "active".to_string(),
47 });
48 self
49 }
50
51 #[must_use]
53 pub fn client(&self) -> &Client {
54 &self.http
55 }
56
57 fn capture_ratelimit_headers(&self, headers: &rquest::header::HeaderMap) {
60 let Some(ctx) = &self.rl_ctx else { return };
61
62 let mut captured = HashMap::new();
63 for (name, value) in headers {
64 let key = name.as_str();
65 if (key.starts_with("anthropic-ratelimit-")
67 || key.starts_with("x-ratelimit-")
68 || key == "retry-after")
69 && let Ok(v) = value.to_str()
70 {
71 captured.insert(key.to_string(), v.to_string());
72 }
73 }
74
75 if captured.is_empty() {
76 return;
77 }
78
79 let now = std::time::SystemTime::now()
80 .duration_since(std::time::UNIX_EPOCH)
81 .unwrap_or_default()
82 .as_secs();
83
84 ctx.store.update(
85 ctx.provider.clone(),
86 ctx.account_id.clone(),
87 RateLimitSnapshot {
88 headers: captured,
89 captured_at: now,
90 },
91 );
92 }
93
94 pub async fn send(&self, builder: RequestBuilder) -> Result<rquest::Response> {
105 let resp = builder.send().await?;
106 self.capture_ratelimit_headers(resp.headers());
108 let status = resp.status();
109 if status.is_success() {
110 Ok(resp)
111 } else {
112 let retry_after = parse_retry_after_header(resp.headers());
113 let text = resp.text().await.unwrap_or_default();
114 let retry_after = parse_retry_after_body(&text, status.as_u16()).or(retry_after);
115 Err(ByokError::Upstream {
116 status: status.as_u16(),
117 body: text,
118 retry_after,
119 })
120 }
121 }
122
123 pub async fn send_passthrough(
132 &self,
133 builder: RequestBuilder,
134 stream: bool,
135 ) -> Result<ProviderResponse> {
136 let resp = self.send(builder).await?;
137 if stream {
138 Ok(ProviderResponse::Stream(Self::byte_stream(resp)))
139 } else {
140 let json: Value = resp.json().await?;
141 Ok(ProviderResponse::Complete(json))
142 }
143 }
144
145 #[must_use]
147 pub fn byte_stream(resp: rquest::Response) -> ByteStream {
148 Box::pin(resp.bytes_stream().map(|r| r.map_err(ByokError::from)))
149 }
150}
151
152fn parse_retry_after_header(headers: &rquest::header::HeaderMap) -> Option<std::time::Duration> {
154 let val = headers.get("retry-after")?.to_str().ok()?;
155 let secs: u64 = val.parse().ok()?;
156 Some(std::time::Duration::from_secs(secs))
157}
158
159fn parse_retry_after_body(body: &str, status: u16) -> Option<std::time::Duration> {
167 if status != 429 {
168 return None;
169 }
170 let json: serde_json::Value = serde_json::from_str(body).ok()?;
171
172 if let Some(error) = json.get("error")
174 && error.get("type").and_then(serde_json::Value::as_str) == Some("usage_limit_reached")
175 {
176 if let Some(secs) = error
177 .get("resets_in_seconds")
178 .and_then(serde_json::Value::as_u64)
179 {
180 return Some(std::time::Duration::from_secs(secs));
181 }
182 if let Some(ts) = error.get("resets_at").and_then(serde_json::Value::as_u64) {
183 let now = std::time::SystemTime::now()
184 .duration_since(std::time::UNIX_EPOCH)
185 .unwrap_or_default()
186 .as_secs();
187 if ts > now {
188 return Some(std::time::Duration::from_secs(ts - now));
189 }
190 }
191 }
192
193 if let Some(details) = json.pointer("/error/details").and_then(Value::as_array) {
195 for detail in details {
196 if detail.get("@type").and_then(Value::as_str)
197 == Some("type.googleapis.com/google.rpc.RetryInfo")
198 && let Some(delay_str) = detail.get("retryDelay").and_then(Value::as_str)
199 && let Some(d) = parse_google_duration(delay_str)
200 {
201 return Some(d);
202 }
203 }
204 for detail in details {
205 if detail.get("@type").and_then(Value::as_str)
206 == Some("type.googleapis.com/google.rpc.ErrorInfo")
207 && let Some(delay_str) = detail
208 .pointer("/metadata/quotaResetDelay")
209 .and_then(Value::as_str)
210 && let Some(d) = parse_google_duration(delay_str)
211 {
212 return Some(d);
213 }
214 }
215 }
216
217 None
218}
219
220fn parse_google_duration(s: &str) -> Option<std::time::Duration> {
222 if let Some(ms_str) = s.strip_suffix("ms") {
223 let ms: f64 = ms_str.parse().ok()?;
224 return Some(std::time::Duration::from_secs_f64(ms / 1000.0));
225 }
226 if let Some(secs_str) = s.strip_suffix('s') {
227 let secs: f64 = secs_str.parse().ok()?;
228 return Some(std::time::Duration::from_secs_f64(secs));
229 }
230 None
231}
232
233#[must_use]
237pub fn accept_for_stream(stream: bool) -> &'static str {
238 if stream {
239 "text/event-stream"
240 } else {
241 "application/json"
242 }
243}
244
245pub fn ensure_stream_options(body: &mut serde_json::Value, stream: bool) {
249 if stream {
250 body["stream_options"] = serde_json::json!({ "include_usage": true });
251 }
252}
253
254pub async fn resolve_bearer_token(
264 api_key: Option<&str>,
265 auth: &Arc<byokey_auth::AuthManager>,
266 provider: &ProviderId,
267) -> byokey_types::Result<String> {
268 if let Some(key) = api_key {
269 return Ok(key.to_string());
270 }
271 let token = auth.get_token(provider).await?;
272 Ok(token.access_token)
273}
274
275#[cfg(test)]
279#[must_use]
280pub fn test_auth() -> (Client, Arc<byokey_auth::AuthManager>) {
281 let store = Arc::new(byokey_store::InMemoryTokenStore::new());
282 let auth = Arc::new(byokey_auth::AuthManager::new(store, Client::new()));
283 (Client::new(), auth)
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
291 fn test_provider_http_clone() {
292 let http = ProviderHttp::new(Client::new());
293 let _http2 = http.clone();
294 }
295
296 #[test]
297 fn test_with_ratelimit() {
298 let store = Arc::new(RateLimitStore::new());
299 let http = ProviderHttp::new(Client::new()).with_ratelimit(store, ProviderId::Claude);
300 assert!(http.rl_ctx.is_some());
301 }
302
303 #[test]
304 fn test_parse_google_duration_seconds() {
305 let d = parse_google_duration("0.847655010s").unwrap();
306 assert!(d.as_micros() > 847_000 && d.as_micros() < 848_000);
307 }
308
309 #[test]
310 fn test_parse_google_duration_millis() {
311 let d = parse_google_duration("373.801628ms").unwrap();
312 assert!(d.as_micros() > 373_000 && d.as_micros() < 374_000);
313 }
314
315 #[test]
316 fn test_parse_google_duration_whole_seconds() {
317 let d = parse_google_duration("5s").unwrap();
318 assert_eq!(d.as_secs(), 5);
319 }
320
321 #[test]
322 fn test_parse_google_duration_invalid() {
323 assert!(parse_google_duration("abc").is_none());
324 assert!(parse_google_duration("").is_none());
325 }
326
327 #[test]
328 fn test_parse_retry_after_body_codex() {
329 let body = r#"{"error":{"type":"usage_limit_reached","resets_in_seconds":300}}"#;
330 let d = parse_retry_after_body(body, 429).unwrap();
331 assert_eq!(d.as_secs(), 300);
332 }
333
334 #[test]
335 fn test_parse_retry_after_body_google_retry_info() {
336 let body = r#"{"error":{"code":429,"details":[{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"1.5s"}]}}"#;
337 let d = parse_retry_after_body(body, 429).unwrap();
338 assert_eq!(d.as_millis(), 1500);
339 }
340
341 #[test]
342 fn test_parse_retry_after_body_google_error_info() {
343 let body = r#"{"error":{"code":429,"details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","metadata":{"quotaResetDelay":"373.8ms"}}]}}"#;
344 let d = parse_retry_after_body(body, 429).unwrap();
345 assert!(d.as_micros() > 373_000 && d.as_micros() < 374_000);
346 }
347
348 #[test]
349 fn test_parse_retry_after_body_non_429() {
350 let body = r#"{"error":{"type":"usage_limit_reached","resets_in_seconds":300}}"#;
351 assert!(parse_retry_after_body(body, 400).is_none());
352 }
353}