use reqwest::RequestBuilder;
use serde::de::DeserializeOwned;
use tracing::{event, Level};
use super::{
network_env::NetworkTrait, NetAuth, NetRequest, NetRequestLogging, NetResponse, Network,
NetworkError, RequestMethod, ResponseType, StatusCode,
};
trait WithAuthentiction {
fn auth(self, auth: Option<NetAuth>) -> reqwest::RequestBuilder;
}
impl WithAuthentiction for reqwest::RequestBuilder {
fn auth(self, auth: Option<NetAuth>) -> reqwest::RequestBuilder {
if let Some(auth) = auth {
self.basic_auth(
auth.username().to_string(),
Some(auth.password().expose_password()),
)
} else {
self
}
}
}
trait WithResponseType {
fn response_type(self, response_type: ResponseType) -> Self;
}
impl WithResponseType for reqwest::RequestBuilder {
fn response_type(self, response_type: ResponseType) -> Self {
match response_type {
ResponseType::Json => self.header(reqwest::header::ACCEPT, "application/json"),
ResponseType::Xml => self.header(reqwest::header::ACCEPT, "application/xml"),
ResponseType::Text => self.header(reqwest::header::ACCEPT, "text/plain"),
ResponseType::None => self,
}
}
}
#[derive(Debug, Clone)]
pub struct RealNetwork {
client: reqwest::Client,
}
impl RealNetwork {
#[cfg(not(tarpaulin_include))]
pub fn new() -> Self {
Self {
client: reqwest::Client::new(),
}
}
async fn send<T: DeserializeOwned + std::fmt::Debug>(
&self,
request: RequestBuilder,
net_request: NetRequest,
) -> Result<NetResponse<T>, NetworkError> {
let response = request.send().await?;
let status_code = response.status();
let text = response.text().await?;
if status_code.is_success() {
tracing::trace!(status = ?status_code);
} else {
tracing::error!("Request failed");
self.log_request(&net_request);
self.log_response(&net_request, &status_code, text.as_str());
return Err(NetworkError::RequestFailed(
net_request.method(),
status_code,
net_request.url().clone(),
));
}
match net_request.response_type() {
ResponseType::Json => serde_json::from_str(text.as_str())
.map(Some)
.map_err(NetworkError::from),
ResponseType::Xml => serde_xml_rs::from_str(text.as_str())
.map(Some)
.map_err(NetworkError::from),
ResponseType::Text => {
panic!("text response type not implemented - use send_for_text(...) instead")
}
ResponseType::None => Ok(None),
}
.map(|response_body| {
match net_request.log() {
NetRequestLogging::None => (),
NetRequestLogging::Request => self.log_request(&net_request),
NetRequestLogging::Response => {
self.log_response(&net_request, &status_code, text.as_str());
}
NetRequestLogging::Both => {
self.log_request(&net_request);
self.log_response(&net_request, &status_code, text.as_str());
}
};
NetResponse::new(
net_request.method(),
net_request.url(),
status_code,
response_body,
)
})
}
async fn send_for_text(
&self,
request: RequestBuilder,
net_request: NetRequest,
) -> Result<NetResponse<String>, NetworkError> {
let response = request.send().await?;
let status_code = response.status();
if status_code.is_success() {
tracing::trace!(status = ?status_code);
} else {
tracing::error!(status = ?status_code, request = ?net_request);
}
match net_request.response_type() {
ResponseType::Text => {
let text = response.text().await?;
Ok(Some(text))
}
_ => panic!("text response type not implemented - use send(...) instead"),
}
.map(|response_body| {
NetResponse::new(
net_request.method(),
net_request.url(),
status_code,
response_body,
)
})
}
fn log_request(&self, net_request: &NetRequest) {
tracing::info!(?net_request, "RealNetworkEnv::request");
}
fn log_response(
&self,
net_request: &NetRequest,
status_code: &StatusCode,
response_body: &str,
) {
tracing::info!(
?net_request,
status = ?status_code,
?response_body,
"RealNetworkEnv::response"
);
}
}
impl Default for RealNetwork {
fn default() -> Self {
Self::new()
}
}
impl From<RealNetwork> for Network {
fn from(real: RealNetwork) -> Self {
Self::Real(real)
}
}
#[cfg(not(tarpaulin_include))]
#[async_trait::async_trait]
impl NetworkTrait for RealNetwork {
#[tracing::instrument(skip_all, fields(request = %net_request.as_trace()), level = "trace")]
async fn get<T: DeserializeOwned + std::fmt::Debug>(
&self,
net_request: NetRequest,
) -> Result<NetResponse<T>, NetworkError> {
tracing::debug!("RealNetworkEnv::get({:?})", net_request);
let url = net_request.url();
let auth = net_request.auth();
let response_type = net_request.response_type();
let headers = net_request.headers();
let request = self
.client
.get(url.to_string())
.auth(auth)
.response_type(response_type)
.headers(headers);
self.send(request, net_request).await
}
#[tracing::instrument(skip_all, fields(request = %net_request.as_trace()), level = "trace")]
async fn post_json<Reply: DeserializeOwned + std::fmt::Debug>(
&self,
net_request: NetRequest,
) -> Result<NetResponse<Reply>, NetworkError> {
let url = net_request.url();
let body = net_request.body();
let auth = net_request.auth();
let response_type = net_request.response_type();
let headers = net_request.headers();
let body = String::try_from(body)?;
let request = self
.client
.post(url.to_string())
.body(body)
.auth(auth)
.response_type(response_type)
.headers(headers);
self.send(request, net_request).await
}
#[tracing::instrument(skip_all, fields(request = %net_request.as_trace()), level = "trace")]
async fn post_string(
&self,
net_request: NetRequest,
) -> Result<NetResponse<String>, NetworkError> {
let url = net_request.url();
let body = net_request.body();
let auth = net_request.auth();
let response_type = net_request.response_type();
let headers = net_request.headers();
let body = String::try_from(body)?;
let request = self
.client
.post(url.to_string())
.body(body)
.auth(auth)
.response_type(response_type)
.headers(headers);
self.send_for_text(request, net_request).await
}
#[tracing::instrument(skip_all, fields(request = %net_request.as_trace()), level = "trace")]
async fn put_json<Reply: DeserializeOwned + std::fmt::Debug>(
&self,
net_request: NetRequest,
) -> Result<NetResponse<Reply>, NetworkError> {
let url = net_request.url();
let body = net_request.body();
let auth = net_request.auth();
let response_type = net_request.response_type();
let headers = net_request.headers();
let body = String::try_from(body)?;
let request = self
.client
.put(url.to_string())
.body(body)
.auth(auth)
.response_type(response_type)
.headers(headers);
self.send(request, net_request).await
}
#[tracing::instrument(skip_all, fields(request = %net_request.as_trace()), level = "trace")]
async fn put_string(
&self,
net_request: NetRequest,
) -> Result<NetResponse<String>, NetworkError> {
let url = net_request.url();
let body = net_request.body();
let auth = net_request.auth();
let response_type = net_request.response_type();
let headers = net_request.headers();
let body = String::try_from(body)?;
let request = self
.client
.put(url.to_string())
.body(body)
.auth(auth)
.response_type(response_type)
.headers(headers);
self.send_for_text(request, net_request).await
}
#[tracing::instrument(skip_all, fields(request = %net_request.as_trace()), level = "trace")]
async fn patch_json<Reply: DeserializeOwned + std::fmt::Debug>(
&self,
net_request: NetRequest,
) -> Result<NetResponse<Reply>, NetworkError> {
let url = net_request.url();
let body = net_request.body();
let auth = net_request.auth();
let response_type = net_request.response_type();
let headers = net_request.headers();
let body = String::try_from(body)?;
let request = self
.client
.patch(url.to_string())
.body(body)
.auth(auth)
.response_type(response_type)
.headers(headers);
self.send(request, net_request).await
}
#[tracing::instrument(skip_all, fields(request = %net_request.as_trace()), level = "trace")]
async fn delete(&self, net_request: NetRequest) -> Result<NetResponse<()>, NetworkError> {
let url = net_request.url();
let auth = net_request.auth();
let response_type = net_request.response_type();
let headers = net_request.headers();
let request = self
.client
.delete(url.to_string())
.auth(auth)
.response_type(response_type)
.headers(headers);
self.send(request, net_request).await
}
#[tracing::instrument(skip_all, fields(request = %net_request.as_trace()), level = "trace")]
async fn propfind<Reply: DeserializeOwned + std::fmt::Debug>(
&self,
net_request: NetRequest,
) -> Result<NetResponse<Reply>, NetworkError> {
let url = net_request.url();
let auth = net_request.auth();
let response_type = net_request.response_type();
let headers = net_request.headers();
let request = self
.client
.request(reqwest::Method::from_bytes(b"PROPFIND")?, url.to_string())
.auth(auth)
.response_type(response_type)
.headers(headers);
let response = request.send().await?;
let status_code = response.status();
event!(Level::TRACE, status = %status_code);
let body = match response_type {
ResponseType::Xml => serde_xml_rs::from_str(response.text().await?.as_str()),
_ => Err(NetworkError::InvalidResponseType)?,
};
body.map_err(Into::into).map(|response_body| {
NetResponse::new(RequestMethod::Propfind, url, status_code, response_body)
})
}
#[tracing::instrument(skip_all, fields(request = %net_request.as_trace()), level = "trace")]
async fn get_string(
&self,
net_request: NetRequest,
) -> Result<NetResponse<String>, NetworkError> {
let url = net_request.url();
let auth = net_request.auth();
let response_type = net_request.response_type();
let headers = net_request.headers();
let request = self
.client
.get(url.to_string())
.auth(auth)
.response_type(response_type)
.headers(headers);
self.send_for_text(request, net_request).await
}
#[tracing::instrument(skip_all, fields(request = %net_request.as_trace()), level = "trace")]
async fn propfind_string(
&self,
net_request: NetRequest,
) -> Result<NetResponse<String>, NetworkError> {
let url = net_request.url();
let auth = net_request.auth();
let response_type = net_request.response_type();
let headers = net_request.headers();
let request = self
.client
.request(reqwest::Method::from_bytes(b"PROPFIND")?, url.to_string())
.auth(auth)
.response_type(response_type)
.headers(headers);
self.send_for_text(request, net_request).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::network::{
net_auth::{NetAuthPassword, NetAuthUsername},
NetUrl, NetworkError, StatusCode,
};
use assert2::let_assert;
use pretty_assertions::assert_eq;
use std::collections::HashMap;
use tokio_test::block_on;
#[derive(Debug, PartialEq, Eq, serde::Deserialize)]
pub struct GetResponse {
pub args: HashMap<String, String>,
pub url: String,
}
#[derive(Debug, PartialEq, Eq, serde::Deserialize)]
pub struct PostResponse {
pub args: HashMap<String, String>,
pub url: String,
pub data: String,
}
#[derive(Debug, PartialEq, Eq, serde::Deserialize)]
pub struct PutResponse {
pub args: HashMap<String, String>,
pub url: String,
pub data: String,
}
#[test_log::test]
fn test_with_authentication_none() {
let client = reqwest::Client::new();
let request = client.get("https://httpbin.org/get").auth(None);
let_assert!(Ok(build) = request.build());
assert!(build.headers().is_empty());
}
#[test_log::test]
fn test_with_authentication_some() {
let client = reqwest::Client::new();
let request = client
.get("https://httpbin.org/get")
.auth(Some(NetAuth::new(
NetAuthUsername::new("user".into()),
NetAuthPassword::new("pass".into()),
)));
let_assert!(Ok(build) = request.build());
let headers = build.headers();
let_assert!(Some(authorization) = headers.get(reqwest::header::AUTHORIZATION));
assert_eq!(authorization, "Basic dXNlcjpwYXNz");
}
#[test_log::test]
fn test_with_response_type_json() {
let client = reqwest::Client::new();
let request = client
.get("https://httpbin.org/get")
.response_type(ResponseType::Json);
let_assert!(Ok(request) = request.build());
let headers = request.headers();
let_assert!(Some(accept) = headers.get(reqwest::header::ACCEPT));
assert_eq!(accept, "application/json");
}
#[test_log::test]
fn test_with_response_type_xml() {
let client = reqwest::Client::new();
let request = client
.get("https://httpbin.org/get")
.response_type(ResponseType::Xml);
let_assert!(Ok(request) = request.build());
let headers = request.headers();
let_assert!(Some(accept) = headers.get(reqwest::header::ACCEPT));
assert_eq!(accept, "application/xml");
}
#[test_log::test]
fn test_with_response_type_text() {
let client = reqwest::Client::new();
let request = client
.get("https://httpbin.org/get")
.response_type(ResponseType::Text);
let_assert!(Ok(request) = request.build());
let headers = request.headers();
let_assert!(Some(accept) = headers.get(reqwest::header::ACCEPT));
assert_eq!(accept, "text/plain");
}
#[test_log::test]
fn test_with_response_type_none() {
let client = reqwest::Client::new();
let request = client
.get("https://httpbin.org/get")
.response_type(ResponseType::None);
let_assert!(Ok(request) = request.build());
let headers = request.headers();
assert!(headers.get(reqwest::header::ACCEPT).is_none());
}
#[test_log::test]
#[ignore]
fn test_real_network_env_get() -> Result<(), NetworkError> {
let env = RealNetwork::new();
let net_request =
NetRequest::get(NetUrl::new("https://httpbin.org/get?arg=baz".into())).build();
let response: NetResponse<GetResponse> = block_on(env.get(net_request))?;
let_assert!(Some(body) = response.response_body());
assert_eq!(
body.args.get("arg"),
Some(&"baz".to_string()),
"args from body"
);
Ok(())
}
#[test_log::test]
#[ignore]
fn test_real_network_env_get_error() {
let env = RealNetwork::new();
let net_request =
NetRequest::get(NetUrl::new("https://httpbin.org/status/400".into())).build();
let result: Result<NetResponse<String>, NetworkError> = block_on(env.get(net_request));
assert!(result.is_err(), "response is not a String");
}
#[test_log::test]
#[ignore]
fn test_real_network_env_post_json() -> Result<(), NetworkError> {
let env = RealNetwork::new();
let body = serde_json::json!({"foo":"bar"});
let net_request = NetRequest::post(NetUrl::new("https://httpbin.org/post?arg=baz".into()))
.json_body(body)?
.build();
let response: NetResponse<PostResponse> = block_on(env.post_json(net_request))?;
let_assert!(Some(body) = response.response_body());
assert_eq!(
body.args.get("arg"),
Some(&"baz".to_string()),
"args from body"
);
assert_eq!(body.data, "{\"foo\":\"bar\"}".to_string(), "data from body");
Ok(())
}
#[test_log::test]
#[ignore]
fn test_real_network_env_post_json_error() -> Result<(), NetworkError> {
let env = RealNetwork::new();
let net_request =
NetRequest::post(NetUrl::new("https://httpbin.org/status/400".into())).build();
let response: Result<NetResponse<PostResponse>, NetworkError> =
block_on(env.post_json(net_request));
match response {
Ok(_) => panic!("expected error"),
Err(e) => match e {
NetworkError::MockError(method, status, _url) => {
assert_eq!(method, RequestMethod::Post);
assert_eq!(status, StatusCode::BAD_REQUEST)
}
_ => panic!("unexpected error type"),
},
}
Ok(())
}
#[test_log::test]
#[ignore]
fn test_real_network_env_put() -> Result<(), NetworkError> {
let env = RealNetwork::new();
let body = serde_json::json!({"foo":"bar"});
let net_request = NetRequest::put(NetUrl::new("https://httpbin.org/put?arg=baz".into()))
.json_body(body)?
.build();
let response: NetResponse<PutResponse> = block_on(env.put_json(net_request))?;
let_assert!(Some(body) = response.response_body());
assert_eq!(
body.args.get("arg"),
Some(&"baz".to_string()),
"args from body"
);
assert_eq!(body.data, "{\"foo\":\"bar\"}".to_string(), "data from body");
Ok(())
}
#[test_log::test]
#[ignore]
fn test_real_network_env_put_error() {
let env = RealNetwork::new();
let net_request =
NetRequest::put(NetUrl::new("https://httpbin.org/status/400".into())).build();
let result: Result<NetResponse<String>, NetworkError> = block_on(env.put_json(net_request));
assert!(result.is_err(), "response is not a String");
}
#[test_log::test]
#[ignore]
fn test_real_network_env_delete() {
let env = RealNetwork::new();
let net_request =
NetRequest::delete(NetUrl::new("https://httpbin.org/delete?arg=baz".into())).build();
let result = block_on(env.delete(net_request));
assert!(result.is_ok());
}
}