use futures_util::{SinkExt, StreamExt};
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use std::fmt::Debug;
use tokio::net::TcpStream;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite::Message};
enum InnerWebSocket {
Http1(WebSocketStream<MaybeTlsStream<TcpStream>>),
Http2(WebSocketStream<TokioIo<Upgraded>>),
}
pub struct TestWebSocket {
inner: InnerWebSocket,
}
impl Debug for TestWebSocket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TestWebSocket(..)").finish()
}
}
impl TestWebSocket {
fn new(inner: InnerWebSocket) -> Self {
Self { inner }
}
pub(crate) fn from_http1(ws: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
Self::new(InnerWebSocket::Http1(ws))
}
pub(crate) fn from_http2(ws: WebSocketStream<TokioIo<Upgraded>>) -> Self {
Self::new(InnerWebSocket::Http2(ws))
}
pub async fn send_text(&mut self, text: &str) {
let msg = Message::Text(text.into());
match &mut self.inner {
InnerWebSocket::Http1(ws) => ws.send(msg).await.unwrap(),
InnerWebSocket::Http2(ws) => ws.send(msg).await.unwrap(),
}
}
pub async fn recv_text(&mut self) -> String {
match &mut self.inner {
InnerWebSocket::Http1(ws) => match ws.next().await {
Some(Ok(Message::Text(t))) => t.to_string(),
other => panic!("Unexpected message: {:?}", other),
},
InnerWebSocket::Http2(ws) => match ws.next().await {
Some(Ok(Message::Text(t))) => t.to_string(),
other => panic!("Unexpected message: {:?}", other),
},
}
}
pub async fn close(self) {
let _ = match self.inner {
InnerWebSocket::Http1(mut ws) => ws.close(None).await,
InnerWebSocket::Http2(mut ws) => ws.close(None).await,
};
}
}