derust 0.2.0

Easy way to start your Rust asynchronous application server using Tokio and Axum frameworks.
Documentation
use crate::httpx::{AppContext, HttpError, HttpTags};
use opentelemetry::trace::TraceContextExt;
use reqwest::{Client, Method, StatusCode};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, RequestBuilder};
use reqwest_tracing::TracingMiddleware;
use serde::Deserialize;
use std::time::Duration;
use tracing::Span;
use tracing_opentelemetry::OpenTelemetrySpanExt;

#[cfg(any(feature = "statsd", feature = "prometheus"))]
use crate::metricx::{timer, MetricTags, Stopwatch};

#[derive(Clone)]
pub struct HttpClient {
    client: ClientWithMiddleware,
    base_url: String,
}

impl HttpClient {
    pub async fn new(
        user_agent: &str,
        base_url: &str,
        timeout: u64,
        connection_timeout: u64,
    ) -> Result<HttpClient, Box<dyn std::error::Error>> {
        let reqwest_client = Client::builder()
            .user_agent(user_agent)
            .timeout(Duration::from_millis(timeout))
            .connect_timeout(Duration::from_millis(connection_timeout))
            .build()?;

        let client = ClientBuilder::new(reqwest_client)
            .with(TracingMiddleware::default())
            .build();

        Ok(HttpClient {
            client: client,
            base_url: base_url.to_string(),
        })
    }

    pub async fn get<'a, T, S>(
        &self,
        context: &AppContext<S>,
        path: &str,
        query_params: Option<Vec<(&str, &str)>>,
        headers: Option<Vec<(&str, &str)>>,
        tags: &HttpTags,
    ) -> Result<T, HttpError>
    where
        T: for<'de> Deserialize<'de>,
        S: Clone,
    {
        let req = self
            .client
            .get(full_url(&self.base_url, path, query_params));
        let request_context = RequestContext::new(Method::GET, &self.base_url, path);
        send(context, request_context, req, None::<&()>, headers, tags).await
    }

    pub async fn post<'a, T, B, S>(
        &self,
        context: &AppContext<S>,
        path: &str,
        body: &B,
        query_params: Option<Vec<(&str, &str)>>,
        headers: Option<Vec<(&str, &str)>>,
        tags: &HttpTags,
    ) -> Result<T, HttpError>
    where
        T: for<'de> Deserialize<'de>,
        B: serde::Serialize,
        S: Clone,
    {
        let req = self
            .client
            .post(full_url(&self.base_url, path, query_params));
        let request_context = RequestContext::new(Method::POST, &self.base_url, path);
        send(context, request_context, req, Some(body), headers, tags).await
    }

    pub async fn put<'a, T, B, S>(
        &self,
        context: &AppContext<S>,
        path: &str,
        body: &B,
        query_params: Option<Vec<(&str, &str)>>,
        headers: Option<Vec<(&str, &str)>>,
        tags: &HttpTags,
    ) -> Result<T, HttpError>
    where
        T: for<'de> Deserialize<'de>,
        B: serde::Serialize,
        S: Clone,
    {
        let req = self
            .client
            .post(full_url(&self.base_url, path, query_params));
        let request_context = RequestContext::new(Method::PUT, &self.base_url, path);
        send(context, request_context, req, Some(body), headers, tags).await
    }

    pub async fn patch<'a, T, B, S>(
        context: &AppContext<S>,
        client: &HttpClient,
        path: &str,
        body: &B,
        query_params: Option<Vec<(&str, &str)>>,
        headers: Option<Vec<(&str, &str)>>,
        tags: &HttpTags,
    ) -> Result<T, HttpError>
    where
        T: for<'de> Deserialize<'de>,
        B: serde::Serialize,
        S: Clone,
    {
        let req = client
            .client
            .post(full_url(&client.base_url, path, query_params));
        let request_context = RequestContext::new(Method::PATCH, &client.base_url, path);
        send(context, request_context, req, Some(body), headers, tags).await
    }

    pub async fn delete<'a, T, B, S>(
        &self,
        context: &AppContext<S>,
        path: &str,
        query_params: Option<Vec<(&str, &str)>>,
        headers: Option<Vec<(&str, &str)>>,
        tags: &HttpTags,
    ) -> Result<T, HttpError>
    where
        T: for<'de> Deserialize<'de>,
        B: serde::Serialize,
        S: Clone,
    {
        let req = self
            .client
            .post(full_url(&self.base_url, path, query_params));
        let request_context = RequestContext::new(Method::DELETE, &self.base_url, path);
        send(context, request_context, req, None::<&B>, headers, tags).await
    }
}

fn full_url(base_url: &str, url: &str, query_params: Option<Vec<(&str, &str)>>) -> String {
    let params = if let Some(params) = query_params {
        let query = params
            .iter()
            .map(|(k, v)| format!("{}={}", k, v))
            .collect::<Vec<String>>()
            .join("&");
        format!("?{}", query)
    } else {
        "".to_string()
    };

    format!("{}{}{}", base_url, url, params)
}

