use std::fmt;
use std::sync::Arc;
use std::time::Duration;
use reqwest::header::HeaderValue;
use reqwest::{Method, Response};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, RequestBuilder};
use reqwest_retry::RetryTransientMiddleware;
use reqwest_retry::policies::ExponentialBackoff;
use serde::Serialize;
use url::Url;
use uuid::Uuid;
use super::config::{NvisyRtBuilder, NvisyRtOptions};
#[cfg(feature = "tracing")]
use crate::TRACING_TARGET_CLIENT;
use crate::error::{Error, Result};
use crate::model::ApiError;
#[derive(Clone)]
pub struct NvisyRt {
pub(crate) inner: Arc<NvisyRtInner>,
}
const ACTOR_ID_HEADER: &str = "x-actor-id";
pub(crate) struct NvisyRtInner {
pub(crate) actor_id: Uuid,
pub(crate) base_url: Url,
pub(crate) timeout: Duration,
pub(crate) client: ClientWithMiddleware,
}
impl NvisyRt {
pub fn new() -> Self {
NvisyRtBuilder::default()
.build()
.expect("default config is valid")
}
pub fn builder() -> NvisyRtBuilder {
NvisyRtBuilder::default()
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip(options)))]
pub(crate) fn from_options(options: NvisyRtOptions) -> Result<Self> {
#[cfg(feature = "tracing")]
tracing::debug!(target: TRACING_TARGET_CLIENT, "creating client");
let base_client = if let Some(custom_client) = options.client {
custom_client
} else {
reqwest::Client::builder()
.timeout(options.timeout)
.user_agent(&options.user_agent)
.build()?
};
let retry_policy =
ExponentialBackoff::builder().build_with_max_retries(options.max_retries);
let builder = ClientBuilder::new(base_client)
.with(RetryTransientMiddleware::new_with_policy(retry_policy));
#[cfg(feature = "tracing")]
let builder = builder.with(reqwest_tracing::TracingMiddleware::default());
let client = builder.build();
#[cfg(feature = "tracing")]
tracing::info!(
target: TRACING_TARGET_CLIENT,
base_url = %base_url,
timeout_secs = options.timeout.as_secs(),
"client created"
);
let base_url = Url::parse(&options.base_url)?;
let inner = Arc::new(NvisyRtInner {
actor_id: options.actor_id,
base_url,
timeout: options.timeout,
client,
});
Ok(Self { inner })
}
pub fn actor_id(&self) -> Uuid {
self.inner.actor_id
}
pub fn base_url(&self) -> &Url {
&self.inner.base_url
}
pub fn timeout(&self) -> Duration {
self.inner.timeout
}
pub(crate) fn resolve_url(&self, path: &str) -> Url {
let mut url = self.inner.base_url.clone();
url.set_path(&format!("{}{}", url.path().trim_end_matches('/'), path));
url
}
pub(crate) fn request(&self, method: Method, url: Url) -> RequestBuilder {
#[cfg(feature = "tracing")]
tracing::trace!(
target: TRACING_TARGET_CLIENT,
%url,
%method,
"building request"
);
self.inner
.client
.request(method, url)
.header(
ACTOR_ID_HEADER,
HeaderValue::from_str(&self.inner.actor_id.to_string())
.expect("UUID is valid header value"),
)
.timeout(self.inner.timeout)
}
pub(crate) async fn check_response(&self, response: Response) -> Result<Response> {
if response.status().is_success() {
return Ok(response);
}
let status = response.status().as_u16();
#[cfg(feature = "tracing")]
tracing::warn!(
target: TRACING_TARGET_CLIENT,
status,
"api error response"
);
let reqwest_err = response.error_for_status_ref().unwrap_err();
match response.json::<ApiError>().await {
Ok(mut api_error) => {
api_error.status = status;
Err(Error::Api(api_error))
}
Err(_) => Err(Error::Reqwest(reqwest_err)),
}
}
pub(crate) async fn send(&self, method: Method, path: &str) -> Result<Response> {
#[cfg(feature = "tracing")]
tracing::debug!(target: TRACING_TARGET_CLIENT, %method, path, "sending request");
let url = self.resolve_url(path);
let response = self.request(method, url).send().await?;
#[cfg(feature = "tracing")]
tracing::debug!(
target: TRACING_TARGET_CLIENT,
status = response.status().as_u16(),
path,
"response received"
);
self.check_response(response).await
}
pub(crate) async fn send_json<T: Serialize>(
&self,
method: Method,
path: &str,
data: &T,
) -> Result<Response> {
#[cfg(feature = "tracing")]
tracing::debug!(target: TRACING_TARGET_CLIENT, %method, path, "sending json request");
let url = self.resolve_url(path);
let response = self.request(method, url).json(data).send().await?;
#[cfg(feature = "tracing")]
tracing::debug!(
target: TRACING_TARGET_CLIENT,
status = response.status().as_u16(),
path,
"response received"
);
self.check_response(response).await
}
}
impl Default for NvisyRt {
fn default() -> Self {
Self::new()
}
}
impl fmt::Debug for NvisyRt {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NvisyRt")
.field("actor_id", &self.inner.actor_id)
.field("base_url", &self.inner.base_url)
.field("timeout", &self.inner.timeout)
.finish()
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
#[test]
fn test_client_creation() -> Result<()> {
let client = NvisyRt::new();
assert_eq!(client.base_url().as_str(), "http://localhost:8080/");
Ok(())
}
#[test]
fn test_client_creation_with_custom_config() -> Result<()> {
let client = NvisyRt::builder()
.with_base_url("https://custom.rt.api.com")
.with_timeout(Duration::from_secs(60))
.build()?;
assert_eq!(client.base_url().as_str(), "https://custom.rt.api.com/");
assert_eq!(client.timeout(), Duration::from_secs(60));
Ok(())
}
}