use crate::mock_server::hyper::run_server;
use crate::mock_set::MockId;
use crate::mock_set::MountedMockSet;
use crate::{mock::Mock, verification::VerificationOutcome, Request};
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio::task::LocalSet;
pub(crate) struct BareMockServer {
state: Arc<RwLock<MockServerState>>,
server_address: SocketAddr,
_shutdown_trigger: tokio::sync::oneshot::Sender<()>,
}
pub(super) struct MockServerState {
mock_set: MountedMockSet,
received_requests: Option<Vec<Request>>,
}
impl MockServerState {
pub(super) async fn handle_request(
&mut self,
request: Request,
) -> (http_types::Response, Option<futures_timer::Delay>) {
if let Some(received_requests) = &mut self.received_requests {
received_requests.push(request.clone());
}
self.mock_set.handle_request(request).await
}
}
impl BareMockServer {
pub(super) async fn start(listener: TcpListener, request_recording: RequestRecording) -> Self {
let (shutdown_trigger, shutdown_receiver) = tokio::sync::oneshot::channel();
let received_requests = match request_recording {
RequestRecording::Enabled => Some(Vec::new()),
RequestRecording::Disabled => None,
};
let state = Arc::new(RwLock::new(MockServerState {
mock_set: MountedMockSet::new(),
received_requests,
}));
let server_address = listener
.local_addr()
.expect("Failed to get server address.");
let server_state = state.clone();
std::thread::spawn(move || {
let server_future = run_server(listener, server_state, shutdown_receiver);
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Cannot build local tokio runtime");
LocalSet::new().block_on(&runtime, server_future)
});
for _ in 0..40 {
if TcpStream::connect_timeout(&server_address, std::time::Duration::from_millis(25))
.is_ok()
{
break;
}
futures_timer::Delay::new(std::time::Duration::from_millis(25)).await;
}
Self {
state,
server_address,
_shutdown_trigger: shutdown_trigger,
}
}
pub(crate) async fn register(&self, mock: Mock) {
self.state.write().await.mock_set.register(mock);
}
pub async fn register_as_scoped(&self, mock: Mock) -> MockGuard {
let mock_id = self.state.write().await.mock_set.register(mock);
MockGuard {
mock_id,
server_state: self.state.clone(),
}
}
pub(crate) async fn reset(&self) {
let mut state = self.state.write().await;
state.mock_set.reset();
if let Some(received_requests) = &mut state.received_requests {
received_requests.clear();
}
}
pub(crate) async fn verify(&self) -> VerificationOutcome {
let mock_set = &self.state.read().await.mock_set;
mock_set.verify_all()
}
pub(crate) fn uri(&self) -> String {
format!("http://{}", self.server_address)
}
pub(crate) fn address(&self) -> &SocketAddr {
&self.server_address
}
pub(crate) async fn received_requests(&self) -> Option<Vec<Request>> {
let state = self.state.read().await;
state.received_requests.to_owned()
}
}
pub(super) enum RequestRecording {
Enabled,
Disabled,
}
#[must_use = "All *_scoped methods return a `MockGuard`.
This guard MUST be bound to a variable (e.g. _mock_guard), \
otherwise the mock will immediately be unmounted (and its expectations checked).
Check `wiremock`'s documentation on scoped mocks for more details."]
pub struct MockGuard {
mock_id: MockId,
server_state: Arc<RwLock<MockServerState>>,
}
impl Drop for MockGuard {
fn drop(&mut self) {
let future = async move {
let MockGuard {
mock_id,
server_state,
} = self;
let mut state = server_state.write().await;
let report = state.mock_set.verify(*mock_id);
if !report.is_satisfied() {
let received_requests_message = if let Some(received_requests) =
&state.received_requests
{
if received_requests.is_empty() {
"The server did not receive any request.".into()
} else {
format!(
"Received requests:\n{}",
received_requests
.iter()
.enumerate()
.map(|(index, request)| {
format!(
"- Request #{}\n{}",
index + 1,
&format!("\t{}", request)
)
})
.collect::<String>()
)
}
} else {
"Enable request recording on the mock server to get the list of incoming requests as part of the panic message.".into()
};
let verifications_error = format!("- {}\n", report.error_message());
let error_message = format!(
"Verification failed for a scoped mock:\n{}\n{}",
verifications_error, received_requests_message
);
if std::thread::panicking() {
log::debug!("{}", &error_message);
} else {
panic!("{}", &error_message);
}
} else {
state.mock_set.deactivate(*mock_id);
}
};
futures::executor::block_on(future)
}
}