use bytes::Bytes;
use core::{
error::Error,
fmt::{Debug, Display, Formatter, self},
hash::BuildHasher,
pin::Pin,
};
use futures_util::stream::{Stream, self};
#[allow(clippy::useless_attribute, reason = "Not useless! Here for the false positive")]
#[allow(clippy::allow_attributes, reason = "False positive lint")]
#[allow(unused_imports, reason = "False positive due to mocks")]
use mockall::{Sequence, concretize, mock};
use reqwest::{
Body,
IntoUrl,
StatusCode,
Url,
header::{HeaderMap, HeaderName, CONTENT_LENGTH, CONTENT_TYPE},
};
use serde::{Serialize, de::DeserializeOwned};
use serde_json::from_slice as from_json_slice;
use std::{
collections::HashMap,
sync::Arc,
};
mock! {
pub Client {
#[concretize]
pub fn delete<U: IntoUrl>(&self, url: U) -> MockRequestBuilder;
#[concretize]
pub fn get<U: IntoUrl>(&self, url: U) -> MockRequestBuilder;
#[concretize]
pub fn patch<U: IntoUrl>(&self, url: U) -> MockRequestBuilder;
#[concretize]
pub fn post<U: IntoUrl>(&self, url: U) -> MockRequestBuilder;
#[concretize]
pub fn put<U: IntoUrl>(&self, url: U) -> MockRequestBuilder;
}
impl Clone for Client {
fn clone(&self) -> Self {
Self {}
}
}
impl Debug for Client {
#[concretize]
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "Mocked Reqwest client")
}
}
}
mock! {
pub RequestBuilder {
pub async fn send(&self) -> Result<MockResponse, MockError>;
}
}
impl MockRequestBuilder {
#[expect(unused_mut, reason = "Needed for compatibility with the real Reqwest")]
#[expect(clippy::needless_pass_by_value, reason = "Needed for compatibility with the real Reqwest")]
#[must_use]
pub fn body<T: Into<Body>>(mut self, _body: T) -> Self {
self
}
#[expect(unused_mut, reason = "Needed for compatibility with the real Reqwest")]
#[must_use]
pub const fn form<T: Serialize + ?Sized>(mut self, _form: &T) -> Self {
self
}
#[expect(unused_mut, reason = "Needed for compatibility with the real Reqwest")]
#[expect(clippy::needless_pass_by_value, reason = "Needed for compatibility with the real Reqwest")]
#[must_use]
pub fn headers(mut self, _headers: HeaderMap) -> Self {
self
}
#[expect(unused_mut, reason = "Needed for compatibility with the real Reqwest")]
#[must_use]
pub const fn json<T: Serialize + ?Sized>(mut self, _json: &T) -> Self {
self
}
}
#[expect(clippy::struct_excessive_bools, reason = "Acceptable here")]
#[derive(Clone, Debug, Default, Eq, PartialEq)]
#[non_exhaustive]
pub struct MockError {
pub is_body: bool,
pub is_builder: bool,
pub is_connect: bool,
pub is_decode: bool,
pub is_redirect: bool,
pub is_request: bool,
pub is_status: bool,
pub is_timeout: bool,
pub status: Option<StatusCode>,
pub url: Option<Url>,
}
impl MockError {
#[must_use]
pub const fn is_body(&self) -> bool {
self.is_body
}
#[must_use]
pub const fn is_builder(&self) -> bool {
self.is_builder
}
#[must_use]
pub const fn is_connect(&self) -> bool {
self.is_connect
}
#[must_use]
pub const fn is_decode(&self) -> bool {
self.is_decode
}
#[must_use]
pub const fn is_redirect(&self) -> bool {
self.is_redirect
}
#[must_use]
pub const fn is_request(&self) -> bool {
self.is_request
}
#[must_use]
pub const fn is_status(&self) -> bool {
self.is_status
}
#[must_use]
pub const fn is_timeout(&self) -> bool {
self.is_timeout
}
#[must_use]
pub const fn status(&self) -> Option<StatusCode> {
self.status
}
#[must_use]
pub const fn url(&self) -> Option<&Url> {
self.url.as_ref()
}
pub fn url_mut(&mut self) -> Option<&mut Url> {
self.url.as_mut()
}
#[must_use]
pub fn with_url(mut self, url: Url) -> Self {
self.url = Some(url);
self
}
#[must_use]
pub fn without_url(mut self) -> Self {
self.url = None;
self
}
}
impl Display for MockError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "Mocked Reqwest error")
}
}
impl Error for MockError {}
#[derive(Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub struct MockResponse {
pub url: Url,
pub status: StatusCode,
pub headers: HeaderMap,
pub body: Result<Arc<Bytes>, MockError>,
}
#[expect(clippy::unused_async, reason = "Needed for compatibility with the real Reqwest")]
impl MockResponse {
pub async fn bytes(&self) -> Result<Bytes, MockError> {
self.body.clone().map(|bytes| (*bytes).clone())
}
#[must_use]
pub fn bytes_stream(&self) -> Pin<Box<dyn Stream<Item = Result<Bytes, MockError>> + Send>> {
let body = self.body.clone();
Box::pin(stream::once(async move { body.map(|bytes| (*bytes).clone()) }))
}
pub fn error_for_status(self) -> Result<Self, MockError> {
let status = self.status();
if status.is_client_error() || status.is_server_error() {
Err(MockError {
is_status: true,
status: Some(status),
url: Some(self.url),
..Default::default()
})
} else {
Ok(self)
}
}
pub fn error_for_status_ref(&self) -> Result<&Self, MockError> {
let status = self.status();
if status.is_client_error() || status.is_server_error() {
Err(MockError {
is_status: true,
status: Some(status),
url: Some(self.url.clone()),
..Default::default()
})
} else {
Ok(self)
}
}
#[must_use]
pub const fn headers(&self) -> &HeaderMap {
&self.headers
}
pub async fn json<T: DeserializeOwned>(&self) -> Result<T, MockError> {
self.bytes().await.map(|bytes| from_json_slice(&bytes).unwrap())
}
#[must_use]
pub const fn status(&self) -> StatusCode {
self.status
}
pub async fn text(&self) -> Result<String, MockError> {
self.bytes().await.map(|bytes| String::from_utf8(bytes.to_vec()).unwrap())
}
#[must_use]
pub const fn url(&self) -> &Url {
&self.url
}
}
#[must_use]
pub fn create_mock_client<U: IntoUrl>(responses: Vec<(U, Result<MockResponse, MockError>)>) -> MockClient {
let mut mock_client = MockClient::new();
let mut sequence = Sequence::new();
for (mock_url, mock_response) in responses {
let expected_url: Url = mock_url.into_url().unwrap();
_ = mock_client.expect_get()
.withf(move |url| url.as_str() == expected_url.as_str())
.times(1)
.in_sequence(&mut sequence)
.returning(move |_| {
let mut mock_request = MockRequestBuilder::new();
let mock_response_clone = mock_response.clone();
_ = mock_request.expect_send()
.times(1)
.returning(move || mock_response_clone.clone())
;
mock_request
})
;
}
mock_client
}
pub fn create_mock_response<U, S1, S2, S3, H: BuildHasher>(
url: U,
status: StatusCode,
content_type: Option<S1>,
content_len: Option<usize>,
extra_headers: HashMap<S2, S3, H>,
body: Result<&[u8], MockError>,
) -> MockResponse
where
U: IntoUrl,
S1: Into<String>,
S2: Into<String>,
S3: Into<String>,
{
MockResponse {
url: url.into_url().unwrap(),
status,
headers: {
let mut headers = HeaderMap::new();
if let Some(ct) = content_type {
drop(headers.insert(CONTENT_TYPE, ct.into().parse().unwrap()));
}
if let Some(cl) = content_len {
drop(headers.insert(CONTENT_LENGTH, format!("{cl}").parse().unwrap()));
}
headers.extend(extra_headers.into_iter().map(|(k, v)|
(k.into().parse::<HeaderName>().unwrap(), v.into().parse().unwrap())
));
headers
},
body: body.map(|bytes| Arc::new(Bytes::copy_from_slice(bytes))),
}
}