use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use bytes::Bytes;
use http::Extensions;
use reqwest::header::{HeaderMap, CONTENT_LENGTH, CONTENT_TYPE};
use reqwest::{IntoUrl, Method, Request, RequestBuilder, Response, StatusCode, Url, Version};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware, Next};
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::error::{Result, RumError};
use crate::event::{PendingResource, RumEvent};
use crate::propagation::{inject_or_extract, TraceInfo};
use crate::sdk::{global, Rum};
#[derive(Clone)]
pub struct RumReqwestClient {
client: reqwest::Client,
middleware_client: Arc<ClientWithMiddleware>,
}
pub struct RumRequestBuilder {
client: RumReqwestClient,
builder: RequestBuilder,
}
pub struct RumReqwestResponse {
response: Option<Response>,
finalizer: Option<ResourceFinalizer>,
}
#[derive(Clone)]
struct ResourceFinalizer {
rum: Rum,
pending: PendingResource,
}
struct RumReqwestMiddleware {
rum: Rum,
}
impl RumReqwestClient {
pub(crate) fn new(client: reqwest::Client, rum: Option<Rum>) -> Self {
let mut builder = ClientBuilder::new(client.clone());
if let Some(rum) = rum {
builder = builder.with(RumReqwestMiddleware { rum });
}
Self {
client,
middleware_client: Arc::new(builder.build()),
}
}
pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RumRequestBuilder {
RumRequestBuilder {
client: self.clone(),
builder: self.client.request(method, url),
}
}
pub fn get<U: IntoUrl>(&self, url: U) -> RumRequestBuilder {
self.request(Method::GET, url)
}
pub fn post<U: IntoUrl>(&self, url: U) -> RumRequestBuilder {
self.request(Method::POST, url)
}
pub fn head<U: IntoUrl>(&self, url: U) -> RumRequestBuilder {
self.request(Method::HEAD, url)
}
pub fn put<U: IntoUrl>(&self, url: U) -> RumRequestBuilder {
self.request(Method::PUT, url)
}
pub fn patch<U: IntoUrl>(&self, url: U) -> RumRequestBuilder {
self.request(Method::PATCH, url)
}
pub fn delete<U: IntoUrl>(&self, url: U) -> RumRequestBuilder {
self.request(Method::DELETE, url)
}
pub async fn execute(&self, request: Request) -> Result<RumReqwestResponse> {
execute_instrumented(self.clone(), request).await
}
}
impl RumRequestBuilder {
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
reqwest::header::HeaderName: TryFrom<K>,
<reqwest::header::HeaderName as TryFrom<K>>::Error: Into<http::Error>,
reqwest::header::HeaderValue: TryFrom<V>,
<reqwest::header::HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
self.builder = self.builder.header(key, value);
self
}
pub fn headers(mut self, headers: HeaderMap) -> Self {
self.builder = self.builder.headers(headers);
self
}
pub fn query<T: Serialize + ?Sized>(mut self, query: &T) -> Self {
self.builder = self.builder.query(query);
self
}
pub fn json<T: Serialize + ?Sized>(mut self, json: &T) -> Self {
self.builder = self.builder.json(json);
self
}
pub fn form<T: Serialize + ?Sized>(mut self, form: &T) -> Self {
self.builder = self.builder.form(form);
self
}
pub fn body(mut self, body: impl Into<reqwest::Body>) -> Self {
self.builder = self.builder.body(body);
self
}
pub fn version(mut self, version: Version) -> Self {
self.builder = self.builder.version(version);
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.builder = self.builder.timeout(timeout);
self
}
pub fn bearer_auth<T: std::fmt::Display>(mut self, token: T) -> Self {
self.builder = self.builder.bearer_auth(token);
self
}
pub fn basic_auth<U, P>(mut self, username: U, password: Option<P>) -> Self
where
U: std::fmt::Display,
P: std::fmt::Display,
{
self.builder = self.builder.basic_auth(username, password);
self
}
pub fn build(self) -> Result<Request> {
Ok(self.builder.build()?)
}
pub fn try_clone(&self) -> Option<RumRequestBuilder> {
self.builder.try_clone().map(|builder| RumRequestBuilder {
client: self.client.clone(),
builder,
})
}
pub async fn send(self) -> Result<RumReqwestResponse> {
let request = self.builder.build()?;
self.client.execute(request).await
}
}
impl RumReqwestResponse {
fn new(response: Response, finalizer: Option<ResourceFinalizer>) -> Self {
Self {
response: Some(response),
finalizer,
}
}
pub fn status(&self) -> StatusCode {
self.response
.as_ref()
.expect("response already consumed")
.status()
}
pub fn headers(&self) -> &HeaderMap {
self.response
.as_ref()
.expect("response already consumed")
.headers()
}
pub fn headers_mut(&mut self) -> &mut HeaderMap {
self.response
.as_mut()
.expect("response already consumed")
.headers_mut()
}
pub fn url(&self) -> &Url {
self.response
.as_ref()
.expect("response already consumed")
.url()
}
pub fn version(&self) -> Version {
self.response
.as_ref()
.expect("response already consumed")
.version()
}
pub fn content_length(&self) -> Option<u64> {
self.response
.as_ref()
.expect("response already consumed")
.content_length()
}
pub async fn bytes(mut self) -> Result<Bytes> {
let response = self.response.take().expect("response already consumed");
match response.bytes().await {
Ok(bytes) => {
let downloaded_size = bytes.len() as u64;
self.finalize(true, "", Some(downloaded_size));
Ok(bytes)
}
Err(error) => {
let message = error.to_string();
self.finalize(false, message, None);
Err(error.into())
}
}
}
pub async fn text(self) -> Result<String> {
let bytes = self.bytes().await?;
Ok(String::from_utf8_lossy(&bytes).into_owned())
}
pub async fn json<T: DeserializeOwned>(self) -> Result<T> {
let bytes = self.bytes().await?;
Ok(serde_json::from_slice(&bytes)?)
}
pub fn error_for_status(mut self) -> Result<Self> {
let response = self.response.take().expect("response already consumed");
match response.error_for_status() {
Ok(response) => {
self.response = Some(response);
Ok(self)
}
Err(error) => {
let message = error.to_string();
self.finalize(false, message, None);
Err(error.into())
}
}
}
pub fn error_for_status_ref(&self) -> Result<&Self> {
self.response
.as_ref()
.expect("response already consumed")
.error_for_status_ref()?;
Ok(self)
}
fn finalize(
&mut self,
success: bool,
message: impl Into<String>,
downloaded_size: Option<u64>,
) {
if let Some(mut finalizer) = self.finalizer.take() {
if finalizer.pending.size == 0 {
if let Some(downloaded_size) = downloaded_size {
finalizer.pending.size = downloaded_size;
}
}
let status_success = (200..400).contains(&finalizer.pending.status_code);
let event =
finalizer
.pending
.complete(success && status_success, message, Instant::now());
let _ = finalizer
.rum
.enqueue_event(RumEvent::Resource(Box::new(event)));
}
}
}
impl Drop for RumReqwestResponse {
fn drop(&mut self) {
self.finalize(true, "", None);
}
}
pub fn wrap_reqwest(client: reqwest::Client) -> Result<RumReqwestClient> {
Ok(RumReqwestClient::new(client, Some(global()?)))
}
async fn execute_instrumented(
client: RumReqwestClient,
request: Request,
) -> Result<RumReqwestResponse> {
let mut extensions = Extensions::new();
match client
.middleware_client
.execute_with_extensions(request, &mut extensions)
.await
{
Ok(response) => Ok(RumReqwestResponse::new(
response,
extensions.remove::<ResourceFinalizer>(),
)),
Err(reqwest_middleware::Error::Reqwest(error)) => Err(RumError::Reqwest(error)),
Err(error) => Err(RumError::ReqwestMiddleware(error)),
}
}
#[async_trait::async_trait]
impl Middleware for RumReqwestMiddleware {
async fn handle(
&self,
mut request: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> reqwest_middleware::Result<Response> {
let rum = self.rum.clone();
if rum.is_shutdown() {
return next.run(request, extensions).await;
}
let start = Instant::now();
let timestamp = SystemTime::now();
let method = request.method().as_str().to_string();
let url = request.url().clone();
let url_string = url.to_string();
let (record, trace) = network_decision(&rum, &url);
let context = if record {
Some(rum.snapshot_context().await)
} else {
None
};
let trace_info = if trace || record {
inject_or_extract(
request.headers_mut(),
rum.config(),
&method,
&url,
trace,
rum.instance_id(),
)
} else {
TraceInfo::default()
};
match next.run(request, extensions).await {
Ok(response) => {
let first_byte = start.elapsed();
let status = response.status();
let headers = response.headers().clone();
if let Some(context) = context {
let mut pending = PendingResource::new(
context,
timestamp,
start,
url_string,
method,
status.as_u16() as i32,
content_type(&headers),
content_length(&headers),
trace_info,
);
pending.first_byte = Some(first_byte);
extensions.insert(ResourceFinalizer { rum, pending });
}
Ok(response)
}
Err(error) => {
if let Some(context) = context {
let pending = PendingResource::new(
context,
timestamp,
start,
url_string,
method,
0,
String::new(),
0,
trace_info,
);
let event = pending.complete(false, error.to_string(), Instant::now());
let _ = rum.enqueue_event(RumEvent::Resource(Box::new(event)));
}
Err(error)
}
}
}
}
fn network_decision(rum: &Rum, url: &Url) -> (bool, bool) {
let config = rum.config();
let network = &config.network;
if !network.enabled || is_excluded(network, url) {
return (false, false);
}
let record = network.record_enabled && (network.should_record_request)(url);
let trace = network.tracing_enabled && (network.should_trace_request)(url);
(record, trace)
}
fn is_excluded(network: &crate::config::NetworkConfig, url: &Url) -> bool {
let Some(host) = url.host_str() else {
return false;
};
let host = host.to_ascii_lowercase();
if network.excluded_hosts.contains(&host) {
return true;
}
let with_port = crate::propagation::effective_target_address(url).to_ascii_lowercase();
network.excluded_hosts.contains(&with_port)
}
fn content_type(headers: &HeaderMap) -> String {
headers
.get(CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.unwrap_or_default()
.to_string()
}
fn content_length(headers: &HeaderMap) -> u64 {
headers
.get(CONTENT_LENGTH)
.and_then(|value| value.to_str().ok())
.and_then(|value| value.parse::<u64>().ok())
.unwrap_or_default()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::RumConfig;
#[test]
fn exclude_host_matches_case_insensitive_host_and_port() {
let config = RumConfig::builder()
.config_address("https://collector.example.com/rum")
.app_id("app")
.network(|network| {
network
.exclude_host("EXAMPLE.COM")
.exclude_host("api.example.com:8443")
})
.build()
.unwrap();
assert!(is_excluded(
&config.network,
&Url::parse("https://example.com/path").unwrap()
));
assert!(is_excluded(
&config.network,
&Url::parse("https://api.example.com:8443/path").unwrap()
));
assert!(!is_excluded(
&config.network,
&Url::parse("https://api.example.com:443/path").unwrap()
));
}
#[test]
fn config_address_target_is_excluded_automatically() {
let config = RumConfig::builder()
.config_address("https://collector.example.com/rum")
.app_id("app")
.build()
.unwrap();
assert!(is_excluded(
&config.network,
&Url::parse("https://collector.example.com/rum").unwrap()
));
}
}