use crate::WebSocketStreamType;
use bytes::Bytes;
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::{fmt::Debug, future::Future};
use tokio::net::TcpListener;
use tokio_tungstenite::{
MaybeTlsStream,
tungstenite::{
self, Message,
protocol::{CloseFrame, frame::coding::CloseCode},
},
};
#[cfg(feature = "tracing")]
use tracing::debug;
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, PartialOrd)]
pub enum EchoControlMessage {
Message(String),
SendPing,
Close,
}
pub async fn echo_server(ws: WebSocketStreamType) -> Result<bool, tungstenite::Error> {
let (mut sink, mut stream) = ws.split();
let mut shutting_down = false;
while let Some(message) = stream.next().await {
match &message {
Ok(Message::Text(text_bytes)) => {
let control_message: EchoControlMessage = serde_json::from_str(text_bytes).unwrap();
match control_message {
EchoControlMessage::Message(_) => {
sink.send(message.unwrap()).await?;
}
EchoControlMessage::SendPing => {
sink.send(Message::Ping(Bytes::new())).await?;
}
EchoControlMessage::Close => {
sink.send(Message::Close(Some(CloseFrame {
code: CloseCode::Normal,
reason: "".into(),
})))
.await?;
shutting_down = true;
}
}
}
Ok(Message::Ping(_)) => {
sink.send(Message::Pong(Bytes::new())).await.unwrap();
}
Ok(Message::Pong(_)) => {
#[cfg(feature = "tracing")]
debug!("Received Pong");
}
Ok(Message::Close(_)) => {
if !shutting_down {
sink.close().await.unwrap();
drop(stream);
}
#[cfg(feature = "tracing")]
debug!("Server received close request");
break;
}
_ => {}
}
}
Ok(shutting_down)
}
pub async fn auth_echo_server(ws: WebSocketStreamType) -> Result<bool, tungstenite::Error> {
let (mut sink, mut stream) = ws.split();
if let Some(Ok(Message::Text(text))) = stream.next().await {
#[derive(serde::Deserialize)]
struct AuthMsg {
action: String,
token: String,
}
if let Ok(auth) = serde_json::from_str::<AuthMsg>(&text) {
if auth.action == "auth" && auth.token == "test-token" {
sink.send(Message::Text(r#"{"status":"authenticated"}"#.into()))
.await?;
} else {
sink.send(Message::Text(r#"{"status":"error"}"#.into()))
.await?;
return Ok(true);
}
}
}
let ws = sink.reunite(stream).unwrap();
echo_server(ws).await
}
pub async fn get_mock_address<F, R>(socket_handler: F) -> SocketAddr
where
F: Fn(WebSocketStreamType) -> R + Send + Sync + 'static,
R: Future<Output = Result<bool, tungstenite::Error>> + Send + Sync + 'static,
{
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let server_address = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
let (tcp_stream, _socket_addr) = listener.accept().await.unwrap();
let ws = tokio_tungstenite::accept_async(MaybeTlsStream::Plain(tcp_stream))
.await
.unwrap();
if socket_handler(ws).await.unwrap() {
break;
}
}
});
server_address
}