use crate::matchers::Matcher;
use futures_util::stream::SplitSink;
use futures_util::{SinkExt, StreamExt};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::{TcpListener, TcpStream};
use tokio::select;
use tokio::sync::broadcast::Receiver as BroadcastReceiver;
use tokio::sync::broadcast::Sender as BroadcastSender;
use tokio::sync::mpsc::Sender as MpscSender;
use tokio::sync::mpsc::{Receiver as MpscReceiver, Sender};
use tokio::sync::{broadcast, mpsc, Notify, RwLock};
use tokio::time::sleep;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{accept_async, WebSocketStream};
use tracing::debug;
const INCOMPLETE_MOCK_PANIC: &str = "A mock must have a response or expected number of calls, or forwarding_channel set. Add `.expect(...)`, `.forward_from_channel()`, or `.respond_with(...)` before mounting the mock.";
#[derive(Debug)]
pub struct WsMock {
matchers: Vec<Box<dyn Matcher>>,
response_data: Vec<Message>,
forwarding_channel: Option<MpscReceiver<Message>>,
expected_calls: Option<usize>,
calls: usize,
}
impl Default for WsMock {
fn default() -> Self {
Self::new()
}
}
impl WsMock {
pub fn new() -> WsMock {
WsMock {
matchers: Vec::new(),
response_data: Vec::new(),
forwarding_channel: None,
expected_calls: None,
calls: 0,
}
}
pub fn matcher<T: Matcher + 'static>(mut self, matcher: T) -> Self {
self.matchers.push(Box::new(matcher));
self
}
pub fn respond_with(mut self, data: Message) -> Self {
self.response_data.push(data);
self
}
pub fn forward_from_channel(mut self, receiver: MpscReceiver<Message>) -> Self {
self.forwarding_channel = Some(receiver);
self
}
pub fn expect(mut self, n: usize) -> Self {
self.expected_calls = Some(n);
self
}
pub async fn mount(self, server: &WsMockServer) {
if self.response_data.is_empty()
&& self.expected_calls.is_none()
&& self.forwarding_channel.is_none()
{
panic!("{}", INCOMPLETE_MOCK_PANIC);
}
let mut state = server.state.write().await;
state.mount(self);
}
#[doc(hidden)]
fn matches_all(&self, text: &str) -> bool {
self.matchers.iter().all(|m| m.matches(text))
}
}
#[doc(hidden)]
struct ServerState {
connection_string: String,
ready_notify: Arc<Notify>,
mocks: Vec<WsMock>,
calls: Vec<String>,
close_sender: BroadcastSender<()>,
}
impl ServerState {
pub fn new(url: String, port: u16, notify: Arc<Notify>) -> ServerState {
let (close_sender, _) = broadcast::channel::<()>(1);
ServerState {
connection_string: format!("{}:{}", url, port),
ready_notify: notify,
mocks: Vec::new(),
calls: Vec::new(),
close_sender,
}
}
#[doc(hidden)]
fn mount(&mut self, mock: WsMock) {
self.mocks.push(mock);
}
}
pub struct WsMockServer {
state: Arc<RwLock<ServerState>>,
}
impl WsMockServer {
pub async fn start() -> WsMockServer {
let ready_notify = Arc::new(Notify::new());
let state = Arc::new(RwLock::new(ServerState::new(
"127.0.0.1".to_string(),
0,
ready_notify.clone(),
)));
let server = WsMockServer::new(state.clone());
tokio::spawn(async move { Self::listen(state).await });
ready_notify.notified().await;
server
}
#[doc(hidden)]
fn new(state: Arc<RwLock<ServerState>>) -> WsMockServer {
WsMockServer { state }
}
pub async fn get_connection_string(&self) -> String {
let state = self.state.read().await;
state.connection_string.clone()
}
pub async fn uri(&self) -> String {
format!("ws://{}", self.get_connection_string().await)
}
#[doc(hidden)]
async fn listen(state: Arc<RwLock<ServerState>>) {
let listener = Self::get_listener(state.clone()).await;
if let Ok((stream, _peer)) = listener.accept().await {
let state = state.clone();
tokio::spawn(WsMockServer::handle_connection(stream, state));
}
}
#[doc(hidden)]
async fn get_listener(state: Arc<RwLock<ServerState>>) -> TcpListener {
let mut state = state.write().await;
let listener = TcpListener::bind(state.connection_string.as_str())
.await
.expect("Failed to listen to port");
let listener_addr = listener
.local_addr()
.expect("Listener had no local address");
state.connection_string = format!("{}:{}", listener_addr.ip(), listener_addr.port());
state.ready_notify.notify_one();
listener
}
#[doc(hidden)]
async fn handle_connection(stream: TcpStream, state: Arc<RwLock<ServerState>>) {
let ws_stream = accept_async(stream)
.await
.expect("Failed to accept connection");
let (send, mut recv) = ws_stream.split();
let (mpsc_send, mpsc_recv) = mpsc::channel::<Message>(32);
Self::spawn_forwarding_tasks(state.clone(), mpsc_send.clone()).await;
{
let state_guard = state.read().await;
let broad_recv = state_guard.close_sender.subscribe();
tokio::spawn(Self::outbound_message_task(send, mpsc_recv, broad_recv));
}
while let Some(Ok(msg)) = recv.next().await {
let text = msg.to_text().expect("Message was not text").to_string();
debug!("Received: '{:?}'", text);
Self::match_mocks(state.clone(), mpsc_send.clone(), text.as_str()).await;
}
}
#[doc(hidden)]
async fn spawn_forwarding_tasks(state: Arc<RwLock<ServerState>>, sender: MpscSender<Message>) {
let mut state_guard = state.write().await;
for mock in &mut state_guard.mocks {
if let Some(forwarding_channel) = mock.forwarding_channel.take() {
tokio::spawn(Self::forward_messages_task(
forwarding_channel,
sender.clone(),
));
}
}
}
#[doc(hidden)]
async fn forward_messages_task(
mut incoming: MpscReceiver<Message>,
outgoing: MpscSender<Message>,
) {
while let Some(msg) = incoming.recv().await {
outgoing.send(msg).await.unwrap();
}
}
#[doc(hidden)]
async fn outbound_message_task(
mut sender: SplitSink<WebSocketStream<TcpStream>, Message>,
mut receiver: MpscReceiver<Message>,
mut close: BroadcastReceiver<()>,
) {
loop {
select! {
Some(msg) = receiver.recv() => sender.send(msg).await.unwrap(),
Ok(_) = close.recv() => break,
else => break
}
}
}
#[doc(hidden)]
async fn match_mocks(state: Arc<RwLock<ServerState>>, mpsc_send: Sender<Message>, text: &str) {
let mut state_guard = state.write().await;
state_guard.calls.push(text.to_string());
for mock in &mut state_guard.mocks {
if mock.matches_all(text) {
mock.calls += 1;
for data in mock.response_data.iter() {
mpsc_send.send(data.clone()).await.unwrap();
}
}
}
}
pub async fn verify(&self) {
sleep(Duration::from_millis(100)).await;
let state_guard = self.state.read().await;
let mut results = Vec::new();
for mock in &state_guard.mocks {
if let Some(expected) = mock.expected_calls {
if expected != mock.calls {
results.push(format!(
"Expected {} matching calls, but received {}\nCalled With:",
expected, mock.calls
));
}
}
}
if !results.is_empty() {
for mock_call in &state_guard.calls {
results.push(format!("\t{}", mock_call));
}
panic!("{}", results.join("\n"));
}
}
pub async fn shutdown(&mut self) {
let state_guard = self.state.read().await;
_ = state_guard.close_sender.send(());
}
}
#[cfg(test)]
mod tests {
use crate::matchers::Any;
use crate::utils::{collect_all_messages, send_to_server};
use crate::ws_mock_server::{WsMock, WsMockServer};
use futures_util::StreamExt;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::Message;
#[tokio::test]
async fn test_wss_mockserver() {
let server = WsMockServer::start().await;
WsMock::new()
.matcher(Any::new())
.expect(1)
.mount(&server)
.await;
WsMock::default()
.matcher(Any::new())
.respond_with(Message::Text("Mock-2".to_string()))
.expect(1)
.mount(&server)
.await;
let recv = send_to_server(&server, "{ data: [42] }".into()).await;
let received = collect_all_messages(recv, Duration::from_millis(250)).await;
server.verify().await;
assert_eq!(vec![Message::Text("Mock-2".to_string())], received);
}
#[tokio::test]
async fn test_mock_other_message_type() {
let server = WsMockServer::start().await;
let message = vec![u8::MIN, u8::MAX];
WsMock::default()
.matcher(Any::new())
.respond_with(Message::Binary(message.clone()))
.expect(1)
.mount(&server)
.await;
let recv = send_to_server(&server, "{ data: [42] }".into()).await;
let received = collect_all_messages(recv, Duration::from_millis(250)).await;
server.verify().await;
assert_eq!(vec![Message::Binary(message)], received);
}
#[tokio::test]
async fn test_multiple_messages() {
let server = WsMockServer::start().await;
WsMock::new()
.matcher(Any::new())
.respond_with(Message::Text("message-1".to_string()))
.respond_with(Message::Text("message-2".to_string()))
.expect(1)
.mount(&server)
.await;
let recv = send_to_server(&server, "{ data: [42] }".into()).await;
let received = collect_all_messages(recv, Duration::from_millis(250)).await;
server.verify().await;
assert_eq!(
vec![
Message::Text("message-1".to_string()),
Message::Text("message-2".to_string())
],
received
);
}
#[tokio::test]
async fn test_forwarding_channel() {
let server = WsMockServer::start().await;
let (mpsc_send, mpsc_recv) = mpsc::channel::<Message>(32);
WsMock::new()
.matcher(Any::new())
.forward_from_channel(mpsc_recv)
.mount(&server)
.await;
let (stream, _resp) = connect_async(server.uri().await)
.await
.expect("Connecting failed");
let (_send, ws_recv) = stream.split();
mpsc_send
.send(Message::Text("message-1".to_string()))
.await
.unwrap();
mpsc_send
.send(Message::Text("message-2".into()))
.await
.unwrap();
let received = collect_all_messages(ws_recv, Duration::from_millis(250)).await;
server.verify().await;
assert_eq!(
vec![
Message::Text("message-1".to_string()),
Message::Text("message-2".to_string()),
],
received
);
}
#[tokio::test]
async fn test_shutdown_with_active_channel() {
let mut server = WsMockServer::start().await;
let (_, mpsc_recv) = mpsc::channel::<Message>(32);
WsMock::new()
.matcher(Any::new())
.forward_from_channel(mpsc_recv)
.mount(&server)
.await;
server.verify().await;
server.shutdown().await;
}
#[should_panic(expected = "Expected 2 matching calls, but received 1\nCalled With:\n\t{}")]
#[tokio::test]
async fn test_ws_mockserver_verify_failure() {
let server = WsMockServer::start().await;
WsMock::new()
.matcher(Any::new())
.respond_with(Message::Text("Mock-1".to_string()))
.expect(2)
.mount(&server)
.await;
let _recv = send_to_server(&server, "{}".into()).await;
server.verify().await;
}
#[should_panic(
expected = "A mock must have a response or expected number of calls, or forwarding_channel set. Add `.expect(...)`, `.forward_from_channel()`, or `.respond_with(...)` before mounting the mock."
)]
#[tokio::test]
async fn test_incomplete_mock_failure() {
let server = WsMockServer::start().await;
WsMock::new().matcher(Any::new()).mount(&server).await;
}
}