use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use http::{
HeaderMap, HeaderName, HeaderValue, Method,
header::{AUTHORIZATION, RETRY_AFTER, USER_AGENT},
};
use rand::Rng;
use serde::{Serialize, de::DeserializeOwned};
use crate::error::{ApiError, Error};
use crate::transport::{HttpRequest, SharedTransport};
#[derive(Debug, Clone)]
pub enum RequestStrategy {
Once,
Idempotent(String),
Retry { max_attempts: u32 },
ExponentialBackoff { max_attempts: u32, jitter: bool },
}
impl RequestStrategy {
fn max_attempts(&self, client_default: u32) -> u32 {
match self {
Self::Once => 1,
Self::Idempotent(_) => client_default.saturating_add(1),
Self::Retry { max_attempts } | Self::ExponentialBackoff { max_attempts, .. } => {
*max_attempts
}
}
}
fn idempotency_key(&self) -> Option<&str> {
match self {
Self::Idempotent(k) => Some(k.as_str()),
_ => None,
}
}
fn jitter(&self) -> bool {
matches!(self, Self::ExponentialBackoff { jitter: true, .. })
|| matches!(self, Self::Idempotent(_))
}
fn use_backoff(&self) -> bool {
!matches!(self, Self::Retry { .. })
}
}
#[derive(Debug, Default, Clone)]
pub struct RequestOptions {
pub idempotency_key: Option<String>,
pub extra_headers: Vec<(HeaderName, HeaderValue)>,
pub strategy: Option<RequestStrategy>,
}
impl RequestOptions {
pub fn new() -> Self {
Self::default()
}
pub fn idempotency_key(mut self, key: impl Into<String>) -> Self {
self.idempotency_key = Some(key.into());
self
}
pub fn header(mut self, name: HeaderName, value: HeaderValue) -> Self {
self.extra_headers.push((name, value));
self
}
pub fn strategy(mut self, strategy: RequestStrategy) -> Self {
self.strategy = Some(strategy);
self
}
}
#[doc(hidden)]
pub fn path_segment(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for &b in s.as_bytes() {
match b {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => {
out.push(b as char);
}
_ => out.push_str(&format!("%{b:02X}")),
}
}
out
}
pub const DEFAULT_BASE_URL: &str = "https://api.workos.com";
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Clone)]
pub struct Client {
inner: Arc<ClientInner>,
}
pub(crate) struct ClientInner {
pub(crate) transport: SharedTransport,
pub(crate) base_url: String,
pub(crate) max_retries: u32,
pub(crate) api_key: String,
pub(crate) client_id: String,
pub(crate) default_headers: HeaderMap,
}
#[derive(Default)]
pub struct ClientBuilder {
api_key: Option<String>,
client_id: Option<String>,
base_url: Option<String>,
timeout: Option<Duration>,
max_retries: Option<u32>,
user_agent: Option<String>,
transport: Option<SharedTransport>,
}
impl Client {
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
pub fn new(api_key: impl Into<String>) -> Self {
Self::builder().api_key(api_key).build()
}
pub fn builder() -> ClientBuilder {
ClientBuilder::default()
}
pub(crate) async fn request_with_query<P: Serialize, R: DeserializeOwned>(
&self,
method: Method,
path: &str,
params: &P,
) -> Result<R, Error> {
self.request_with_query_opts(method, path, params, None)
.await
}
pub(crate) async fn request_with_query_opts<P: Serialize, R: DeserializeOwned>(
&self,
method: Method,
path: &str,
params: &P,
opts: Option<&RequestOptions>,
) -> Result<R, Error> {
let req = self.build_request(method.clone(), path, Some(params), None::<&()>, opts)?;
self.send(req, method, opts).await
}
pub(crate) async fn request_with_body_opts<P: Serialize, B: Serialize, R: DeserializeOwned>(
&self,
method: Method,
path: &str,
params: &P,
body: Option<&B>,
opts: Option<&RequestOptions>,
) -> Result<R, Error> {
let req = self.build_request(method.clone(), path, Some(params), body, opts)?;
self.send(req, method, opts).await
}
pub(crate) async fn request_json<B: Serialize, R: DeserializeOwned>(
&self,
method: Method,
path: &str,
body: &B,
) -> Result<R, Error> {
let req = self.build_request(method.clone(), path, None::<&()>, Some(body), None)?;
self.send(req, method, None).await
}
pub(crate) async fn request_empty<B: Serialize>(
&self,
method: Method,
path: &str,
body: Option<&B>,
) -> Result<(), Error> {
let req = self.build_request(method.clone(), path, None::<&()>, body, None)?;
self.send_no_body(req, method, None).await
}
pub(crate) async fn request_with_query_opts_empty<P: Serialize>(
&self,
method: Method,
path: &str,
params: &P,
opts: Option<&RequestOptions>,
) -> Result<(), Error> {
let req = self.build_request(method.clone(), path, Some(params), None::<&()>, opts)?;
self.send_no_body(req, method, opts).await
}
pub(crate) async fn request_with_body_opts_empty<P: Serialize, B: Serialize>(
&self,
method: Method,
path: &str,
params: &P,
body: Option<&B>,
opts: Option<&RequestOptions>,
) -> Result<(), Error> {
let req = self.build_request(method.clone(), path, Some(params), body, opts)?;
self.send_no_body(req, method, opts).await
}
pub fn base_url(&self) -> &str {
&self.inner.base_url
}
pub fn client_id(&self) -> &str {
&self.inner.client_id
}
pub fn api_key(&self) -> &str {
&self.inner.api_key
}
pub fn transport(&self) -> SharedTransport {
self.inner.transport.clone()
}
pub fn default_headers(&self) -> &HeaderMap {
&self.inner.default_headers
}
pub fn passwordless(&self) -> crate::helpers::PasswordlessApi<'_> {
crate::helpers::PasswordlessApi { client: self }
}
pub fn vault(&self) -> crate::helpers::VaultApi<'_> {
crate::helpers::VaultApi { client: self }
}
pub fn authkit(&self) -> crate::helpers::AuthKitHelper<'_> {
crate::helpers::AuthKitHelper { client: self }
}
pub fn sso_helpers(&self) -> crate::helpers::SsoHelper<'_> {
crate::helpers::SsoHelper { client: self }
}
pub fn jwks(&self) -> crate::helpers::JwksHelper {
crate::helpers::JwksHelper::from_client(self)
}
pub fn session<'a>(
&'a self,
sealed: impl Into<String>,
password: impl Into<String>,
) -> crate::helpers::SessionManager<'a> {
crate::helpers::SessionManager::new(Some(self), sealed, password)
}
fn build_request<P: Serialize, B: Serialize>(
&self,
method: Method,
path: &str,
query: Option<&P>,
body: Option<&B>,
opts: Option<&RequestOptions>,
) -> Result<HttpRequest, Error> {
let mut url = format!("{}{}", self.inner.base_url, path);
if let Some(p) = query {
let qs = serde_urlencoded::to_string(p)
.map_err(|e| Error::Builder(format!("query encode failed: {e}")))?;
if !qs.is_empty() {
let sep = if url.contains('?') { '&' } else { '?' };
url.push(sep);
url.push_str(&qs);
}
}
let mut headers = self.inner.default_headers.clone();
let body_bytes = if let Some(b) = body {
let bytes = serde_json::to_vec(b).map_err(Error::from)?;
headers.insert(
http::header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
Some(Bytes::from(bytes))
} else {
None
};
if let Some(o) = opts {
let strategy_key = o.strategy.as_ref().and_then(|s| s.idempotency_key());
let key = strategy_key.or(o.idempotency_key.as_deref());
if let Some(key) = key {
let v = HeaderValue::from_str(key)
.map_err(|e| Error::Builder(format!("invalid idempotency key: {e}")))?;
headers.insert(HeaderName::from_static("idempotency-key"), v);
}
for (name, value) in &o.extra_headers {
headers.insert(name.clone(), value.clone());
}
}
Ok(HttpRequest {
method,
url,
headers,
body: body_bytes,
})
}
async fn send_no_body(
&self,
req: HttpRequest,
method: Method,
opts: Option<&RequestOptions>,
) -> Result<(), Error> {
let resp = self.execute_with_retry(req, &method, opts).await?;
let status = resp.status.as_u16();
if (200..300).contains(&status) {
Ok(())
} else {
Err(Error::Api(Box::new(ApiError::from_response(
status,
&resp.headers,
&resp.body,
))))
}
}
async fn send<R: DeserializeOwned>(
&self,
req: HttpRequest,
method: Method,
opts: Option<&RequestOptions>,
) -> Result<R, Error> {
let resp = self.execute_with_retry(req, &method, opts).await?;
let status = resp.status.as_u16();
if (200..300).contains(&status) {
serde_json::from_slice::<R>(&resp.body).map_err(Error::from)
} else {
Err(Error::Api(Box::new(ApiError::from_response(
status,
&resp.headers,
&resp.body,
))))
}
}
async fn execute_with_retry(
&self,
req: HttpRequest,
method: &Method,
opts: Option<&RequestOptions>,
) -> Result<crate::transport::HttpResponse, Error> {
let strategy = opts.and_then(|o| o.strategy.clone());
let has_idempotency_key = opts
.map(|o| {
o.strategy
.as_ref()
.is_some_and(|s| s.idempotency_key().is_some())
|| o.idempotency_key.is_some()
})
.unwrap_or(false);
let safe_method = is_safe_method(method);
let auto_retry_allowed = safe_method || has_idempotency_key;
let max_attempts = match &strategy {
Some(s) => s.max_attempts(self.inner.max_retries),
None => {
if auto_retry_allowed {
self.inner.max_retries.saturating_add(1)
} else {
1
}
}
};
let use_backoff = strategy.as_ref().is_none_or(|s| s.use_backoff());
let jitter = strategy.as_ref().map(|s| s.jitter()).unwrap_or(true);
let mut attempt: u32 = 0;
loop {
let cloned = req.clone();
let result = self.inner.transport.execute(cloned).await;
match result {
Ok(resp) => {
let status = resp.status.as_u16();
let retryable = status == 429 || (500..=599).contains(&status);
let attempts_used = attempt + 1;
if retryable && attempts_used < max_attempts {
let delay = match retry_after(&resp.headers) {
Some(d) => d,
None if use_backoff => backoff_delay(attempt + 1, jitter),
None => Duration::from_millis(0),
};
attempt += 1;
if !delay.is_zero() {
tokio::time::sleep(delay).await;
}
continue;
}
return Ok(resp);
}
Err(e) => {
let attempts_used = attempt + 1;
if e.is_retryable() && attempts_used < max_attempts {
let delay = if use_backoff {
backoff_delay(attempt + 1, jitter)
} else {
Duration::from_millis(0)
};
attempt += 1;
if !delay.is_zero() {
tokio::time::sleep(delay).await;
}
continue;
}
return Err(Error::Network(e));
}
}
}
}
}
fn is_safe_method(m: &Method) -> bool {
matches!(*m, Method::GET | Method::HEAD | Method::OPTIONS)
}
fn retry_after(headers: &HeaderMap) -> Option<Duration> {
let raw = headers.get(RETRY_AFTER)?.to_str().ok()?;
raw.trim().parse::<u64>().ok().map(Duration::from_secs)
}
fn backoff_delay(attempt: u32, jitter: bool) -> Duration {
let base_ms: u64 = 100;
let capped = base_ms.saturating_mul(1u64 << attempt.min(6));
let bounded = capped.min(5_000);
if !jitter {
return Duration::from_millis(bounded);
}
let half = bounded / 2;
let rand_part: u64 = rand::rng().random_range(0..=half);
Duration::from_millis(half + rand_part)
}
impl ClientBuilder {
pub fn api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
pub fn client_id(mut self, id: impl Into<String>) -> Self {
self.client_id = Some(id.into());
self
}
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = Some(url.into());
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn max_retries(mut self, max: u32) -> Self {
self.max_retries = Some(max);
self
}
pub fn user_agent(mut self, ua: impl Into<String>) -> Self {
self.user_agent = Some(ua.into());
self
}
pub fn transport(mut self, transport: SharedTransport) -> Self {
self.transport = Some(transport);
self
}
pub fn build(self) -> Client {
match self.try_build() {
Ok(c) => c,
Err(e) => panic!("ClientBuilder::build: {e}"),
}
}
pub fn try_build(self) -> Result<Client, Error> {
let api_key = self.api_key.unwrap_or_default();
let timeout = self.timeout.unwrap_or(DEFAULT_TIMEOUT);
let mut headers = HeaderMap::new();
if !api_key.is_empty() {
let v = HeaderValue::from_str(&format!("Bearer {api_key}"))
.map_err(|e| Error::Builder(format!("invalid API key: {e}")))?;
headers.insert(AUTHORIZATION, v);
}
let ua = self.user_agent.as_deref().unwrap_or("workos-rust");
let v = HeaderValue::from_str(ua)
.map_err(|e| Error::Builder(format!("invalid user-agent: {e}")))?;
headers.insert(USER_AGENT, v);
let transport = self.transport.unwrap_or_else(|| default_transport(timeout));
Ok(Client {
inner: Arc::new(ClientInner {
transport,
base_url: self
.base_url
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string()),
max_retries: self.max_retries.unwrap_or(3),
api_key,
client_id: self.client_id.unwrap_or_default(),
default_headers: headers,
}),
})
}
}
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
fn default_transport(timeout: Duration) -> SharedTransport {
Arc::new(crate::transport::ReqwestTransport::with_timeout(timeout))
}
#[cfg(not(any(feature = "native-tls", feature = "rustls-tls")))]
fn default_transport(_timeout: Duration) -> SharedTransport {
panic!(
"no HTTP transport configured: build with --features rustls-tls (or native-tls), \
or supply one via ClientBuilder::transport(...)"
);
}