use crate::{
common::{Id, Message as S2Message, ReceptionStatus, ReceptionStatusValues},
connection::{ConnectionError, S2Connection},
transport::S2Transport,
};
use futures_util::{SinkExt, StreamExt};
use std::str::FromStr;
use thiserror::Error;
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio_tungstenite::{
MaybeTlsStream, WebSocketStream,
tungstenite::{self, client::IntoClientRequest, protocol::Message as TungsteniteMessage},
};
#[derive(Error, Debug)]
pub enum WebsocketTransportError {
#[error("error originating from the internal TCPListener: {0}")]
WebsocketServerError(#[from] tokio::io::Error),
#[error("error from websocket connection: {0}")]
WebsocketError(#[from] tungstenite::Error),
#[error("the websocket has closed")]
WebsocketClosed,
#[error("received a websocket message in a binary format")]
ReceivedBinaryMessage,
#[error("error parsing a received JSON message into a valid S2 message: {0}")]
MessageParseError(#[from] serde_json::Error),
}
pub struct WebsocketServer {
listener: TcpListener,
}
impl WebsocketServer {
pub async fn new(addr: impl ToSocketAddrs) -> Result<Self, ConnectionError<WebsocketTransportError>> {
Ok(Self {
listener: TcpListener::bind(addr)
.await
.map_err(Into::into)
.map_err(ConnectionError::TransportError)?,
})
}
pub async fn accept_connection(&self) -> Result<S2Connection<WebsocketTransport>, ConnectionError<WebsocketTransportError>> {
let (tcp_stream, _) = self
.listener
.accept()
.await
.map_err(Into::into)
.map_err(ConnectionError::TransportError)?;
let ws_stream = tokio_tungstenite::accept_async(tcp_stream)
.await
.map_err(Into::into)
.map_err(ConnectionError::TransportError)?;
let ws_transport = WebsocketTransport::from_server_socket(ws_stream);
Ok(S2Connection::new(ws_transport))
}
}
pub async fn connect_as_client(
url: impl IntoClientRequest + Unpin,
) -> Result<S2Connection<WebsocketTransport>, ConnectionError<WebsocketTransportError>> {
let (socket, _) = tokio_tungstenite::connect_async(url)
.await
.map_err(Into::into)
.map_err(ConnectionError::TransportError)?;
let ws_transport = WebsocketTransport::from_client_socket(socket);
Ok(S2Connection::new(ws_transport))
}
pub enum WebsocketTransport {
ClientSocket(WebSocketStream<MaybeTlsStream<TcpStream>>),
ServerSocket(WebSocketStream<TcpStream>),
}
impl WebsocketTransport {
fn from_client_socket(socket: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
WebsocketTransport::ClientSocket(socket)
}
fn from_server_socket(socket: WebSocketStream<TcpStream>) -> Self {
WebsocketTransport::ServerSocket(socket)
}
}
impl S2Transport for WebsocketTransport {
type TransportError = WebsocketTransportError;
async fn send(&mut self, message: S2Message) -> Result<(), WebsocketTransportError> {
let serialized =
serde_json::to_string(&message).expect("unable to seralize `Message` to JSON; if you see this, you've found a bug in s2energy");
let tungstenite_message = TungsteniteMessage::text(serialized);
match self {
Self::ClientSocket(socket) => socket.send(tungstenite_message).await?,
Self::ServerSocket(socket) => socket.send(tungstenite_message).await?,
}
Ok(())
}
async fn receive(&mut self) -> Result<S2Message, WebsocketTransportError> {
let message = loop {
let next = match self {
Self::ClientSocket(socket) => socket.next().await,
Self::ServerSocket(socket) => socket.next().await,
};
let Some(msg) = next else {
return Err(WebsocketTransportError::WebsocketClosed);
};
let msg = msg?;
if msg.is_binary() {
tracing::warn!("Received binary websocket message, which is not supported. Sending ReceptionStatus INVALID_DATA.");
let _ = self
.send(
ReceptionStatus {
diagnostic_label: Some("Binary messages are not supported".to_string()),
status: ReceptionStatusValues::InvalidData,
subject_message_id: Id::from_str("00000000-0000-0000-0000-000000000000").unwrap(),
}
.into(),
)
.await;
return Err(WebsocketTransportError::ReceivedBinaryMessage);
} else if msg.is_close() {
tracing::info!("Received a websocket close message");
return Err(WebsocketTransportError::WebsocketClosed);
} else if msg.is_text() {
let msg_string = msg
.into_text()
.expect("Encountering a panic here should be impossible; please report a bug in s2energy if you encounter this anyway");
let msg_parsed = match serde_json::from_str(&msg_string) {
Ok(msg) => msg,
Err(err) => {
tracing::warn!("Failed to parse incoming message. Message: {msg_string}. Error: {err}");
let _ = self
.send(
ReceptionStatus {
diagnostic_label: Some(format!("Failed to parse message. Error: {err}")),
status: ReceptionStatusValues::InvalidData,
subject_message_id: Id::from_str("00000000-0000-0000-0000-000000000000").unwrap(),
}
.into(),
)
.await;
return Err(err.into());
}
};
break msg_parsed;
}
};
Ok(message)
}
async fn disconnect(self) {
let msg = TungsteniteMessage::Close(None);
let _ = match self {
Self::ClientSocket(mut socket) => socket.send(msg).await,
Self::ServerSocket(mut socket) => socket.send(msg).await,
};
}
}