use std::time::Duration;
use async_trait::async_trait;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest::multipart::Form;
use reqwest::{Method, Url};
use serde::Serialize;
use tracing::Instrument;
use uuid::Uuid;
pub use body::{Body, GraphQLBody, MultipartFile, MultipartFormFileField};
pub use request_type::{GraphQLMultipart, GraphQLRequest, Request, RestMultipart, RestRequest};
use crate::errors::{PrimaBridgeError, PrimaBridgeResult};
use crate::sealed::Sealed;
use crate::{BridgeClient, BridgeImpl, PrimaRequestBuilder, PrimaRequestBuilderInner, Response};
mod body;
mod request_type;
#[cfg(all(feature = "grpc", feature = "_any_otel_version"))]
pub mod grpc;
#[cfg(feature = "_any_otel_version")]
mod otel;
#[cfg(feature = "tracing_opentelemetry")]
use otel::otel_crates::tracing_opentelemetry::OpenTelemetrySpanExt;
pub enum RequestType {
Rest,
#[allow(clippy::upper_case_acronyms)]
GraphQL,
}
#[derive(Default)]
pub enum DeliverableRequestBody {
#[default]
Empty,
RawBody(Body),
Multipart(Form),
}
#[async_trait]
pub trait DeliverableRequest<'a>: Sized + Sealed + 'a {
type Client: BridgeClient;
fn raw_body(self, body: impl Into<Body>) -> Self;
fn json_body<B: Serialize>(self, body: &B) -> PrimaBridgeResult<Self>;
fn method(self, method: Method) -> Self;
fn to(self, path: &'a str) -> Self;
fn ignore_status_code(self) -> Self;
fn set_timeout(self, timeout: Duration) -> Self;
fn get_timeout(&self) -> Duration;
fn with_custom_header(mut self, name: HeaderName, value: HeaderValue) -> Self {
self.get_custom_headers_mut().insert(name, value);
self
}
fn with_custom_headers(mut self, headers: Vec<(HeaderName, HeaderValue)>) -> Self {
self.get_custom_headers_mut().extend(headers);
self
}
fn with_query_pair(mut self, name: &'a str, value: &'a str) -> Self {
self.get_query_pairs_mut().push((name, value));
self
}
fn with_query_pairs(mut self, pairs: Vec<(&'a str, &'a str)>) -> Self {
self.get_query_pairs_mut().extend(pairs);
self
}
fn get_id(&self) -> Uuid;
#[doc(hidden)]
fn get_bridge(&self) -> &BridgeImpl<Self::Client>;
#[doc(hidden)]
fn get_path(&self) -> Option<&str>;
#[doc(hidden)]
fn endpoint(&self) -> Url;
#[doc(hidden)]
fn get_query_pairs(&self) -> &[(&'a str, &'a str)];
#[doc(hidden)]
fn get_query_pairs_mut(&mut self) -> &mut Vec<(&'a str, &'a str)>;
#[doc(hidden)]
fn get_ignore_status_code(&self) -> bool;
#[doc(hidden)]
fn get_method(&self) -> Method;
#[doc(hidden)]
fn get_custom_headers(&self) -> &HeaderMap;
#[doc(hidden)]
fn get_custom_headers_mut(&mut self) -> &mut HeaderMap;
#[cfg(feature = "auth0")]
#[doc(hidden)]
fn get_auth0(&self) -> &Option<crate::auth0::RefreshingToken>;
#[cfg(feature = "auth0")]
#[doc(hidden)]
fn get_auth0_headers(&self) -> HeaderMap {
match self.get_auth0().as_ref().map(|auth0| auth0.token()) {
None => HeaderMap::new(),
Some(token) => {
let mut header_map: HeaderMap = HeaderMap::new();
let header_value: HeaderValue = HeaderValue::from_str(token.to_bearer().as_str())
.expect("Failed to create bearer header");
header_map.append(reqwest::header::AUTHORIZATION, header_value);
header_map
}
}
}
fn get_all_headers(&self) -> HeaderMap {
let mut additional_headers = self.get_custom_headers().clone();
#[cfg(feature = "_any_otel_version")]
additional_headers.extend(self.tracing_headers());
#[cfg(feature = "auth0")]
additional_headers.extend(self.get_auth0_headers());
additional_headers
}
#[doc(hidden)]
fn get_request_type(&self) -> RequestType;
#[doc(hidden)]
fn into_body(self) -> PrimaBridgeResult<DeliverableRequestBody>;
fn get_body(&self) -> Option<&[u8]>;
async fn send(self) -> PrimaBridgeResult<Response> {
let request_id = self.get_id();
let url = self.get_url();
let method = self.get_method();
let client_span = tracing::info_span!(
"prima_bridge.http.client",
"otel.kind" = "client",
"otel.name" = %method.as_str(),
"http.request.method" = %method.as_str(),
"server.address" = %url.host().map(|h| h.to_string()).unwrap_or_default(),
"server.port" = %url.port_or_known_default().map(|p| p.to_string()).unwrap_or_default(),
"url.full" = %strip_url_credentials(&url),
"url.scheme" = %url.scheme(),
request_id = %request_id
);
#[cfg(feature = "_any_otel_version")]
let headers = client_span.in_scope(|| self.get_all_headers());
#[cfg(not(feature = "_any_otel_version"))]
let headers = self.get_all_headers();
#[cfg(feature = "tracing_opentelemetry")]
client_span.set_status(otel::otel_crates::opentelemetry::trace::Status::Unset);
let request_builder = self
.get_bridge()
.inner_client
.request(method, url.clone())
.timeout(self.get_timeout())
.header(HeaderName::from_static("x-request-id"), &request_id.to_string())
.headers(headers);
let result = self.send_request(request_builder).instrument(client_span.clone()).await;
#[cfg(feature = "tracing_opentelemetry")]
if let Err(ref reason) = result {
client_span.set_status(otel::otel_crates::opentelemetry::trace::Status::Error {
description: reason.to_string().into(),
});
}
result
}
async fn send_request<T>(self, request: PrimaRequestBuilder<T>) -> PrimaBridgeResult<Response>
where
T: PrimaRequestBuilderInner,
{
let request_id = self.get_id();
let url = self.get_url();
let ignore_status_code = self.get_ignore_status_code();
let request_type = self.get_request_type();
let response = match self.into_body()? {
DeliverableRequestBody::Empty => request,
DeliverableRequestBody::RawBody(body) => request.body(body.inner),
DeliverableRequestBody::Multipart(form) => request.multipart(form),
}
.send()
.await?;
let status_code = response.status();
let span = tracing::Span::current();
span.record("http.response.status_code", status_code.as_u16());
#[cfg(feature = "tracing_opentelemetry")]
if status_code.is_client_error() || status_code.is_server_error() {
span.set_status(otel::otel_crates::opentelemetry::trace::Status::Error { description: "".into() });
}
if !ignore_status_code && !status_code.is_success() {
return Err(PrimaBridgeError::WrongStatusCode(url.clone(), status_code));
}
let response_headers = response.headers().clone();
let raw_body = response.bytes().await.map(|b| b.to_vec());
let body = raw_body.map_err(|e| PrimaBridgeError::HttpError {
source: e,
url: url.clone(),
})?;
match request_type {
RequestType::Rest => Ok(Response::rest(
url.clone(),
body,
status_code,
response_headers,
request_id,
)),
RequestType::GraphQL => Ok(Response::graphql(
url.clone(),
body,
status_code,
response_headers,
request_id,
)),
}
}
fn get_url(&self) -> Url {
let mut final_endpoint = self.endpoint();
let path = self.get_path();
let endpoint = match path {
Some(path) => {
let ep = self.endpoint();
let mut parts: Vec<&str> = ep
.path_segments()
.map_or_else(Vec::new, |ps| ps.collect())
.into_iter()
.filter(|p| p != &"")
.collect();
parts.push(path);
final_endpoint.set_path(&parts.join("/"));
final_endpoint
}
_ => final_endpoint,
};
self.get_query_pairs().iter().fold(endpoint, |mut url, (name, value)| {
url.query_pairs_mut().append_pair(name, value);
url
})
}
#[cfg(feature = "_any_otel_version")]
fn tracing_headers(&self) -> HeaderMap {
use std::collections::HashMap;
let mut tracing_headers: HashMap<String, String> = HashMap::new();
otel::inject_context(&mut tracing_headers);
tracing_headers
.iter()
.flat_map(|(name, value)| {
let header_name = HeaderName::from_bytes(name.as_bytes());
let header_value = HeaderValue::from_bytes(value.as_bytes());
match (header_name, header_value) {
(Ok(valid_header_name), Ok(valid_header_value)) => {
vec![(valid_header_name, valid_header_value)]
}
_ => vec![],
}
})
.collect()
}
#[cfg(not(feature = "_any_otel_version"))]
fn tracing_headers(&self) -> Vec<(HeaderName, HeaderValue)> {
vec![]
}
}
fn strip_url_credentials(url: &reqwest::Url) -> String {
if url.username().is_empty() && url.password().is_none() {
return url.as_str().to_owned();
}
let mut redacted = url.clone();
let _ = redacted.set_username("");
let _ = redacted.set_password(None);
redacted.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use reqwest::Url;
#[test]
fn preserves_url_without_credentials() {
let url = Url::parse("https://example.com?xx=yy#123").unwrap();
assert_eq!(strip_url_credentials(&url), "https://example.com/?xx=yy#123");
}
#[test]
fn strips_username_and_password() {
let url = Url::parse("https://myuser:secret@example.com?xx=yy#123").unwrap();
assert_eq!(strip_url_credentials(&url), "https://example.com/?xx=yy#123");
}
#[test]
fn strips_username_without_password() {
let url = Url::parse("https://myuser@example.com/api/v1/users").unwrap();
assert_eq!(strip_url_credentials(&url), "https://example.com/api/v1/users");
}
#[test]
fn preserves_port_path_query_and_fragment() {
let url = Url::parse("https://user:pass@example.com:8443/api/test?q=1#frag").unwrap();
assert_eq!(
strip_url_credentials(&url),
"https://example.com:8443/api/test?q=1#frag"
);
}
#[test]
fn handles_special_characters_in_credentials() {
let url = Url::parse("https://user%40mail.com:p%40ss@example.com/path").unwrap();
assert_eq!(strip_url_credentials(&url), "https://example.com/path");
}
}