use std::collections::HashMap;
use serde::{de::DeserializeOwned, Serialize};
use serde_json::Value;
use crate::{
ApiError, ApiResult, FormLike, JsonResponder, LogConfig, MockServer, RequestBuilder, RequestId,
RequestTraceIdInjector,
};
#[derive(Debug, Default)]
pub struct RequestConfigurator {
log_target: &'static str,
log_enabled: Option<bool>,
require_headers: bool,
}
impl RequestConfigurator {
pub fn new(log_target: &'static str, log_enabled: Option<bool>, require_headers: bool) -> Self {
Self {
log_target,
log_enabled,
require_headers,
}
}
pub fn merge(self, log_target: &'static str, require_headers: bool) -> Self {
RequestConfigurator {
log_target,
require_headers,
..self
}
}
fn build(self, req: &mut RequestBuilder) -> RequestConfig {
let extensions = req.extensions();
extensions
.get::<LogConfig>()
.map(|log_config| {
let request_id = extensions
.get::<RequestId>()
.map(|id| id.request_id.clone())
.unwrap_or_default();
RequestConfig {
log_target: self.log_target,
log_enabled: self.log_enabled.unwrap_or(log_config.enabled),
require_headers: self.require_headers,
request_id: Some(request_id),
}
})
.unwrap_or_else(|| RequestConfig {
log_target: self.log_target,
log_enabled: self.log_enabled.unwrap_or_default(),
require_headers: true,
request_id: None,
})
}
}
#[derive(Debug, Default)]
struct RequestConfig {
log_target: &'static str,
log_enabled: bool,
require_headers: bool,
request_id: Option<String>,
}
impl RequestConfig {
pub fn request_id(&self) -> &str {
self.request_id.as_deref().unwrap_or_default()
}
}
pub async fn _send<O>(mut req: RequestBuilder, config: RequestConfigurator) -> ApiResult<O>
where
O: DeserializeOwned,
{
req = RequestTraceIdInjector::inject_to_builder(req);
let config = config.build(&mut req);
if config.log_enabled {
log::debug!(target: config.log_target, "#[{}] Request => {:?}", config.request_id(), req);
}
send(req, config).await
}
pub async fn _send_json<I, O>(
mut req: RequestBuilder,
json: &I,
config: RequestConfigurator,
) -> ApiResult<O>
where
I: Serialize + ?Sized,
O: DeserializeOwned,
{
req = RequestTraceIdInjector::inject_to_builder(req);
req = req.json(json);
let config = config.build(&mut req);
if config.log_enabled {
log::debug!(target: config.log_target, "#[{}] Request => {:?}", config.request_id(), req);
log::debug!(target: config.log_target, "#[{}] Json: {}", config.request_id(), serde_json::to_string(json).unwrap_or_default());
}
send(req, config).await
}
pub async fn _send_form<I, O>(
mut req: RequestBuilder,
form: I,
config: RequestConfigurator,
) -> ApiResult<O>
where
I: FormLike,
O: DeserializeOwned,
{
req = RequestTraceIdInjector::inject_to_builder(req);
let is_multipart = form.is_multipart();
let meta = form.get_meta();
if is_multipart {
if let Some(multipart) = form.get_multipart() {
req = req.multipart(multipart)
}
} else if let Some(form) = form.get_form() {
req = req.form(&form);
};
let config = config.build(&mut req);
if config.log_enabled {
log::debug!(target: config.log_target, "#[{}] Request => {:?}", config.request_id(), req);
log::debug!(target: config.log_target, "#[{}] {}: {:?}", config.request_id(), if is_multipart { "Multipart"} else {"Form"}, meta);
}
send(req, config).await
}
pub async fn _send_multipart<I, O>(
mut req: RequestBuilder,
form: I,
config: RequestConfigurator,
) -> ApiResult<O>
where
I: FormLike,
O: DeserializeOwned,
{
req = RequestTraceIdInjector::inject_to_builder(req);
let form = form.get_multipart().ok_or(ApiError::NotMultipartForm)?;
req = req.multipart(form);
let config = config.build(&mut req);
if config.log_enabled {
log::debug!(target: config.log_target, "#[{}] Request => {:?}", config.request_id(), req);
}
send(req, config).await
}
async fn send<O>(mut req: RequestBuilder, config: RequestConfig) -> ApiResult<O>
where
O: DeserializeOwned,
{
let extensions = req.extensions();
if let Some(mock) = extensions.get::<MockServer>().cloned() {
let req = req.build()?;
if config.log_enabled {
log::debug!(target: config.log_target, "#[{}] Response <= (MOCK)", config.request_id());
}
match mock.handle(req).await {
Ok(json) => {
if config.log_enabled {
log::debug!(target: config.log_target, "#[{}] Payload: {}", config.request_id(), serde_json::to_string(&json).unwrap_or_default());
}
return serde_json::from_value(json).map_err(|e| e.into());
}
Err(e) => {
if config.log_enabled {
log::debug!(target: config.log_target, "#[{}] Error: {}", config.request_id(), e);
}
return Err(ApiError::Middleware(e));
}
}
}
let res = req.send().await?;
if config.log_enabled {
log::debug!(target: config.log_target, "#[{}] Response <= {:?}", config.request_id(), res);
}
let res = match res.error_for_status() {
Ok(res) => res,
Err(e) => {
let e = e.into();
if config.log_enabled {
log::debug!(target: config.log_target, "#[{}] Error: {}", config.request_id(), e);
}
return Err(e);
}
};
let headers = if config.require_headers {
let mut headers = HashMap::new();
for (name, value) in res.headers() {
if let Ok(value) = value.to_str() {
headers.insert(name.to_string(), value.to_string());
}
}
Some(headers)
} else {
None
};
let mut json = match res.json::<Value>().await {
Ok(json) => {
if config.log_enabled {
log::debug!(target: config.log_target, "#[{}] Payload: {}", config.request_id(), serde_json::to_string(&json).unwrap_or_default());
}
json
}
Err(e) => {
let e = e.into();
if config.log_enabled {
log::debug!(target: config.log_target, "#[{}] Error: {}", config.request_id(), e);
}
return Err(e);
}
};
if let Some(headers) = headers {
if let Value::Object(m) = &mut json {
if let Ok(headers) = serde_json::to_value(headers) {
m.insert("__headers__".to_string(), headers);
}
}
}
match serde_json::from_value(json) {
Ok(r) => Ok(r),
Err(e) => {
let e = e.into();
if config.log_enabled {
log::debug!(target: config.log_target, "#[{}] Error: {}", config.request_id(), e);
}
Err(e)
}
}
}