use crate::auth::{AuthStrategy, TokenProvider};
use crate::config::{PortConfig, PortRegion, RetryConfig, TelemetryConfig};
use crate::error::PortError;
use crate::tracking::{new_shared_tracker, ResourceTrackerHandle};
#[cfg(feature = "retry")]
use backoff::backoff::Backoff;
#[cfg(feature = "retry")]
use backoff::ExponentialBackoff;
use httpdate::parse_http_date;
use reqwest::header::HeaderValue;
use reqwest::{Client, Method, Proxy, Request, Response, Url};
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
#[cfg(feature = "tracing")]
use tracing::{Instrument, Span};
#[derive(Debug, Clone, Default, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct Pagination {
#[serde(skip_serializing_if = "Option::is_none")]
page: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
per_page: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
cursor: Option<String>,
}
impl Pagination {
pub fn builder() -> PaginationBuilder {
PaginationBuilder::default()
}
pub fn is_empty(&self) -> bool {
self.page.is_none() && self.per_page.is_none() && self.cursor.is_none()
}
}
#[cfg(feature = "tracing")]
impl PortClient {
fn start_request_span(&self, request: &Request) -> Span {
if !self.telemetry.enable_tracing {
return Span::none();
}
tracing::info_span!(
"port_sdk.request",
method = %request.method(),
path = %request.url(),
retry_enabled = self.retry_policy.is_some()
)
}
fn finish_request_span(&self, span: &Span, response: &Response) {
if !self.telemetry.enable_tracing || span.is_none() {
return;
}
let status = response.status();
let _guard = span.enter();
if status.is_success() {
tracing::info!(status = %status, "request completed");
} else {
tracing::warn!(status = %status, "request completed with failure");
}
}
}
#[derive(Debug, Default)]
pub struct PaginationBuilder {
inner: Pagination,
}
impl PaginationBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn page(mut self, page: u32) -> Self {
self.inner.page = Some(page);
self
}
pub fn per_page(mut self, per_page: u32) -> Self {
self.inner.per_page = Some(per_page);
self
}
pub fn cursor(mut self, cursor: impl Into<String>) -> Self {
self.inner.cursor = Some(cursor.into());
self
}
pub fn build(self) -> Pagination {
self.inner
}
}
#[derive(Clone)]
pub struct PortClient {
http: Client,
base_url: Url,
token_provider: Arc<dyn TokenProvider>,
retry_policy: Option<RetryPolicy>,
tracker: ResourceTrackerHandle,
telemetry: TelemetryConfig,
}
impl std::fmt::Debug for PortClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PortClient")
.field("base_url", &self.base_url)
.field("retry_policy", &self.retry_policy)
.field("tracing", &self.telemetry.enable_tracing)
.finish()
}
}
impl PortClient {
pub fn builder() -> PortClientBuilder {
PortClientBuilder::default()
}
pub fn from_config(config: PortConfig) -> Result<Self, PortError> {
PortClientBuilder::from_config(config).build()
}
pub fn from_env() -> Result<Self, PortError> {
let config = PortConfig::from_env()?;
Self::from_config(config)
}
pub fn base_url(&self) -> &Url {
&self.base_url
}
pub fn tracker(&self) -> ResourceTrackerHandle {
Arc::clone(&self.tracker)
}
pub fn telemetry(&self) -> &TelemetryConfig {
&self.telemetry
}
pub fn record_creation(&self, resource_type: &str, identifier: &str) {
self.tracker.record_creation(resource_type.to_string(), identifier.to_string());
}
pub fn record_deletion(&self, resource_type: &str, identifier: &str) {
self.tracker.record_deletion(resource_type.to_string(), identifier.to_string());
}
async fn authenticated_request(
&self,
builder: reqwest::RequestBuilder,
) -> Result<Request, PortError> {
let token = self.token_provider.bearer_token().await?;
let request = builder.bearer_auth(token).build()?;
Ok(request)
}
async fn execute<T>(&self, request: Request) -> Result<T, PortError>
where
T: DeserializeOwned,
{
#[cfg(feature = "retry")]
{
if let Some(policy) = &self.retry_policy {
return self.execute_with_retry(request, policy).await;
}
}
self.execute_once(request).await
}
async fn execute_once<T>(&self, request: Request) -> Result<T, PortError>
where
T: DeserializeOwned,
{
#[cfg(feature = "tracing")]
let span = self.start_request_span(&request);
#[cfg(feature = "tracing")]
let response = self.http.execute(request).instrument(span.clone()).await?;
#[cfg(not(feature = "tracing"))]
let response = self.http.execute(request).await?;
#[cfg(feature = "tracing")]
self.finish_request_span(&span, &response);
Self::deserialize_response(response).await
}
async fn deserialize_response<T>(response: Response) -> Result<T, PortError>
where
T: DeserializeOwned,
{
let status = response.status();
let headers = response.headers().clone();
let body_bytes = response.bytes().await?;
if !status.is_success() {
let message = String::from_utf8_lossy(&body_bytes).trim().to_string();
return Err(PortError::api(status.as_u16(), message, headers));
}
if body_bytes.is_empty() {
return Ok(serde_json::from_str("null")?);
}
Ok(serde_json::from_slice(&body_bytes)?)
}
async fn request_json<T, F>(
&self,
method: Method,
path: &str,
configure: F,
) -> Result<T, PortError>
where
T: DeserializeOwned,
F: Fn(reqwest::RequestBuilder) -> reqwest::RequestBuilder,
{
self.request_with(method, path, configure, |builder| builder).await
}
async fn request_with<T, F, Q>(
&self,
method: Method,
path: &str,
configure: F,
extra: Q,
) -> Result<T, PortError>
where
T: DeserializeOwned,
F: Fn(reqwest::RequestBuilder) -> reqwest::RequestBuilder,
Q: Fn(reqwest::RequestBuilder) -> reqwest::RequestBuilder,
{
let url = self.base_url.join(path)?;
let builder = self.http.request(method, url);
let builder = configure(builder);
let builder = extra(builder);
let request = self.authenticated_request(builder).await?;
self.execute(request).await
}
pub async fn get<T>(&self, path: &str) -> Result<T, PortError>
where
T: DeserializeOwned,
{
self.request_json(Method::GET, path, |builder| builder).await
}
pub async fn get_with_query<T, Q>(&self, path: &str, query: &Q) -> Result<T, PortError>
where
T: DeserializeOwned,
Q: Serialize + ?Sized,
{
self.request_json(Method::GET, path, |builder| builder.query(query)).await
}
pub async fn get_paginated<T, Q>(
&self,
path: &str,
query: &Q,
pagination: &Pagination,
) -> Result<T, PortError>
where
T: DeserializeOwned,
Q: Serialize + ?Sized,
{
self.request_json(Method::GET, path, |builder| {
let builder = builder.query(query);
if pagination.is_empty() {
builder
} else {
builder.query(pagination)
}
})
.await
}
pub async fn post<B, T>(&self, path: &str, body: &B) -> Result<T, PortError>
where
B: Serialize + ?Sized,
T: DeserializeOwned,
{
self.request_json(Method::POST, path, |builder| builder.json(body)).await
}
pub async fn put<B, T>(&self, path: &str, body: &B) -> Result<T, PortError>
where
B: Serialize + ?Sized,
T: DeserializeOwned,
{
self.request_json(Method::PUT, path, |builder| builder.json(body)).await
}
pub async fn patch<B, T>(&self, path: &str, body: &B) -> Result<T, PortError>
where
B: Serialize + ?Sized,
T: DeserializeOwned,
{
self.request_json(Method::PATCH, path, |builder| builder.json(body)).await
}
pub async fn delete<T>(&self, path: &str) -> Result<T, PortError>
where
T: DeserializeOwned,
{
self.request_json(Method::DELETE, path, |builder| builder).await
}
pub async fn delete_with_query<T, Q>(&self, path: &str, query: &Q) -> Result<T, PortError>
where
T: DeserializeOwned,
Q: Serialize + ?Sized,
{
self.request_json(Method::DELETE, path, |builder| builder.query(query)).await
}
#[cfg(feature = "retry")]
async fn execute_with_retry<T>(
&self,
request: Request,
policy: &RetryPolicy,
) -> Result<T, PortError>
where
T: DeserializeOwned,
{
let start = Instant::now();
let mut attempts = 0;
let mut backoff = policy.to_backoff();
let request_template = request;
loop {
attempts += 1;
let attempt_request = request_template.try_clone().ok_or_else(|| {
PortError::Configuration("request body could not be cloned for retry".into())
})?;
match self.execute_once(attempt_request).await {
Ok(value) => return Ok(value),
Err(err) => {
if !policy.should_retry(&err, attempts) {
return Err(err);
}
let delay = match policy.next_delay(&err, &mut backoff) {
Some(delay) => delay,
None => return Err(err),
};
tokio::time::sleep(delay).await;
if let Some(max_elapsed) = policy.max_elapsed_time {
if start.elapsed() >= max_elapsed {
return Err(err);
}
}
}
}
}
}
}
#[derive(Clone, Debug, Default)]
pub struct PortClientBuilder {
region: PortRegion,
base_url: Option<Url>,
auth: Option<AuthStrategy>,
proxy: Option<String>,
timeout: Option<Duration>,
retry: Option<RetryPolicy>,
http_client: Option<Client>,
tracker: Option<ResourceTrackerHandle>,
telemetry: TelemetryConfig,
}
impl PortClientBuilder {
pub fn from_config(config: PortConfig) -> Self {
let retry = config.retry.map(RetryPolicy::from);
Self {
region: config.region,
base_url: Some(config.base_url),
auth: Some(config.auth),
proxy: config.proxy,
timeout: Some(config.timeout),
retry,
http_client: None,
tracker: None,
telemetry: config.telemetry,
}
}
pub fn from_env() -> Result<Self, PortError> {
let config = PortConfig::from_env()?;
Ok(Self::from_config(config))
}
pub fn region(mut self, region: PortRegion) -> Self {
self.region = region;
self
}
pub fn base_url(mut self, base_url: Url) -> Self {
self.base_url = Some(base_url);
self
}
pub fn auth(mut self, auth: AuthStrategy) -> Self {
self.auth = Some(auth);
self
}
pub fn proxy(mut self, proxy: impl Into<String>) -> Self {
self.proxy = Some(proxy.into());
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn retry(mut self, retry: Option<RetryPolicy>) -> Self {
self.retry = retry;
self
}
pub fn http_client(mut self, client: Client) -> Self {
self.http_client = Some(client);
self
}
pub fn tracker(mut self, tracker: ResourceTrackerHandle) -> Self {
self.tracker = Some(tracker);
self
}
pub fn telemetry(mut self, telemetry: TelemetryConfig) -> Self {
self.telemetry = telemetry;
self
}
#[cfg(feature = "tracing")]
pub fn enable_tracing(mut self, enable: bool) -> Self {
self.telemetry.enable_tracing = enable;
self
}
pub fn build(self) -> Result<PortClient, PortError> {
let base_url = match self.base_url {
Some(url) => url,
None => Url::parse(self.region.base_url())?,
};
let auth = self.auth.ok_or_else(|| {
PortError::Configuration("authentication strategy missing for PortClient".into())
})?;
let token_provider = auth.into_provider()?;
let tracker = self.tracker.unwrap_or_else(new_shared_tracker);
let timeout = self.timeout.unwrap_or_else(|| Duration::from_secs(30));
let http = match self.http_client {
Some(client) => client,
None => {
let mut builder = Client::builder();
builder = builder.timeout(timeout);
if let Some(proxy_url) = &self.proxy {
builder = builder.proxy(build_proxy(proxy_url)?);
}
builder.build()?
}
};
Ok(PortClient {
http,
base_url,
token_provider,
retry_policy: self.retry,
tracker,
telemetry: self.telemetry,
})
}
}
fn build_proxy(proxy_url: &str) -> Result<Proxy, PortError> {
let mut proxy = Proxy::all(proxy_url).map_err(|err| {
PortError::Configuration(format!("failed to configure proxy {proxy_url}: {err}"))
})?;
if let (Ok(username), Ok(password)) =
(std::env::var("PROXY_AUTH_USERNAME"), std::env::var("PROXY_AUTH_PASSWORD"))
{
if !username.is_empty() || !password.is_empty() {
proxy = proxy.basic_auth(&username, &password);
}
}
Ok(proxy)
}
#[derive(Clone, Debug)]
pub struct RetryPolicy {
max_attempts: u32,
max_elapsed_time: Option<Duration>,
initial_interval: Duration,
multiplier: f64,
max_interval: Duration,
retry_on_statuses: Vec<u16>,
}
impl From<RetryConfig> for RetryPolicy {
fn from(value: RetryConfig) -> Self {
RetryPolicy {
max_attempts: value.max_attempts.max(1),
max_elapsed_time: value.max_elapsed_time,
initial_interval: value.initial_interval,
multiplier: value.multiplier,
max_interval: value.max_interval,
retry_on_statuses: value.retry_on_statuses,
}
}
}
impl RetryPolicy {
fn should_retry(&self, error: &PortError, attempt: u32) -> bool {
if attempt >= self.max_attempts {
return false;
}
match error {
PortError::Http(_) => true,
PortError::Api { status, .. } => self.retry_on_statuses.contains(status),
_ => false,
}
}
#[cfg(feature = "retry")]
fn next_delay(&self, error: &PortError, backoff: &mut ExponentialBackoff) -> Option<Duration> {
if let Some(duration) = self.retry_after(error) {
return Some(duration);
}
backoff.next_backoff()
}
fn retry_after(&self, error: &PortError) -> Option<Duration> {
match error {
PortError::Api { headers, .. } => {
headers.get("retry-after").and_then(parse_retry_after_header)
}
_ => None,
}
}
#[cfg(feature = "retry")]
fn to_backoff(&self) -> ExponentialBackoff {
let mut backoff = ExponentialBackoff::default();
backoff.max_elapsed_time = self.max_elapsed_time;
backoff.current_interval = self.initial_interval;
backoff.initial_interval = self.initial_interval;
backoff.multiplier = self.multiplier;
backoff.max_interval = self.max_interval;
backoff.randomization_factor = 0.0;
backoff
}
}
fn parse_retry_after_header(value: &HeaderValue) -> Option<Duration> {
let text = value.to_str().ok()?;
if let Ok(seconds) = text.parse::<u64>() {
return Some(Duration::from_secs(seconds));
}
let target_time = parse_http_date(text).ok()?;
let now = SystemTime::now();
if target_time > now {
target_time.duration_since(now).ok()
} else {
Some(Duration::from_secs(0))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_retry_after_seconds() {
let header = HeaderValue::from_static("3");
let duration = parse_retry_after_header(&header).expect("duration");
assert_eq!(duration.as_secs(), 3);
}
}