use byokey_types::{
ByokError, ProviderId, RateLimitSnapshot, RateLimitStore,
traits::{ByteStream, ProviderResponse, Result},
};
use futures_util::StreamExt as _;
use rquest::{Client, RequestBuilder};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Clone)]
struct RateLimitCtx {
store: Arc<RateLimitStore>,
provider: ProviderId,
account_id: String,
}
#[derive(Clone)]
pub struct ProviderHttp {
http: Client,
rl_ctx: Option<RateLimitCtx>,
}
impl ProviderHttp {
#[must_use]
pub fn new(http: Client) -> Self {
Self { http, rl_ctx: None }
}
#[must_use]
pub fn with_ratelimit(mut self, store: Arc<RateLimitStore>, provider: ProviderId) -> Self {
self.rl_ctx = Some(RateLimitCtx {
store,
provider,
account_id: "active".to_string(),
});
self
}
#[must_use]
pub fn client(&self) -> &Client {
&self.http
}
fn capture_ratelimit_headers(&self, headers: &rquest::header::HeaderMap) {
let Some(ctx) = &self.rl_ctx else { return };
let mut captured = HashMap::new();
for (name, value) in headers {
let key = name.as_str();
if (key.starts_with("anthropic-ratelimit-")
|| key.starts_with("x-ratelimit-")
|| key == "retry-after")
&& let Ok(v) = value.to_str()
{
captured.insert(key.to_string(), v.to_string());
}
}
if captured.is_empty() {
return;
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
ctx.store.update(
ctx.provider.clone(),
ctx.account_id.clone(),
RateLimitSnapshot {
headers: captured,
captured_at: now,
},
);
}
pub async fn send(&self, builder: RequestBuilder) -> Result<rquest::Response> {
let resp = builder.send().await?;
self.capture_ratelimit_headers(resp.headers());
let status = resp.status();
if status.is_success() {
Ok(resp)
} else {
let retry_after = parse_retry_after_header(resp.headers());
let text = resp.text().await.unwrap_or_default();
let retry_after = parse_retry_after_body(&text, status.as_u16()).or(retry_after);
Err(ByokError::Upstream {
status: status.as_u16(),
body: text,
retry_after,
})
}
}
pub async fn send_passthrough(
&self,
builder: RequestBuilder,
stream: bool,
) -> Result<ProviderResponse> {
let resp = self.send(builder).await?;
if stream {
Ok(ProviderResponse::Stream(Self::byte_stream(resp)))
} else {
let json: Value = resp.json().await?;
Ok(ProviderResponse::Complete(json))
}
}
#[must_use]
pub fn byte_stream(resp: rquest::Response) -> ByteStream {
Box::pin(resp.bytes_stream().map(|r| r.map_err(ByokError::from)))
}
}
fn parse_retry_after_header(headers: &rquest::header::HeaderMap) -> Option<std::time::Duration> {
let val = headers.get("retry-after")?.to_str().ok()?;
let secs: u64 = val.parse().ok()?;
Some(std::time::Duration::from_secs(secs))
}
fn parse_retry_after_body(body: &str, status: u16) -> Option<std::time::Duration> {
if status != 429 {
return None;
}
let json: serde_json::Value = serde_json::from_str(body).ok()?;
if let Some(error) = json.get("error")
&& error.get("type").and_then(serde_json::Value::as_str) == Some("usage_limit_reached")
{
if let Some(secs) = error
.get("resets_in_seconds")
.and_then(serde_json::Value::as_u64)
{
return Some(std::time::Duration::from_secs(secs));
}
if let Some(ts) = error.get("resets_at").and_then(serde_json::Value::as_u64) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if ts > now {
return Some(std::time::Duration::from_secs(ts - now));
}
}
}
if let Some(details) = json.pointer("/error/details").and_then(Value::as_array) {
for detail in details {
if detail.get("@type").and_then(Value::as_str)
== Some("type.googleapis.com/google.rpc.RetryInfo")
&& let Some(delay_str) = detail.get("retryDelay").and_then(Value::as_str)
&& let Some(d) = parse_google_duration(delay_str)
{
return Some(d);
}
}
for detail in details {
if detail.get("@type").and_then(Value::as_str)
== Some("type.googleapis.com/google.rpc.ErrorInfo")
&& let Some(delay_str) = detail
.pointer("/metadata/quotaResetDelay")
.and_then(Value::as_str)
&& let Some(d) = parse_google_duration(delay_str)
{
return Some(d);
}
}
}
None
}
fn parse_google_duration(s: &str) -> Option<std::time::Duration> {
if let Some(ms_str) = s.strip_suffix("ms") {
let ms: f64 = ms_str.parse().ok()?;
return Some(std::time::Duration::from_secs_f64(ms / 1000.0));
}
if let Some(secs_str) = s.strip_suffix('s') {
let secs: f64 = secs_str.parse().ok()?;
return Some(std::time::Duration::from_secs_f64(secs));
}
None
}
#[must_use]
pub fn accept_for_stream(stream: bool) -> &'static str {
if stream {
"text/event-stream"
} else {
"application/json"
}
}
pub fn ensure_stream_options(body: &mut serde_json::Value, stream: bool) {
if stream {
body["stream_options"] = serde_json::json!({ "include_usage": true });
}
}
pub async fn resolve_bearer_token(
api_key: Option<&str>,
auth: &Arc<byokey_auth::AuthManager>,
provider: &ProviderId,
) -> byokey_types::Result<String> {
if let Some(key) = api_key {
return Ok(key.to_string());
}
let token = auth.get_token(provider).await?;
Ok(token.access_token)
}
#[cfg(test)]
#[must_use]
pub fn test_auth() -> (Client, Arc<byokey_auth::AuthManager>) {
let store = Arc::new(byokey_store::InMemoryTokenStore::new());
let auth = Arc::new(byokey_auth::AuthManager::new(store, Client::new()));
(Client::new(), auth)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_http_clone() {
let http = ProviderHttp::new(Client::new());
let _http2 = http.clone();
}
#[test]
fn test_with_ratelimit() {
let store = Arc::new(RateLimitStore::new());
let http = ProviderHttp::new(Client::new()).with_ratelimit(store, ProviderId::Claude);
assert!(http.rl_ctx.is_some());
}
#[test]
fn test_parse_google_duration_seconds() {
let d = parse_google_duration("0.847655010s").unwrap();
assert!(d.as_micros() > 847_000 && d.as_micros() < 848_000);
}
#[test]
fn test_parse_google_duration_millis() {
let d = parse_google_duration("373.801628ms").unwrap();
assert!(d.as_micros() > 373_000 && d.as_micros() < 374_000);
}
#[test]
fn test_parse_google_duration_whole_seconds() {
let d = parse_google_duration("5s").unwrap();
assert_eq!(d.as_secs(), 5);
}
#[test]
fn test_parse_google_duration_invalid() {
assert!(parse_google_duration("abc").is_none());
assert!(parse_google_duration("").is_none());
}
#[test]
fn test_parse_retry_after_body_codex() {
let body = r#"{"error":{"type":"usage_limit_reached","resets_in_seconds":300}}"#;
let d = parse_retry_after_body(body, 429).unwrap();
assert_eq!(d.as_secs(), 300);
}
#[test]
fn test_parse_retry_after_body_google_retry_info() {
let body = r#"{"error":{"code":429,"details":[{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"1.5s"}]}}"#;
let d = parse_retry_after_body(body, 429).unwrap();
assert_eq!(d.as_millis(), 1500);
}
#[test]
fn test_parse_retry_after_body_google_error_info() {
let body = r#"{"error":{"code":429,"details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","metadata":{"quotaResetDelay":"373.8ms"}}]}}"#;
let d = parse_retry_after_body(body, 429).unwrap();
assert!(d.as_micros() > 373_000 && d.as_micros() < 374_000);
}
#[test]
fn test_parse_retry_after_body_non_429() {
let body = r#"{"error":{"type":"usage_limit_reached","resets_in_seconds":300}}"#;
assert!(parse_retry_after_body(body, 400).is_none());
}
}