use crate::mock_server::hyper::run_server;
use crate::mock_set::MockId;
use crate::mock_set::MountedMockSet;
use crate::request::BodyPrintLimit;
use crate::{ErrorResponse, Request, mock::Mock, verification::VerificationOutcome};
use http_body_util::Full;
use hyper::body::Bytes;
use std::fmt::{Debug, Write};
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::pin::pin;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use tokio::sync::Notify;
use tokio::sync::RwLock;
pub(crate) struct BareMockServer {
state: Arc<RwLock<MockServerState>>,
server_address: SocketAddr,
_shutdown_trigger: tokio::sync::watch::Sender<()>,
}
pub(super) struct MockServerState {
mock_set: MountedMockSet,
received_requests: Option<Vec<Request>>,
body_print_limit: BodyPrintLimit,
}
impl MockServerState {
pub(super) async fn handle_request(
&mut self,
request: Request,
) -> Result<(hyper::Response<Full<Bytes>>, Option<tokio::time::Sleep>), ErrorResponse> {
if let Some(received_requests) = &mut self.received_requests {
received_requests.push(request.clone());
}
self.mock_set.handle_request(request).await
}
}
impl BareMockServer {
pub(super) async fn start(
listener: TcpListener,
request_recording: RequestRecording,
body_print_limit: BodyPrintLimit,
) -> Self {
let (shutdown_trigger, shutdown_receiver) = tokio::sync::watch::channel(());
let received_requests = match request_recording {
RequestRecording::Enabled => Some(Vec::new()),
RequestRecording::Disabled => None,
};
let state = Arc::new(RwLock::new(MockServerState {
mock_set: MountedMockSet::new(body_print_limit),
received_requests,
body_print_limit,
}));
let server_address = listener
.local_addr()
.expect("Failed to get server address.");
let server_state = state.clone();
std::thread::spawn(move || {
let server_future = run_server(listener, server_state, shutdown_receiver);
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Cannot build local tokio runtime");
runtime.block_on(server_future);
});
for _ in 0..40 {
if TcpStream::connect_timeout(&server_address, std::time::Duration::from_millis(25))
.is_ok()
{
break;
}
tokio::time::sleep(std::time::Duration::from_millis(25)).await;
}
Self {
state,
server_address,
_shutdown_trigger: shutdown_trigger,
}
}
pub(crate) async fn register(&self, mock: Mock) {
self.state.write().await.mock_set.register(mock);
}
pub async fn register_as_scoped(&self, mock: Mock) -> MockGuard {
let (notify, mock_id) = self.state.write().await.mock_set.register(mock);
MockGuard {
notify,
mock_id,
server_state: self.state.clone(),
}
}
pub(crate) async fn reset(&self) {
let mut state = self.state.write().await;
state.mock_set.reset();
if let Some(received_requests) = &mut state.received_requests {
received_requests.clear();
}
}
pub(crate) async fn verify(&self) -> VerificationOutcome {
let mock_set = &self.state.read().await.mock_set;
mock_set.verify_all()
}
pub(crate) fn uri(&self) -> String {
format!("http://{}", self.server_address)
}
pub(crate) fn address(&self) -> &SocketAddr {
&self.server_address
}
pub(crate) async fn body_print_limit(&self) -> BodyPrintLimit {
self.state.read().await.body_print_limit
}
pub(crate) async fn received_requests(&self) -> Option<Vec<Request>> {
let state = self.state.read().await;
state.received_requests.clone()
}
}
impl Debug for BareMockServer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "BareMockServer {{ address: {} }}", self.address())
}
}
pub(super) enum RequestRecording {
Enabled,
Disabled,
}
#[must_use = "All *_scoped methods return a `MockGuard`.
This guard MUST be bound to a variable (e.g. _mock_guard), \
otherwise the mock will immediately be unmounted (and its expectations checked).
Check `wiremock`'s documentation on scoped mocks for more details."]
pub struct MockGuard {
mock_id: MockId,
server_state: Arc<RwLock<MockServerState>>,
notify: Arc<(Notify, AtomicBool)>,
}
impl MockGuard {
pub async fn received_requests(&self) -> Vec<crate::Request> {
let state = self.server_state.read().await;
let (mounted_mock, _) = &state.mock_set[self.mock_id];
mounted_mock.received_requests()
}
pub async fn wait_until_satisfied(&self) {
let (notify, flag) = &*self.notify;
let mut notification = pin!(notify.notified());
notification.as_mut().enable();
if flag.load(std::sync::atomic::Ordering::Acquire) {
return;
}
notification.await;
}
}
impl Drop for MockGuard {
fn drop(&mut self) {
let future = async move {
let MockGuard {
mock_id,
server_state,
..
} = self;
let mut state = server_state.write().await;
let report = state.mock_set.verify(*mock_id);
if !report.is_satisfied() {
let received_requests_message = if let Some(received_requests) =
&state.received_requests
{
if received_requests.is_empty() {
"The server did not receive any request.".into()
} else {
received_requests.iter().enumerate().fold(
"Received requests:\n".to_string(),
|mut message, (index, request)| {
_ = write!(message, "- Request #{}\n\t", index + 1,);
_ = request.print_with_limit(&mut message, state.body_print_limit);
message
},
)
}
} else {
"Enable request recording on the mock server to get the list of incoming requests as part of the panic message.".into()
};
let verifications_error = format!("- {}\n", report.error_message());
let error_message = format!(
"Verification failed for a scoped mock:\n{}\n{}",
verifications_error, received_requests_message
);
if std::thread::panicking() {
log::debug!("{}", &error_message);
} else {
panic!("{}", &error_message);
}
} else {
state.mock_set.deactivate(*mock_id);
}
};
futures::executor::block_on(future);
}
}