#[cfg(feature = "apicurio")]
use std::collections::HashMap;
use std::time::Duration;
use bytes::{Bytes, BytesMut};
use reqwest::Client;
use tokio::time::sleep;
use crate::error::{Result, SchemaRegError};
const MAX_BODY_BYTES: usize = 16 * 1024 * 1024;
const MAX_REQUEST_BODY_BYTES: usize = 4 * 1024 * 1024;
const MAX_RETRIES: u32 = 3;
const BACKOFF_BASE_MS: u64 = 100;
const MAX_BACKOFF_MS: u64 = 60_000;
pub(crate) struct HttpResponse {
pub status: u16,
pub content_type: Option<String>,
pub body: Bytes,
pub retry_after_ms: Option<u64>,
#[cfg(feature = "apicurio")]
pub headers: HashMap<String, String>,
}
fn is_retryable_status(status: u16) -> bool {
status == 429 || (500..600).contains(&status)
}
pub(crate) struct HttpClientConfig {
pub timeout: Option<Duration>,
pub connect_timeout: Option<Duration>,
pub root_certificates: Vec<reqwest::Certificate>,
pub identity: Option<reqwest::Identity>,
pub pool_max_idle_per_host: Option<usize>,
}
pub(crate) struct HttpClient {
client: Client,
}
impl HttpClient {
pub fn with_webpki_roots(timeout: Option<Duration>) -> Result<Self> {
Self::with_config(HttpClientConfig {
timeout,
connect_timeout: None,
root_certificates: Vec::new(),
identity: None,
pool_max_idle_per_host: None,
})
}
pub fn with_config(config: HttpClientConfig) -> Result<Self> {
let mut builder = Client::builder();
if let Some(t) = config.timeout {
builder = builder.timeout(t);
}
if let Some(ct) = config.connect_timeout {
builder = builder.connect_timeout(ct);
}
for cert in config.root_certificates {
builder = builder.add_root_certificate(cert);
}
if let Some(identity) = config.identity {
builder = builder.identity(identity);
}
if let Some(n) = config.pool_max_idle_per_host {
builder = builder.pool_max_idle_per_host(n);
}
let client = builder
.build()
.map_err(|e| SchemaRegError::config(format!("failed to build HTTP client: {e}")))?;
Ok(Self { client })
}
pub async fn request(
&self,
method: &str,
url: &str,
extra_headers: &[(&str, &str)],
body: Option<&[u8]>,
auth_header: Option<&str>,
) -> Result<HttpResponse> {
let mut attempt = 0u32;
loop {
match self
.request_once(method, url, extra_headers, body, auth_header)
.await
{
Ok(resp) if is_retryable_status(resp.status) && attempt < MAX_RETRIES => {
let delay = resp
.retry_after_ms
.unwrap_or_else(|| BACKOFF_BASE_MS * (1u64 << attempt))
.min(MAX_BACKOFF_MS);
tracing::warn!(
status = resp.status,
attempt,
delay_ms = delay,
url,
"transient HTTP error — retrying"
);
sleep(Duration::from_millis(delay)).await;
attempt += 1;
}
Ok(resp) => return Ok(resp),
Err(e) if attempt < MAX_RETRIES => {
let delay = (BACKOFF_BASE_MS * (1u64 << attempt)).min(MAX_BACKOFF_MS);
tracing::warn!(
error = %e,
attempt,
delay_ms = delay,
url,
"transient network error — retrying"
);
sleep(Duration::from_millis(delay)).await;
attempt += 1;
}
Err(e) => return Err(e),
}
}
}
async fn request_once(
&self,
method: &str,
url: &str,
extra_headers: &[(&str, &str)],
body: Option<&[u8]>,
auth_header: Option<&str>,
) -> Result<HttpResponse> {
let method = reqwest::Method::from_bytes(method.as_bytes())
.map_err(|_| SchemaRegError::config(format!("invalid HTTP method: {method}")))?;
let mut builder = self.client.request(method, url);
for (name, value) in extra_headers {
builder = builder.header(*name, *value);
}
if let Some(auth) = auth_header {
builder = builder.header("Authorization", auth);
}
if let Some(b) = body {
if b.len() > MAX_REQUEST_BODY_BYTES {
return Err(SchemaRegError::config(format!(
"request body ({} bytes) exceeds the {MAX_REQUEST_BODY_BYTES}-byte limit",
b.len()
)));
}
builder = builder.body(b.to_vec());
}
let response = builder.send().await.map_err(SchemaRegError::network)?;
let status = response.status().as_u16();
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.and_then(|ct| ct.split(';').next())
.map(|s| s.trim().to_string());
let retry_after_ms = if status == 429 {
response
.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.trim().parse::<u64>().ok())
.map(|secs| secs.saturating_mul(1_000))
} else {
None
};
#[cfg(feature = "apicurio")]
let headers: HashMap<String, String> = response
.headers()
.iter()
.filter_map(|(name, value)| {
value
.to_str()
.ok()
.map(|v| (name.as_str().to_lowercase(), v.to_string()))
})
.collect();
if let Some(len) = response.content_length()
&& len as usize > MAX_BODY_BYTES
{
return Err(SchemaRegError::invalid_state(format!(
"response Content-Length ({len} bytes) exceeds the {MAX_BODY_BYTES}-byte limit"
)));
}
let mut buf = BytesMut::with_capacity(4096);
let mut total = 0usize;
let mut response = response;
loop {
let chunk = response.chunk().await.map_err(SchemaRegError::network)?;
match chunk {
Some(bytes) => {
total += bytes.len();
if total > MAX_BODY_BYTES {
return Err(SchemaRegError::invalid_state(format!(
"response body exceeds the {MAX_BODY_BYTES}-byte limit"
)));
}
buf.extend_from_slice(&bytes);
}
None => break,
}
}
Ok(HttpResponse {
status,
content_type,
body: buf.freeze(),
retry_after_ms,
#[cfg(feature = "apicurio")]
headers,
})
}
}
pub(crate) static PATH_SEGMENT_ENCODE_SET: percent_encoding::AsciiSet = percent_encoding::CONTROLS
.add(b' ')
.add(b'"')
.add(b'#')
.add(b'<')
.add(b'>')
.add(b'?')
.add(b'`')
.add(b'{')
.add(b'}')
.add(b'/')
.add(b'%')
.add(b'[')
.add(b']')
.add(b'\\')
.add(b'^')
.add(b'@');
#[inline]
pub(crate) fn percent_encode(input: &str) -> String {
percent_encoding::utf8_percent_encode(input, &PATH_SEGMENT_ENCODE_SET).to_string()
}
pub(crate) fn normalize_url(mut url: String) -> String {
let trimmed_len = url.trim_end_matches('/').len();
url.truncate(trimmed_len);
url
}
pub(crate) fn reject_embedded_credentials(url: &str) -> crate::error::Result<()> {
let Some(scheme_end) = url.find("://") else {
return Ok(());
};
let authority_start = scheme_end + 3;
let authority = &url[authority_start..];
let authority_end = authority.find(['/', '?', '#']).unwrap_or(authority.len());
let authority_slice = &authority[..authority_end];
if authority_slice.contains('@') {
return Err(crate::error::SchemaRegError::config(
"registry URL must not contain embedded credentials (user:pass@host); \
use the builder's auth methods instead",
));
}
Ok(())
}
pub(crate) fn validate_subject(subject: &str) -> crate::error::Result<()> {
for segment in subject.split('/') {
if segment == ".." || segment == "." {
return Err(crate::error::SchemaRegError::config(
"subject name must not contain '.' or '..' path segments",
));
}
}
Ok(())
}