use std::sync::Arc;
use std::time::Duration;
use reqwest::header::{HeaderName, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use reqwest::{Method, StatusCode};
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_json::Value;
use crate::config::{ResolvedConfig, SDK_LANGUAGE, SDK_VERSION};
use crate::error::{ApiError, ApiErrorKind, Error};
use crate::types::common::{ApiResponse, RateLimitInfo, RequestOptions};
const API_PREFIX: &str = "/api/v1";
const MAX_BACKOFF: Duration = Duration::from_secs(30);
const BASE_BACKOFF_MS: u64 = 1_000;
#[derive(Debug, Clone, Copy)]
pub(crate) enum HttpMethod {
Get,
Post,
Put,
Delete,
#[allow(dead_code)]
Patch,
}
impl From<HttpMethod> for Method {
fn from(m: HttpMethod) -> Self {
match m {
HttpMethod::Get => Method::GET,
HttpMethod::Post => Method::POST,
HttpMethod::Put => Method::PUT,
HttpMethod::Delete => Method::DELETE,
HttpMethod::Patch => Method::PATCH,
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct HttpClient {
pub(crate) inner: reqwest::Client,
pub(crate) config: Arc<ResolvedConfig>,
}
impl HttpClient {
pub(crate) fn new(config: ResolvedConfig) -> Result<Self, Error> {
let inner = reqwest::Client::builder()
.build()
.map_err(|err| Error::Network {
message: err.to_string(),
source: Some(Box::new(err)),
})?;
Ok(Self {
inner,
config: Arc::new(config),
})
}
pub(crate) async fn get<Resp>(
&self,
path: &str,
query: Option<&[(&str, &str)]>,
opts: Option<&RequestOptions>,
) -> Result<Resp, Error>
where
Resp: DeserializeOwned,
{
let response = self
.request_raw::<()>(HttpMethod::Get, path, query, None, opts)
.await?;
decode_value(response.data)
}
pub(crate) async fn post<Body, Resp>(
&self,
path: &str,
body: &Body,
opts: Option<&RequestOptions>,
) -> Result<Resp, Error>
where
Body: Serialize + ?Sized,
Resp: DeserializeOwned,
{
let response = self
.request_raw(HttpMethod::Post, path, None, Some(body), opts)
.await?;
decode_value(response.data)
}
pub(crate) async fn put<Body, Resp>(
&self,
path: &str,
body: &Body,
opts: Option<&RequestOptions>,
) -> Result<Resp, Error>
where
Body: Serialize + ?Sized,
Resp: DeserializeOwned,
{
let response = self
.request_raw(HttpMethod::Put, path, None, Some(body), opts)
.await?;
decode_value(response.data)
}
pub(crate) async fn delete<Resp>(
&self,
path: &str,
query: Option<&[(&str, &str)]>,
opts: Option<&RequestOptions>,
) -> Result<Resp, Error>
where
Resp: DeserializeOwned,
{
let response = self
.request_raw::<()>(HttpMethod::Delete, path, query, None, opts)
.await?;
decode_value(response.data)
}
pub(crate) async fn patch<Body, Resp>(
&self,
path: &str,
body: &Body,
opts: Option<&RequestOptions>,
) -> Result<Resp, Error>
where
Body: Serialize + ?Sized,
Resp: DeserializeOwned,
{
let response = self
.request_raw(HttpMethod::Patch, path, None, Some(body), opts)
.await?;
decode_value(response.data)
}
pub(crate) async fn post_with_query<Body, Resp>(
&self,
path: &str,
query: Option<&[(&str, &str)]>,
body: &Body,
opts: Option<&RequestOptions>,
) -> Result<Resp, Error>
where
Body: Serialize + ?Sized,
Resp: DeserializeOwned,
{
let response = self
.request_raw(HttpMethod::Post, path, query, Some(body), opts)
.await?;
decode_value(response.data)
}
pub(crate) async fn request_raw<Body>(
&self,
method: HttpMethod,
path: &str,
query: Option<&[(&str, &str)]>,
body: Option<&Body>,
opts: Option<&RequestOptions>,
) -> Result<ApiResponse<Value>, Error>
where
Body: Serialize + ?Sized,
{
let max_retries = opts
.and_then(|o| o.retries)
.unwrap_or(self.config.max_retries);
let timeout = opts.and_then(|o| o.timeout).unwrap_or(self.config.timeout);
let resolved_path = resolve_path(path);
let url = format!("{}{}", self.config.base_url, resolved_path);
let serialized_body = match body {
Some(b) => Some(serde_json::to_vec(b)?),
None => None,
};
let mut attempt: u32 = 0;
loop {
let mut request = self
.inner
.request(method.into(), &url)
.timeout(timeout)
.header(CONTENT_TYPE, "application/json")
.header(HeaderName::from_static("accept"), "application/json")
.header(HeaderName::from_static("x-sdk-version"), SDK_VERSION)
.header(HeaderName::from_static("x-sdk-language"), SDK_LANGUAGE)
.header(
HeaderName::from_static("x-tenant-id"),
&self.config.tenant_id,
);
if let Some(user_id) = opts
.and_then(|o| o.user_id.as_deref())
.or(self.config.user_id.as_deref())
{
request = request.header(HeaderName::from_static("x-user-id"), user_id);
}
if let Some(namespace_id) = opts
.and_then(|o| o.namespace_id.as_deref())
.or(self.config.namespace_id.as_deref())
{
request = request.header(HeaderName::from_static("x-namespace-id"), namespace_id);
}
if let Some(user) = self.config.authenticated_user.as_deref() {
request = request.header(HeaderName::from_static("x-authenticated-user"), user);
}
if let Some(token) = self.config.bearer_token.as_deref() {
let value = HeaderValue::from_str(&format!("Bearer {token}")).map_err(|_| {
Error::validation("bearer_token", "invalid characters in bearer token")
})?;
request = request.header(AUTHORIZATION, value);
}
for (name, value) in &self.config.extra_headers {
request = request.header(name.as_str(), value.as_str());
}
if let Some(q) = query {
request = request.query(q);
}
if let Some(bytes) = &serialized_body {
request = request.body(bytes.clone());
}
match request.send().await {
Ok(response) => {
let status = response.status();
let headers = response.headers().clone();
let rate_limit = RateLimitInfo::from_headers(&headers);
if status.is_success() {
if status == StatusCode::NO_CONTENT {
return Ok(ApiResponse {
data: Value::Null,
status,
headers,
rate_limit,
});
}
let bytes = response
.bytes()
.await
.map_err(|err| map_runtime(err, timeout))?;
let data = if bytes.is_empty() {
Value::Null
} else {
serde_json::from_slice(&bytes)?
};
return Ok(ApiResponse {
data,
status,
headers,
rate_limit,
});
}
let body_bytes = response
.bytes()
.await
.map_err(|err| map_runtime(err, timeout))?;
let body_value: Option<Value> = if body_bytes.is_empty() {
None
} else {
serde_json::from_slice(&body_bytes).ok()
};
let api_error = ApiError::from_response(status, body_value, headers);
if should_retry_status(status, self.config.retry_on_503)
&& attempt < max_retries
{
let delay = retry_delay_from_api_error(&api_error, attempt + 1);
attempt += 1;
tokio::time::sleep(delay).await;
continue;
}
return Err(Error::from(api_error));
}
Err(err) => {
if err.is_timeout() {
return Err(Error::Timeout { timeout });
}
if attempt < max_retries && is_retryable_network_error(&err) {
let delay = backoff_delay(attempt + 1);
attempt += 1;
tokio::time::sleep(delay).await;
continue;
}
return Err(map_runtime(err, timeout));
}
}
}
}
}
fn decode_value<T: DeserializeOwned>(v: Value) -> Result<T, Error> {
serde_json::from_value(v).map_err(Error::Serde)
}
const ABSOLUTE_PREFIXES: &[&str] = &[
"/api/",
"/health",
"/ready",
"/live",
"/metrics",
"/ws/",
"/api-docs",
"/swagger-ui",
];
pub(crate) fn resolve_path_public(path: &str) -> String {
resolve_path(path)
}
fn resolve_path(path: &str) -> String {
if ABSOLUTE_PREFIXES
.iter()
.any(|p| path == *p || path.starts_with(p))
{
path.to_string()
} else if path.starts_with('/') {
format!("{API_PREFIX}{path}")
} else {
format!("{API_PREFIX}/{path}")
}
}
fn should_retry_status(status: StatusCode, retry_on_503: bool) -> bool {
if status == StatusCode::TOO_MANY_REQUESTS {
return true;
}
if status == StatusCode::SERVICE_UNAVAILABLE && retry_on_503 {
return true;
}
false
}
fn retry_delay_from_api_error(err: &ApiError, attempt: u32) -> Duration {
if err.kind == ApiErrorKind::RateLimit {
if let Some(rl) = err.rate_limit() {
if let Some(secs) = rl.retry_after {
return Duration::from_secs(secs);
}
}
}
backoff_delay(attempt)
}
fn backoff_delay(attempt: u32) -> Duration {
use rand::Rng;
let shift = attempt.saturating_sub(1).min(16);
let exp = BASE_BACKOFF_MS.saturating_mul(1u64 << shift);
let jitter: u64 = rand::thread_rng().gen_range(0..=1_000);
let total = exp.saturating_add(jitter);
let capped = total.min(MAX_BACKOFF.as_millis() as u64);
Duration::from_millis(capped)
}
fn is_retryable_network_error(err: &reqwest::Error) -> bool {
err.is_connect() || err.is_request() || err.is_body()
}
fn map_runtime(err: reqwest::Error, timeout: Duration) -> Error {
if err.is_timeout() {
return Error::Timeout { timeout };
}
Error::Network {
message: err.to_string(),
source: Some(Box::new(err)),
}
}