use std::io;
use std::time::Duration;
use futures_util::{SinkExt, StreamExt};
use roam_session::MessageTransport;
use roam_stream::{
ConnectionError, ConnectionHandle, Driver, FramedClient, HandshakeConfig, Message,
MessageConnector, RetryPolicy, ServiceDispatcher, accept_framed, connect_framed,
connect_framed_with_policy,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::protocol::Message as WsMessage;
pub struct WsTransport<S> {
stream: WebSocketStream<S>,
last_decoded: Vec<u8>,
}
impl<S> WsTransport<S> {
pub fn new(stream: WebSocketStream<S>) -> Self {
Self {
stream,
last_decoded: Vec::new(),
}
}
pub fn stream(&self) -> &WebSocketStream<S> {
&self.stream
}
pub fn stream_mut(&mut self) -> &mut WebSocketStream<S> {
&mut self.stream
}
pub fn into_inner(self) -> WebSocketStream<S> {
self.stream
}
}
impl<S> MessageTransport for WsTransport<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
async fn send(&mut self, msg: &Message) -> io::Result<()> {
let payload = facet_postcard::to_vec(msg)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
self.stream
.send(WsMessage::Binary(payload.into()))
.await
.map_err(|e| io::Error::other(e.to_string()))?;
Ok(())
}
async fn recv_timeout(&mut self, timeout: Duration) -> io::Result<Option<Message>> {
tokio::time::timeout(timeout, self.recv())
.await
.unwrap_or(Ok(None))
}
async fn recv(&mut self) -> io::Result<Option<Message>> {
loop {
match self.stream.next().await {
Some(Ok(WsMessage::Binary(data))) => {
self.last_decoded = data.to_vec();
let msg: Message = facet_postcard::from_slice(&data).map_err(|e| {
io::Error::new(io::ErrorKind::InvalidData, format!("postcard: {e}"))
})?;
return Ok(Some(msg));
}
Some(Ok(WsMessage::Close(_))) => {
return Ok(None);
}
Some(Ok(WsMessage::Ping(data))) => {
let _ = self.stream.send(WsMessage::Pong(data)).await;
continue;
}
Some(Ok(WsMessage::Pong(_))) => {
continue;
}
Some(Ok(WsMessage::Text(_))) => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"text frames not allowed",
));
}
Some(Ok(WsMessage::Frame(_))) => {
continue;
}
Some(Err(e)) => {
return Err(io::Error::other(e.to_string()));
}
None => {
return Ok(None);
}
}
}
}
fn last_decoded(&self) -> &[u8] {
&self.last_decoded
}
}
pub async fn ws_accept<S, D>(
transport: WsTransport<S>,
config: HandshakeConfig,
dispatcher: D,
) -> Result<
(
ConnectionHandle,
roam_session::IncomingConnections,
Driver<WsTransport<S>, D>,
),
ConnectionError,
>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
D: ServiceDispatcher,
{
accept_framed(transport, config, dispatcher).await
}
pub fn ws_connect<C, D>(connector: C, config: HandshakeConfig, dispatcher: D) -> FramedClient<C, D>
where
C: MessageConnector,
D: ServiceDispatcher + Clone,
{
connect_framed(connector, config, dispatcher)
}
pub fn ws_connect_with_policy<C, D>(
connector: C,
config: HandshakeConfig,
dispatcher: D,
retry_policy: RetryPolicy,
) -> FramedClient<C, D>
where
C: MessageConnector,
D: ServiceDispatcher + Clone,
{
connect_framed_with_policy(connector, config, dispatcher, retry_policy)
}
#[cfg(test)]
mod tests {
use super::*;
use roam_stream::NoDispatcher;
use tokio::net::TcpListener;
use tokio_tungstenite::{accept_async, connect_async};
#[tokio::test]
async fn websocket_hello_exchange() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let ws_url = format!("ws://{}", addr);
let config = HandshakeConfig::default();
let server_config = config.clone();
let server_handle = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let ws_stream = accept_async(stream).await.unwrap();
let transport = WsTransport::new(ws_stream);
ws_accept(transport, server_config, NoDispatcher).await
});
let (ws_stream, _) = connect_async(&ws_url).await.unwrap();
let transport = WsTransport::new(ws_stream);
let (client_handle, _client_incoming, client_driver) =
accept_framed(transport, config, NoDispatcher)
.await
.unwrap();
tokio::spawn(client_driver.run());
let (server_handle_result, _server_incoming, server_driver) =
server_handle.await.unwrap().unwrap();
tokio::spawn(server_driver.run());
let _ = client_handle;
let _ = server_handle_result;
}
}