#![cfg(not(tarpaulin_include))]
use serde::de::DeserializeOwned;
use tracing::{event, Level};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use crate::network::StatusCode;
use super::network_env::NetworkTrait;
use super::{
NetRequest, NetResponse, NetUrl, Network, NetworkError, RequestBody, RequestMethod,
ResponseType, SavedRequest,
};
#[derive(Debug, Clone)]
pub struct MockNetwork {
requests: Arc<Mutex<Vec<SavedRequest>>>,
get_responses: HashMap<NetUrl, (StatusCode, String)>,
get_errors: HashMap<NetUrl, String>,
post_responses: HashMap<NetUrl, (StatusCode, String)>,
post_errors: HashMap<NetUrl, String>,
put_responses: HashMap<NetUrl, (StatusCode, String)>,
put_errors: HashMap<NetUrl, String>,
patch_responses: HashMap<NetUrl, (StatusCode, String)>,
patch_errors: HashMap<NetUrl, String>,
delete_responses: HashMap<NetUrl, (StatusCode, String)>,
delete_errors: HashMap<NetUrl, String>,
propfind_responses: HashMap<NetUrl, (StatusCode, String)>,
propfind_errors: HashMap<NetUrl, String>,
}
impl MockNetwork {
pub fn new() -> Self {
Self {
requests: Arc::new(Mutex::new(Vec::new())),
get_responses: HashMap::new(),
get_errors: HashMap::new(),
post_responses: HashMap::new(),
post_errors: HashMap::new(),
put_responses: HashMap::new(),
put_errors: HashMap::new(),
patch_responses: HashMap::new(),
patch_errors: HashMap::new(),
delete_responses: HashMap::new(),
delete_errors: HashMap::new(),
propfind_responses: HashMap::new(),
propfind_errors: HashMap::new(),
}
}
pub fn requests(&self) -> Vec<SavedRequest> {
unsafe { self.requests.lock().unwrap_unchecked().clone() }
}
pub fn add_get_response(&mut self, url: &str, status: StatusCode, body: &str) {
self.get_responses
.insert(NetUrl::new(url.to_string()), (status, body.to_string()));
}
pub fn add_get_error(&mut self, url: &str, error: &str) {
self.get_errors
.insert(NetUrl::new(url.to_string()), error.to_string());
}
pub fn add_post_response(&mut self, url: &str, status: StatusCode, body: &str) {
self.post_responses
.insert(NetUrl::new(url.to_string()), (status, body.to_string()));
}
pub fn add_post_error(&mut self, url: &str, error: &str) {
self.post_errors
.insert(NetUrl::new(url.to_string()), error.to_string());
}
pub fn add_put_response(&mut self, url: &str, status: StatusCode, body: &str) {
self.put_responses
.insert(NetUrl::new(url.to_string()), (status, body.to_string()));
}
pub fn add_put_error(&mut self, url: &str, error: &str) {
self.put_errors
.insert(NetUrl::new(url.to_string()), error.to_string());
}
pub fn add_patch_response(&mut self, url: &str, status: StatusCode, body: &str) {
self.patch_responses
.insert(NetUrl::new(url.to_string()), (status, body.to_string()));
}
pub fn add_patch_error(&mut self, url: &str, error: &str) {
self.patch_errors
.insert(NetUrl::new(url.to_string()), error.to_string());
}
pub fn add_delete_response(&mut self, url: &str, status: StatusCode, body: &str) {
self.delete_responses
.insert(NetUrl::new(url.to_string()), (status, body.to_string()));
}
pub fn add_delete_error(&mut self, url: &str, error: &str) {
self.delete_errors
.insert(NetUrl::new(url.to_string()), error.to_string());
}
pub fn add_propfind_response(&mut self, url: &str, status: StatusCode, body: &str) {
self.propfind_responses
.insert(NetUrl::new(url.to_string()), (status, body.to_string()));
}
pub fn add_propfind_error(&mut self, url: &str, error: &str) {
self.propfind_errors
.insert(NetUrl::new(url.to_string()), error.to_string());
}
fn save_request(&self, method: RequestMethod, url: &str, body: RequestBody) {
unsafe {
self.requests
.lock()
.unwrap_unchecked()
.push(SavedRequest::new(method, url, body));
}
}
#[tracing::instrument(skip_all)]
fn call<Reply: DeserializeOwned>(
&self,
net_request: NetRequest,
) -> Result<NetResponse<Reply>, NetworkError> {
tracing::info!("MockNetworkEnv::call({:?})", net_request);
let method = net_request.method();
let url = net_request.url();
let body = net_request.body();
let response_type = net_request.response_type();
self.save_request(method, url, body.clone());
let errors = match method {
RequestMethod::Get => &self.get_errors,
RequestMethod::Post => &self.post_errors,
RequestMethod::Put => &self.put_errors,
RequestMethod::Patch => &self.patch_errors,
RequestMethod::Propfind => &self.propfind_errors,
RequestMethod::Delete => &self.delete_errors,
};
if let Some(error) = errors.get(url) {
event!(
Level::INFO,
"MockNetworkEnv::{}({}) -> error: {}",
method,
**url,
error
);
Err(NetworkError::RequestError(
method,
StatusCode::INTERNAL_SERVER_ERROR,
url.clone(),
))
} else {
let responses = match method {
RequestMethod::Get => &self.get_responses,
RequestMethod::Post => &self.post_responses,
RequestMethod::Put => &self.put_responses,
RequestMethod::Patch => &self.patch_responses,
RequestMethod::Propfind => &self.propfind_responses,
RequestMethod::Delete => &self.delete_responses,
};
let (status, response) = responses.get(url).ok_or_else(|| {
tracing::error!(?method, ?url, "unexpected request");
NetworkError::MockError(method, StatusCode::NOT_IMPLEMENTED, url.clone())
})?;
if status.is_client_error() || status.is_server_error() {
event!(
Level::INFO,
"MockNetworkEnv::{}({}) -> error: {}",
method,
url,
response
);
Err(NetworkError::RequestError(method, *status, url.clone()))
} else {
event!(
Level::INFO,
"MockNetworkEnv::{}({}) -> response: {}",
method,
url,
response
);
let response_body: Option<Reply> = match response_type {
ResponseType::Json => Some(serde_json::from_str(response)?),
ResponseType::Xml => Some(serde_xml_rs::from_str(response)?),
ResponseType::None => None,
ResponseType::Text => {
Err(NetworkError::UnexpectedMockRequest {
method,
request_url: url.clone(),
})?
}
};
Ok(NetResponse::new(method, url, *status, response_body))
}
}
}
fn call_string(&self, net_request: NetRequest) -> Result<NetResponse<String>, NetworkError> {
let method = net_request.method();
let url = net_request.url();
let body = net_request.body();
match body {
RequestBody::String(_) => Ok(()),
RequestBody::None => Ok(()),
_ => Err(NetworkError::InvalidRequestBody),
}?;
self.save_request(method, url, body.clone());
event!(Level::INFO, "MockNetworkEnv::{}({})", method, url);
self.check_for_error(method, url).map_or_else(
|| self.as_response(method, url, &net_request),
|error| as_server_error(method, url, error),
)
}
fn as_response(
&self,
method: RequestMethod,
url: &NetUrl,
net_request: &NetRequest,
) -> Result<NetResponse<String>, NetworkError> {
event!(Level::INFO, "url: {}", url);
let (status, response) = self.check_for_response(method, url).ok_or_else(|| {
NetworkError::MockError(method, StatusCode::NOT_IMPLEMENTED, url.clone())
})?;
event!(Level::INFO, "status: {:?}", status);
let response_body = match net_request.response_type() {
ResponseType::None => None,
_ => {
let response_body: String = response.to_string();
event!(Level::INFO, "response_body: {:?}", response_body);
Some(response_body)
}
};
Ok(NetResponse::new(method, url, *status, response_body))
}
fn check_for_response(
&self,
method: RequestMethod,
url: &NetUrl,
) -> Option<&(StatusCode, String)> {
(match method {
RequestMethod::Get => &self.get_responses,
RequestMethod::Post => &self.post_responses,
RequestMethod::Put => &self.put_responses,
RequestMethod::Patch => &self.patch_responses,
RequestMethod::Propfind => &self.propfind_responses,
RequestMethod::Delete => &self.delete_responses,
})
.get(url)
}
fn check_for_error(&self, method: RequestMethod, url: &NetUrl) -> Option<&String> {
(match method {
RequestMethod::Get => &self.get_errors,
RequestMethod::Post => &self.post_errors,
RequestMethod::Put => &self.put_errors,
RequestMethod::Patch => &self.patch_errors,
RequestMethod::Propfind => &self.propfind_errors,
RequestMethod::Delete => &self.delete_errors,
})
.get(url)
}
}
fn as_server_error(
method: RequestMethod,
url: &NetUrl,
error: &String,
) -> Result<NetResponse<String>, NetworkError> {
event!(
Level::INFO,
"MockNetworkEnv::{}({}) -> error: {}",
method,
url,
error
);
Err(NetworkError::RequestError(
method,
StatusCode::INTERNAL_SERVER_ERROR,
url.clone(),
))
}
impl Default for MockNetwork {
fn default() -> Self {
Self::new()
}
}
impl From<MockNetwork> for Network {
fn from(mock: MockNetwork) -> Self {
Self::Mock(mock)
}
}
#[async_trait::async_trait]
impl NetworkTrait for MockNetwork {
async fn get<Reply: DeserializeOwned>(
&self,
net_request: NetRequest,
) -> Result<NetResponse<Reply>, NetworkError> {
assert_eq!(
net_request.method(),
RequestMethod::Get,
"get method must be RequestMethod::Get"
);
self.call(net_request)
}
async fn get_string(
&self,
net_request: NetRequest,
) -> Result<NetResponse<String>, NetworkError> {
assert_eq!(
net_request.method(),
RequestMethod::Get,
"get_string method must be RequestMethod::Get"
);
assert_eq!(
net_request.response_type(),
ResponseType::Text,
"get_string response_type must be ResponseType::Text"
);
self.call_string(net_request)
}
async fn post_json<Reply: DeserializeOwned>(
&self,
net_request: NetRequest,
) -> Result<NetResponse<Reply>, NetworkError> {
assert_eq!(
net_request.method(),
RequestMethod::Post,
"post method must be RequestMethod::Post"
);
assert!(
matches!(net_request.body(), RequestBody::Json(_)),
"request body must be RequestBody::Json"
);
self.call(net_request)
}
async fn post_string(
&self,
net_request: NetRequest,
) -> Result<NetResponse<String>, NetworkError> {
assert_eq!(
net_request.method(),
RequestMethod::Post,
"post_string method must be RequestMethod::Post"
);
assert_eq!(
net_request.response_type(),
ResponseType::Text,
"post_string response_type must be ResponseType::Text"
);
self.call_string(net_request)
}
async fn put_json<Reply: DeserializeOwned>(
&self,
net_request: NetRequest,
) -> Result<NetResponse<Reply>, NetworkError> {
assert_eq!(
net_request.method(),
RequestMethod::Put,
"put method must be RequestMethod::Put"
);
assert!(
matches!(net_request.body(), RequestBody::Json(_)),
"request body must be RequestBody::Json"
);
self.call(net_request)
}
async fn put_string(
&self,
net_request: NetRequest,
) -> Result<NetResponse<String>, NetworkError> {
assert_eq!(
net_request.method(),
RequestMethod::Put,
"put_string method must be RequestMethod::Put"
);
assert_eq!(
net_request.response_type(),
ResponseType::Text,
"put_string response_type must be ResponseType::Text"
);
self.call_string(net_request)
}
async fn patch_json<Reply: DeserializeOwned + std::fmt::Debug>(
&self,
net_request: NetRequest,
) -> Result<NetResponse<Reply>, NetworkError> {
assert_eq!(
net_request.method(),
RequestMethod::Patch,
"patch method must be RequestMethod::Patch"
);
assert!(
matches!(net_request.body(), RequestBody::Json(_)),
"request body must be RequestBody::Json"
);
self.call(net_request)
}
#[tracing::instrument(skip(self))]
async fn delete(&self, net_request: NetRequest) -> Result<NetResponse<()>, NetworkError> {
assert_eq!(
net_request.method(),
RequestMethod::Delete,
"delete method must be RequestMethod::Delete"
);
assert_eq!(
net_request.response_type(),
ResponseType::None,
"delete response_type must be ResponseType::None"
);
self.call(net_request)
}
async fn propfind<Reply: DeserializeOwned + std::fmt::Debug>(
&self,
net_request: NetRequest,
) -> Result<NetResponse<Reply>, NetworkError> {
assert_eq!(
net_request.method(),
RequestMethod::Propfind,
"propfind method must be RequestMethod::Propfind"
);
assert_eq!(
net_request.response_type(),
ResponseType::Xml,
"propfind response_type must be ResponseType::Xml"
);
assert_eq!(
net_request.body(),
&RequestBody::None,
"delete body must be RequestBody::None"
);
self.call(net_request)
}
async fn propfind_string(
&self,
net_request: NetRequest,
) -> Result<NetResponse<String>, NetworkError> {
assert_eq!(
net_request.method(),
RequestMethod::Propfind,
"propfind_string method must be RequestMethod::Propfind"
);
assert_eq!(
net_request.response_type(),
ResponseType::Text,
"propfind_string response_type must be ResponseType::Text"
);
assert_eq!(
net_request.body(),
&RequestBody::None,
"delete body must be RequestBody::None"
);
self.call_string(net_request)
}
}
#[cfg(test)]
mod tests {
use crate::network::{NetResponse, NetworkError, RequestMethod};
use super::*;
use reqwest::StatusCode;
use pretty_assertions::assert_eq;
use serde_json::json;
use tokio_test::block_on;
#[test_log::test]
fn test_mock_network_env_get() -> Result<(), NetworkError> {
let mut net = MockNetwork::new();
net.add_get_response("https://httpbin.org", StatusCode::OK, r#"{"foo": "bar"}"#);
let net_request = NetRequest::get(NetUrl::new("https://httpbin.org".into())).build();
let response: NetResponse<HashMap<String, String>> = block_on(net.get(net_request))?;
assert_eq!(
response
.response_body()
.and_then(|body| body.get("foo").cloned()),
Some("bar".to_string())
);
assert_eq!(
net.requests(),
vec![SavedRequest::new(
RequestMethod::Get,
"https://httpbin.org",
RequestBody::None
)]
);
Ok(())
}
#[test_log::test]
fn test_mock_network_env_get_error() -> Result<(), NetworkError> {
let mut net = MockNetwork::new();
net.add_get_error("https://httpbin.org", "error");
let net_request = NetRequest::get(NetUrl::new("https://httpbin.org".into())).build();
let result: Result<NetResponse<HashMap<String, String>>, NetworkError> =
block_on(net.get(net_request));
assert!(result.is_err());
assert_eq!(
net.requests(),
vec![SavedRequest::new(
RequestMethod::Get,
"https://httpbin.org",
RequestBody::None
)]
);
Ok(())
}
#[test_log::test]
fn test_mock_network_env_get_string() -> Result<(), NetworkError> {
let mut net = MockNetwork::new();
net.add_get_response("https://httpbin.org", StatusCode::OK, r#"{"foo":"bar"}"#);
let net_request = NetRequest::get(NetUrl::new("https://httpbin.org".into()))
.response_type(ResponseType::Text)
.build();
let response: NetResponse<String> = block_on(net.get_string(net_request))?;
assert_eq!(
response.response_body(),
Some(json!({"foo":"bar"}).to_string())
);
assert_eq!(
net.requests(),
vec![SavedRequest::new(
RequestMethod::Get,
"https://httpbin.org",
RequestBody::None
)]
);
Ok(())
}
#[test_log::test]
fn test_mock_network_env_get_string_error() {
let mut net = MockNetwork::new();
net.add_get_error("https://httpbin.org", "error");
let net_request = NetRequest::get(NetUrl::new("https://httpbin.org".into()))
.response_type(ResponseType::Text)
.build();
let result: Result<NetResponse<String>, NetworkError> =
block_on(net.get_string(net_request));
assert!(result.is_err());
assert_eq!(
net.requests(),
vec![SavedRequest::new(
RequestMethod::Get,
"https://httpbin.org",
RequestBody::None
)]
);
}
#[test_log::test]
fn test_mock_network_env_post() -> Result<(), NetworkError> {
let mut net = MockNetwork::new();
net.add_post_response("https://httpbin.org", StatusCode::OK, r#"{"foo": "bar"}"#);
let net_request = NetRequest::post(NetUrl::new("https://httpbin.org".into()))
.json_body(json!({}))?
.build();
let response: NetResponse<HashMap<String, String>> = block_on(net.post_json(net_request))?;
assert_eq!(
response
.response_body()
.and_then(|body| body.get("foo").cloned()),
Some("bar".to_string())
);
assert_eq!(
net.requests(),
vec![SavedRequest::new(
RequestMethod::Post,
"https://httpbin.org",
RequestBody::Json(json!({}))
)]
);
Ok(())
}
#[test_log::test]
fn test_mock_network_env_post_error() -> Result<(), NetworkError> {
let mut net = MockNetwork::new();
net.add_post_error("https://httpbin.org", "error");
let net_request = NetRequest::post(NetUrl::new("https://httpbin.org".into()))
.json_body(json!({}))?
.build();
let result: Result<NetResponse<HashMap<String, String>>, NetworkError> =
block_on(net.post_json(net_request));
assert!(result.is_err());
assert_eq!(
net.requests(),
vec![SavedRequest::new(
RequestMethod::Post,
"https://httpbin.org",
RequestBody::Json(json!({}))
)]
);
Ok(())
}
#[test_log::test]
fn test_mock_network_env_put_json() -> Result<(), NetworkError> {
let mut net = MockNetwork::new();
net.add_put_response("https://httpbin.org", StatusCode::OK, r#"{"foo": "bar"}"#);
let net_request = NetRequest::put(NetUrl::new("https://httpbin.org".into()))
.json_body(json!({}))?
.build();
let response: NetResponse<HashMap<String, String>> = block_on(net.put_json(net_request))?;
assert_eq!(
response
.response_body()
.and_then(|body| body.get("foo").cloned()),
Some("bar".to_string())
);
assert_eq!(
net.requests(),
vec![SavedRequest::new(
RequestMethod::Put,
"https://httpbin.org",
RequestBody::Json(json!({}))
)]
);
Ok(())
}
#[test_log::test]
fn test_mock_network_env_put_json_error() -> Result<(), NetworkError> {
let mut net = MockNetwork::new();
net.add_put_error("https://httpbin.org", "error");
let net_request = NetRequest::put(NetUrl::new("https://httpbin.org".into()))
.json_body(json!({}))?
.build();
let result: Result<NetResponse<HashMap<String, String>>, NetworkError> =
block_on(net.put_json(net_request));
assert!(result.is_err());
assert_eq!(
net.requests(),
vec![SavedRequest::new(
RequestMethod::Put,
"https://httpbin.org",
RequestBody::Json(json!({}))
)]
);
Ok(())
}
#[test_log::test]
fn test_mock_network_env_put_string() -> Result<(), NetworkError> {
let mut net = MockNetwork::new();
net.add_put_response("https://httpbin.org", StatusCode::OK, r#"{"foo":"bar"}"#);
let net_request = NetRequest::put(NetUrl::new("https://httpbin.org".into()))
.string_body("PLAIN-TEXT".to_string())
.build();
let response: NetResponse<String> = block_on(net.put_string(net_request))?;
assert_eq!(
response.response_body(),
Some(json!({"foo":"bar"}).to_string())
);
assert_eq!(
net.requests(),
vec![SavedRequest::new(
RequestMethod::Put,
"https://httpbin.org",
RequestBody::String("PLAIN-TEXT".to_string()),
)]
);
Ok(())
}
#[test_log::test]
fn test_mock_network_env_put_string_error() {
let mut net = MockNetwork::new();
net.add_put_error("https://httpbin.org", "error");
let net_request = NetRequest::put(NetUrl::new("https://httpbin.org".into()))
.string_body("PLAIN-TEXT".to_string())
.build();
let result: Result<NetResponse<String>, NetworkError> =
block_on(net.put_string(net_request));
assert!(result.is_err());
assert_eq!(
net.requests(),
vec![SavedRequest::new(
RequestMethod::Put,
"https://httpbin.org",
RequestBody::String("PLAIN-TEXT".to_string())
)]
);
}
#[derive(Debug, PartialEq, Eq, serde::Deserialize)]
struct PropfindTestResponse {
pub foo: String,
}
#[test_log::test]
fn test_mock_network_env_propfind() -> Result<(), NetworkError> {
let mut net = MockNetwork::new();
net.add_propfind_response(
"https://caldav.org",
StatusCode::OK,
r#"<container><foo>bar</foo></container>"#,
);
let net_request = NetRequest::propfind(NetUrl::new("https://caldav.org".into()))
.response_type(ResponseType::Xml)
.build();
let response: NetResponse<PropfindTestResponse> = block_on(net.propfind(net_request))?;
assert_eq!(
response.response_body().map(|body| body.foo),
Some("bar".to_string())
);
assert_eq!(
net.requests(),
vec![SavedRequest::new(
RequestMethod::Propfind,
"https://caldav.org",
RequestBody::None
)]
);
Ok(())
}
#[test_log::test]
fn test_mock_network_env_propfind_string() -> Result<(), NetworkError> {
let mut net = MockNetwork::new();
net.add_propfind_response(
"https://caldav.org",
StatusCode::OK,
r#"<container><foo>bar</foo></container>"#,
);
let net_request = NetRequest::propfind(NetUrl::new("https://caldav.org".into()))
.response_type(ResponseType::Text)
.build();
let response: NetResponse<String> = block_on(net.propfind_string(net_request))?;
assert_eq!(
response.response_body(),
Some("<container><foo>bar</foo></container>".to_string())
);
assert_eq!(
net.requests(),
vec![SavedRequest::new(
RequestMethod::Propfind,
"https://caldav.org",
RequestBody::None
)]
);
Ok(())
}
#[test_log::test]
fn test_mock_network_env_propfind_error() -> Result<(), NetworkError> {
let mut net = MockNetwork::new();
net.add_propfind_error("https://caldav.org", "error");
let net_request = NetRequest::propfind(NetUrl::new("https://caldav.org".into()))
.response_type(ResponseType::Xml)
.build();
let result: Result<NetResponse<PropfindTestResponse>, NetworkError> =
block_on(net.propfind(net_request));
assert!(result.is_err());
assert_eq!(
net.requests(),
vec![SavedRequest::new(
RequestMethod::Propfind,
"https://caldav.org",
RequestBody::None
)]
);
Ok(())
}
#[test_log::test]
fn test_mock_network_env_propfind_string_error() -> Result<(), NetworkError> {
let mut net = MockNetwork::new();
net.add_propfind_error("https://caldav.org", "error");
let net_request = NetRequest::propfind(NetUrl::new("https://caldav.org".into()))
.response_type(ResponseType::Text)
.build();
let result: Result<NetResponse<String>, NetworkError> =
block_on(net.propfind_string(net_request));
assert!(result.is_err());
assert_eq!(
net.requests(),
vec![SavedRequest::new(
RequestMethod::Propfind,
"https://caldav.org",
RequestBody::None
)]
);
Ok(())
}
}