use async_trait::async_trait;
use azure_core::{
http::{request::Request, AsyncRawResponse, HttpClient},
Result,
};
use futures::{future::BoxFuture, lock::Mutex};
use std::fmt;
pub struct MockHttpClient<C>(Mutex<C>);
impl<C> MockHttpClient<C>
where
C: FnMut(&Request) -> BoxFuture<'_, Result<AsyncRawResponse>> + Send + Sync,
{
pub fn new(client: C) -> Self {
Self(Mutex::new(client))
}
}
impl<C> fmt::Debug for MockHttpClient<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(stringify!("MockHttpClient"))
}
}
#[async_trait]
impl<C> HttpClient for MockHttpClient<C>
where
C: FnMut(&Request) -> BoxFuture<'_, Result<AsyncRawResponse>> + Send + Sync,
{
async fn execute_request(&self, req: &Request) -> Result<AsyncRawResponse> {
let mut client = self.0.lock().await;
(client)(req).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::FutureExt as _;
#[tokio::test]
async fn mock_http_client() {
use azure_core::http::{
headers::{HeaderName, Headers},
Method, StatusCode,
};
use std::sync::{Arc, Mutex};
const COUNT_HEADER: HeaderName = HeaderName::from_static("x-count");
let count = Arc::new(Mutex::new(0));
let mock_client = Arc::new(MockHttpClient::new(|req| {
let count = count.clone();
async move {
assert_eq!(req.url().host_str(), Some("localhost"));
if req.headers().get_optional_str(&COUNT_HEADER).is_some() {
let mut count = count.lock().unwrap();
*count += 1;
}
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
vec![],
))
}
.boxed()
})) as Arc<dyn HttpClient>;
let req = Request::new("https://localhost".parse().unwrap(), Method::Get);
mock_client.execute_request(&req).await.unwrap();
let mut req = Request::new("https://localhost".parse().unwrap(), Method::Get);
req.insert_header(COUNT_HEADER, "true");
mock_client.execute_request(&req).await.unwrap();
assert_eq!(*count.lock().unwrap(), 1);
}
}