use std::{
fmt::Debug,
marker::PhantomData,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
task::{Context, Poll},
time::Duration,
};
use futures::{future::BoxFuture, FutureExt};
use proptest::prelude::*;
use tokio::{
sync::{
broadcast::{self, error::RecvError},
oneshot, Mutex,
},
time::timeout,
};
use tower::{BoxError, Service};
const DEFAULT_PROXY_CHANNEL_SIZE: usize = 100;
pub const DEFAULT_MAX_REQUEST_DELAY: Duration = Duration::from_millis(300);
type ProxyItem<Request, Response, Error> =
Arc<Mutex<Option<ResponseSender<Request, Response, Error>>>>;
pub struct MockService<Request, Response, Assertion, Error = BoxError> {
receiver: broadcast::Receiver<ProxyItem<Request, Response, Error>>,
sender: broadcast::Sender<ProxyItem<Request, Response, Error>>,
poll_count: Arc<AtomicUsize>,
max_request_delay: Duration,
_assertion_type: PhantomData<Assertion>,
}
#[derive(Default)]
pub struct MockServiceBuilder {
proxy_channel_size: Option<usize>,
max_request_delay: Option<Duration>,
}
#[must_use = "Tests may fail if a response is not sent back to the caller"]
pub struct ResponseSender<Request, Response, Error> {
request: Request,
response_sender: oneshot::Sender<Result<Response, Error>>,
}
impl<Request, Response, Assertion, Error> Service<Request>
for MockService<Request, Response, Assertion, Error>
where
Request: Send + 'static,
Response: Send + 'static,
Error: Send + 'static,
{
type Response = Response;
type Error = Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _context: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_count.fetch_add(1, Ordering::SeqCst);
Poll::Ready(Ok(()))
}
fn call(&mut self, request: Request) -> Self::Future {
let (response_sender, response_receiver) = ResponseSender::new(request);
let proxy_item = Arc::new(Mutex::new(Some(response_sender)));
let _ = self.sender.send(proxy_item);
response_receiver
.map(|response| {
response.expect("A response was not sent by the `MockService` for a request")
})
.boxed()
}
}
impl MockService<(), (), ()> {
pub fn build() -> MockServiceBuilder {
MockServiceBuilder::default()
}
}
impl MockServiceBuilder {
pub fn with_proxy_channel_size(mut self, size: usize) -> Self {
self.proxy_channel_size = Some(size);
self
}
pub fn with_max_request_delay(mut self, max_request_delay: Duration) -> Self {
self.max_request_delay = Some(max_request_delay);
self
}
pub fn for_prop_tests<Request, Response, Error>(
self,
) -> MockService<Request, Response, PropTestAssertion, Error> {
self.finish()
}
pub fn for_unit_tests<Request, Response, Error>(
self,
) -> MockService<Request, Response, PanicAssertion, Error> {
self.finish()
}
pub fn finish<Request, Response, Assertion, Error>(
self,
) -> MockService<Request, Response, Assertion, Error> {
let proxy_channel_size = self
.proxy_channel_size
.unwrap_or(DEFAULT_PROXY_CHANNEL_SIZE);
let (sender, receiver) = broadcast::channel(proxy_channel_size);
MockService {
receiver,
sender,
poll_count: Arc::new(AtomicUsize::new(0)),
max_request_delay: self.max_request_delay.unwrap_or(DEFAULT_MAX_REQUEST_DELAY),
_assertion_type: PhantomData,
}
}
}
impl<Request, Response, Error> MockService<Request, Response, PanicAssertion, Error> {
pub async fn expect_request(
&mut self,
expected: Request,
) -> ResponseSender<Request, Response, Error>
where
Request: PartialEq + Debug,
{
let response_sender = self.next_request().await;
assert_eq!(
response_sender.request,
expected,
"received an unexpected request\n \
in {}",
std::any::type_name::<Self>(),
);
response_sender
}
pub async fn expect_request_that(
&mut self,
condition: impl FnOnce(&Request) -> bool,
) -> ResponseSender<Request, Response, Error>
where
Request: Debug,
{
let response_sender = self.next_request().await;
assert!(
condition(&response_sender.request),
"condition was false for request: {:?},\n \
in {}",
response_sender.request,
std::any::type_name::<Self>(),
);
response_sender
}
pub async fn expect_no_requests(&mut self)
where
Request: Debug,
{
if let Some(response_sender) = self.try_next_request().await {
panic!(
"received an unexpected request: {:?},\n \
in {}",
response_sender.request,
std::any::type_name::<Self>(),
);
}
}
async fn next_request(&mut self) -> ResponseSender<Request, Response, Error> {
match self.try_next_request().await {
Some(request) => request,
None => panic!(
"timeout while waiting for a request\n \
in {}",
std::any::type_name::<Self>(),
),
}
}
pub fn poll_count(&self) -> usize {
self.poll_count.load(Ordering::SeqCst)
}
}
impl<Request, Response, Error> MockService<Request, Response, PropTestAssertion, Error> {
pub async fn expect_request(
&mut self,
expected: Request,
) -> Result<ResponseSender<Request, Response, Error>, TestCaseError>
where
Request: PartialEq + Debug,
{
let response_sender = self.next_request().await?;
prop_assert_eq!(
&response_sender.request,
&expected,
"received an unexpected request\n \
in {}",
std::any::type_name::<Self>(),
);
Ok(response_sender)
}
pub async fn expect_request_that(
&mut self,
condition: impl FnOnce(&Request) -> bool,
) -> Result<ResponseSender<Request, Response, Error>, TestCaseError>
where
Request: Debug,
{
let response_sender = self.next_request().await?;
prop_assert!(
condition(&response_sender.request),
"condition was false for request: {:?},\n \
in {}",
&response_sender.request,
std::any::type_name::<Self>(),
);
Ok(response_sender)
}
pub async fn expect_no_requests(&mut self) -> Result<(), TestCaseError>
where
Request: Debug,
{
match self.try_next_request().await {
Some(response_sender) => {
prop_assert!(
false,
"received an unexpected request: {:?},\n \
in {}",
response_sender.request,
std::any::type_name::<Self>(),
);
unreachable!("prop_assert!(false) returns an early error");
}
None => Ok(()),
}
}
async fn next_request(
&mut self,
) -> Result<ResponseSender<Request, Response, Error>, TestCaseError> {
match self.try_next_request().await {
Some(request) => Ok(request),
None => {
prop_assert!(
false,
"timeout while waiting for a request\n \
in {}",
std::any::type_name::<Self>(),
);
unreachable!("prop_assert!(false) returns an early error");
}
}
}
pub fn poll_count(&self) -> usize {
self.poll_count.load(Ordering::SeqCst)
}
}
impl<Request, Response, Assertion, Error> MockService<Request, Response, Assertion, Error> {
pub async fn try_next_request(&mut self) -> Option<ResponseSender<Request, Response, Error>> {
loop {
match timeout(self.max_request_delay, self.receiver.recv()).await {
Ok(Ok(item)) => {
if let Some(proxy_item) = item.lock().await.take() {
return Some(proxy_item);
}
}
Ok(Err(RecvError::Lagged(_))) => continue,
Ok(Err(RecvError::Closed)) => unreachable!("sender is never closed"),
Err(_timeout) => return None,
}
}
}
}
impl<Request, Response, Assertion, Error> Clone
for MockService<Request, Response, Assertion, Error>
{
fn clone(&self) -> Self {
MockService {
receiver: self.sender.subscribe(),
sender: self.sender.clone(),
poll_count: self.poll_count.clone(),
max_request_delay: self.max_request_delay,
_assertion_type: PhantomData,
}
}
}
impl<Request, Response, Error> ResponseSender<Request, Response, Error> {
fn new(request: Request) -> (Self, oneshot::Receiver<Result<Response, Error>>) {
let (response_sender, response_receiver) = oneshot::channel();
(
ResponseSender {
request,
response_sender,
},
response_receiver,
)
}
pub fn request(&self) -> &Request {
&self.request
}
pub fn respond(self, response: impl ResponseResult<Response, Error>) {
let _ = self.response_sender.send(response.into_result());
}
pub fn respond_with<F, R>(self, response_fn: F)
where
F: FnOnce(&Request) -> R,
R: ResponseResult<Response, Error>,
{
let response_result = response_fn(self.request()).into_result();
let _ = self.response_sender.send(response_result);
}
pub fn respond_error(self, error: Error) {
let _ = self.response_sender.send(Err(error));
}
pub fn respond_with_error<F>(self, response_fn: F)
where
F: FnOnce(&Request) -> Error,
{
let response_result = Err(response_fn(self.request()));
let _ = self.response_sender.send(response_result);
}
}
#[allow(dead_code)]
trait AssertionType {}
pub enum PanicAssertion {}
pub enum PropTestAssertion {}
impl AssertionType for PanicAssertion {}
impl AssertionType for PropTestAssertion {}
pub trait ResponseResult<Response, Error> {
fn into_result(self) -> Result<Response, Error>;
}
impl<Response, Error> ResponseResult<Response, Error> for Response {
fn into_result(self) -> Result<Response, Error> {
Ok(self)
}
}
impl<Response, SourceError, TargetError> ResponseResult<Response, TargetError>
for Result<Response, SourceError>
where
SourceError: Into<TargetError>,
{
fn into_result(self) -> Result<Response, TargetError> {
self.map_err(|source_error| source_error.into())
}
}