async fn send<'a, T, B, S>(
    context: &AppContext<S>,
    request_context: RequestContext,
    mut request_builder: RequestBuilder,
    body: Option<&B>,
    headers: Option<Vec<(&str, &str)>>,
    tags: &HttpTags,
) -> Result<T, HttpError>
where
    T: for<'de> Deserialize<'de>,
    B: serde::Serialize,
    S: Clone,
{
    if let Some(b) = body {
        request_builder = request_builder.json(b);
    }

    for (key, value) in headers.unwrap_or_default() {
        request_builder = request_builder.header(key, value);
    }

    if let Some(trace_parent) = get_traceparent() {
        request_builder = request_builder.header("traceparent", &trace_parent);
    }

    #[cfg(any(feature = "statsd", feature = "prometheus"))]
    let stopwatch = start_stopwatch(&context, request_context);

    let res = request_builder.send().await.map_err(|error| {
        HttpError::without_body(
            error.status().unwrap_or(StatusCode::BAD_GATEWAY),
            format!("Failed to send http request: {error}"),
            tags.clone(),
        )
    })?;

    #[cfg(any(feature = "statsd", feature = "prometheus"))]
    stopwatch.record(MetricTags::from([(
        "status",
        res.status().as_u16().to_string(),
    )]));

    if res.status().is_success() {
        res.json().await.map_err(|error| {
            HttpError::without_body(
                error.status().unwrap_or(StatusCode::BAD_GATEWAY),
                format!("Failed to deserialize response: {error}"),
                tags.clone(),
            )
        })
    } else {
        let status_code = res.status();
        let response_message = res
            .text()
            .await
            .unwrap_or("Failed to get response body as text".to_string());
        Err(HttpError::without_body(
            status_code,
            format!("Http response error: {response_message}"),
            tags.clone(),
        ))
    }
}

struct RequestContext {
    method: Method,
    url: String,
    path: String,
}

impl RequestContext {
    fn new(method: Method, url: &str, path: &str) -> Self {
        let url_without_protocol = url.replace("https://", "").replace("http://", "");

        let normalized_path = if path.is_empty() || path == "/" {
            "<no_path>"
        } else {
            path.split('?').next().unwrap_or("<no_path>")
        };

        Self {
            method,
            url: url_without_protocol,
            path: normalized_path.to_string(),
        }
    }
}

fn get_traceparent() -> Option<String> {
    let ctx = Span::current().context();
    let binding = ctx.span();
    let span_context = binding.span_context();

    if span_context.is_valid() {
        Some(format!(
            "{:02x}-{:032x}-{:016x}-{:02x}",
            span_context.trace_flags().to_u8(),
            span_context.trace_id(),
            span_context.span_id(),
            span_context.trace_flags().to_u8()
        ))
    } else {
        None
    }
}

#[cfg(any(feature = "statsd", feature = "prometheus"))]
fn start_stopwatch<S>(context: &AppContext<S>, req: RequestContext) -> Stopwatch<S>
where
    S: Clone,
{
    let metric_tags = MetricTags::http_client(&req.url, &req.path, req.method.as_str());
    timer::start_stopwatch(&context, "http_client_seconds", metric_tags)
}

#[cfg(feature = "http_client")]
mod test {
    use super::*;
    use crate::http_clientx::client::RequestContext;

    #[test]
    #[cfg(any(feature = "statsd", feature = "prometheus"))]
    fn should_remove_params_and_split_path_from_url() {
        let urls_and_paths = vec![
            ("https://www.rust-lang.org", ""),
            ("https://www.rust-lang.org", "/"),
            ("https://www.rust-lang.org", "/anything"),
            ("https://www.rust-lang.org", "/anything/"),
            ("https://www.rust-lang.org", "/anything/123"),
            (
                "https://www.rust-lang.org",
                "/anything/123/0193a2ce-e912-762e-a66b-5b45d44a3a6e",
            ),
            (
                "https://www.rust-lang.org",
                "/anything/123/0193a2ce-e912-762e-a66b-5b45d44a3a6e?foo=bar",
            ),
            (
                "https://www.rust-lang.org",
                "/anything/123/0193a2ce-e912-762e-a66b-5b45d44a3a6e?foo=bar",
            ),
        ];

        let expected_urls_and_paths = vec![
            ("www.rust-lang.org", "<no_path>"),
            ("www.rust-lang.org", "<no_path>"),
            ("www.rust-lang.org", "/anything"),
            ("www.rust-lang.org", "/anything/"),
            ("www.rust-lang.org", "/anything/123"),
            (
                "www.rust-lang.org",
                "/anything/123/0193a2ce-e912-762e-a66b-5b45d44a3a6e",
            ),
            (
                "www.rust-lang.org",
                "/anything/123/0193a2ce-e912-762e-a66b-5b45d44a3a6e",
            ),
            (
                "www.rust-lang.org",
                "/anything/123/0193a2ce-e912-762e-a66b-5b45d44a3a6e",
            ),
        ];

        for (i, (url, path)) in urls_and_paths.iter().enumerate() {
            let rc = RequestContext::new(Method::GET, url, path);
            let (expected_url, expected_path) = expected_urls_and_paths[i];
            assert_eq!(expected_url, rc.url);
            assert_eq!(expected_path, rc.path);
        }
    }
}