use std::time::Duration;
use bytes::{Bytes, BytesMut};
use reqwest::Client;
use tokio::time::sleep;
use crate::error::{Result, SchemaRegError};
const MAX_BODY_BYTES: usize = 16 * 1024 * 1024;
const MAX_REQUEST_BODY_BYTES: usize = 4 * 1024 * 1024;
const MAX_RETRIES: u32 = 3;
const BACKOFF_BASE_MS: u64 = 100;
const MAX_BACKOFF_MS: u64 = 60_000;
pub(crate) struct HttpResponse {
pub status: u16,
pub content_type: Option<String>,
pub body: Bytes,
pub retry_after_ms: Option<u64>,
}
fn is_retryable_status(status: u16) -> bool {
status == 429 || (500..600).contains(&status)
}
pub(crate) struct HttpClient {
client: Client,
}
impl HttpClient {
pub fn with_webpki_roots(timeout: Option<Duration>) -> Result<Self> {
let mut builder = Client::builder();
if let Some(t) = timeout {
builder = builder.timeout(t);
}
let client = builder
.build()
.map_err(|e| SchemaRegError::config(format!("failed to build HTTP client: {e}")))?;
Ok(Self { client })
}
pub async fn request(
&self,
method: &str,
url: &str,
extra_headers: &[(&str, &str)],
body: Option<&[u8]>,
auth_header: Option<&str>,
) -> Result<HttpResponse> {
let mut attempt = 0u32;
loop {
match self
.request_once(method, url, extra_headers, body, auth_header)
.await
{
Ok(resp) if is_retryable_status(resp.status) && attempt < MAX_RETRIES => {
let delay = resp
.retry_after_ms
.unwrap_or_else(|| BACKOFF_BASE_MS * (1u64 << attempt))
.min(MAX_BACKOFF_MS);
tracing::warn!(
status = resp.status,
attempt,
delay_ms = delay,
url,
"transient HTTP error — retrying"
);
sleep(Duration::from_millis(delay)).await;
attempt += 1;
}
Ok(resp) => return Ok(resp),
Err(e) if attempt < MAX_RETRIES => {
let delay = (BACKOFF_BASE_MS * (1u64 << attempt)).min(MAX_BACKOFF_MS);
tracing::warn!(
error = %e,
attempt,
delay_ms = delay,
url,
"transient network error — retrying"
);
sleep(Duration::from_millis(delay)).await;
attempt += 1;
}
Err(e) => return Err(e),
}
}
}
async fn request_once(
&self,
method: &str,
url: &str,
extra_headers: &[(&str, &str)],
body: Option<&[u8]>,
auth_header: Option<&str>,
) -> Result<HttpResponse> {
let method = reqwest::Method::from_bytes(method.as_bytes())
.map_err(|_| SchemaRegError::config(format!("invalid HTTP method: {method}")))?;
let mut builder = self.client.request(method, url);
for (name, value) in extra_headers {
builder = builder.header(*name, *value);
}
if let Some(auth) = auth_header {
builder = builder.header("Authorization", auth);
}
if let Some(b) = body {
if b.len() > MAX_REQUEST_BODY_BYTES {
return Err(SchemaRegError::config(format!(
"request body ({} bytes) exceeds the {MAX_REQUEST_BODY_BYTES}-byte limit",
b.len()
)));
}
builder = builder.body(b.to_vec());
}
let response = builder
.send()
.await
.map_err(|e| SchemaRegError::registry(format!("HTTP request failed: {e}")))?;
let status = response.status().as_u16();
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.and_then(|ct| ct.split(';').next())
.map(|s| s.trim().to_string());
let retry_after_ms = if status == 429 {
response
.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.trim().parse::<u64>().ok())
.map(|secs| secs.saturating_mul(1_000))
} else {
None
};
if let Some(len) = response.content_length()
&& len as usize > MAX_BODY_BYTES
{
return Err(SchemaRegError::registry(format!(
"response Content-Length ({len} bytes) exceeds the {MAX_BODY_BYTES}-byte limit"
)));
}
let mut buf = BytesMut::with_capacity(4096);
let mut total = 0usize;
let mut response = response;
loop {
let chunk = response.chunk().await.map_err(|e| {
SchemaRegError::registry(format!("failed to read response body: {e}"))
})?;
match chunk {
Some(bytes) => {
total += bytes.len();
if total > MAX_BODY_BYTES {
return Err(SchemaRegError::registry(format!(
"response body exceeds the {MAX_BODY_BYTES}-byte limit"
)));
}
buf.extend_from_slice(&bytes);
}
None => break,
}
}
Ok(HttpResponse {
status,
content_type,
body: buf.freeze(),
retry_after_ms,
})
}
}