use async_trait::async_trait;
use mockall::automock;
use reqwest::Response;
#[automock]
#[async_trait]
pub trait HttpClient: Send + Sync {
async fn post(
&self,
url: String,
body: Vec<u8>,
headers: Vec<(String, String)>,
) -> Result<Response, reqwest::Error>;
async fn get(
&self,
url: String,
headers: Vec<(String, String)>,
) -> Result<Response, reqwest::Error>;
}
pub struct ReqwestClient {
client: reqwest::Client,
}
impl ReqwestClient {
pub fn new() -> Self {
Self {
client: reqwest::Client::new(),
}
}
pub fn with_client(client: reqwest::Client) -> Self {
Self { client }
}
}
#[async_trait]
impl HttpClient for ReqwestClient {
async fn post(
&self,
url: String,
body: Vec<u8>,
headers: Vec<(String, String)>,
) -> Result<Response, reqwest::Error> {
let mut request = self.client.post(url);
for (key, value) in headers {
request = request.header(key, value);
}
request.body(body).send().await
}
async fn get(
&self,
url: String,
headers: Vec<(String, String)>,
) -> Result<Response, reqwest::Error> {
let mut request = self.client.get(url);
for (key, value) in headers {
request = request.header(key, value);
}
request.send().await
}
}
#[automock]
pub trait HttpClientFactory {
fn create_client(&self) -> Box<dyn HttpClient>;
}
pub struct ReqwestClientFactory;
impl HttpClientFactory for ReqwestClientFactory {
fn create_client(&self) -> Box<dyn HttpClient> {
Box::new(ReqwestClient::new())
}
}
impl Default for ReqwestClientFactory {
fn default() -> Self {
Self
}
}
impl Default for ReqwestClient {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use mockall::predicate;
#[tokio::test]
async fn test_mock_http_client() {
let mut mock = MockHttpClient::new();
mock.expect_post()
.with(
predicate::eq("https://test.example.com".to_string()),
predicate::always(),
predicate::always(),
)
.times(1)
.returning(|_, _, _| {
Ok(reqwest::Response::from(
http::Response::builder()
.status(200)
.body("Test Response")
.unwrap(),
))
});
let result = mock
.post(
"https://test.example.com".to_string(),
b"test body".to_vec(),
vec![("Content-Type".to_string(), "text/plain".to_string())],
)
.await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.status(), 200);
}
#[tokio::test]
async fn test_mock_http_client_factory() {
let mut mock_factory = MockHttpClientFactory::new();
mock_factory.expect_create_client().times(1).returning(|| {
let mut new_mock = MockHttpClient::new();
new_mock.expect_get().returning(|_, _| {
Ok(reqwest::Response::from(
http::Response::builder()
.status(200)
.body("Factory Test")
.unwrap(),
))
});
Box::new(new_mock)
});
let client = mock_factory.create_client();
assert!(client
.get("https://example.com".to_string(), vec![])
.await
.is_ok());
}
}