use std::collections::HashMap;
use std::net::TcpStream;
use std::str::FromStr;
use tungstenite::stream::MaybeTlsStream;
use tungstenite::{Bytes, ClientRequestBuilder, Error, Message, Utf8Bytes, WebSocket};
use tungstenite::handshake::client::Response;
use tungstenite::http::Uri;
use tungstenite::protocol::CloseFrame;
use crate::error::{S9Result, S9WebSocketError};
use super::options::{NonBlockingOptions, BlockingOptions};
use super::types::ControlMessage;
pub(crate) enum ControlFlow {
Continue,
Break,
}
pub(crate) fn connect_socket(uri: &str, headers: &HashMap<String, String>) -> S9Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
let uri = Uri::from_str(uri).map_err(|e| {
tracing::error!("S9WebSocketClient error connecting to invalid URI: {}", uri);
S9WebSocketError::InvalidUri(e.to_string())
})?;
let mut builder = ClientRequestBuilder::new(uri);
for (key, value) in headers {
builder = builder.with_header(key, value);
}
let (sock, response) = tungstenite::connect(builder)?;
trace_on_connected(&response);
Ok((sock, response))
}
pub(crate) fn configure_non_blocking(socket: &mut WebSocket<MaybeTlsStream<TcpStream>>, options: &NonBlockingOptions) -> S9Result<()> {
let stream = match socket.get_mut() {
MaybeTlsStream::Plain(stream) => stream,
MaybeTlsStream::NativeTls(stream) => stream.get_mut(),
_ => return Ok(()),
};
stream.set_nonblocking(true)?;
if let Some(nodelay) = options.shared.nodelay {
stream.set_nodelay(nodelay)?;
}
if let Some(ttl) = options.shared.ttl {
stream.set_ttl(ttl)?;
}
Ok(())
}
pub(crate) fn configure_blocking(socket: &mut WebSocket<MaybeTlsStream<TcpStream>>, options: &BlockingOptions) -> S9Result<()> {
let stream = match socket.get_mut() {
MaybeTlsStream::Plain(stream) => stream,
MaybeTlsStream::NativeTls(stream) => stream.get_mut(),
_ => return Ok(()),
};
if let Some(nodelay) = options.shared.nodelay {
stream.set_nodelay(nodelay)?;
}
if let Some(ttl) = options.shared.ttl {
stream.set_ttl(ttl)?;
}
stream.set_read_timeout(options.read_timeout)?;
stream.set_write_timeout(options.write_timeout)?;
Ok(())
}
#[inline]
pub(crate) fn handle_control_message(control_msg: ControlMessage, socket: &mut WebSocket<MaybeTlsStream<TcpStream>>) -> Result<ControlFlow, String> {
match control_msg {
ControlMessage::SendText(text) => {
if let Err(e) = send_text_message_to_websocket(socket, &text) {
return Err(format!("Error sending text: {}", e));
}
Ok(ControlFlow::Continue)
},
ControlMessage::SendBinary(data) => {
if let Err(e) = send_binary_message_to_websocket(socket, data) {
return Err(format!("Error sending binary: {}", e));
}
Ok(ControlFlow::Continue)
},
ControlMessage::SendPing(data) => {
if let Err(e) = send_ping_to_websocket(socket, data) {
return Err(format!("Error sending ping: {}", e));
}
Ok(ControlFlow::Continue)
},
ControlMessage::SendPong(data) => {
if let Err(e) = send_pong_to_websocket(socket, data) {
return Err(format!("Error sending pong: {}", e));
}
Ok(ControlFlow::Continue)
},
ControlMessage::Close() => {
close_websocket_with_logging(socket, "ControlMessage::Close");
Ok(ControlFlow::Continue)
},
ControlMessage::ForceQuit() => {
if tracing::enabled!(tracing::Level::TRACE) {
tracing::trace!("Forcibly quitting message loop");
}
Ok(ControlFlow::Break)
}
}
}
pub(crate) fn handle_read_error(error: Error) -> (Option<String>, bool) {
match error {
Error::Io(io_err) if io_err.kind() == std::io::ErrorKind::WouldBlock => {
(None, false)
},
Error::Io(io_err) if io_err.kind() == std::io::ErrorKind::TimedOut => {
(None, false)
},
Error::ConnectionClosed => {
let reason = "Connection closed normally".to_string();
if tracing::enabled!(tracing::Level::TRACE) {
tracing::trace!(reason);
}
(Some(reason), true)
},
e => {
let error = format!("Failed to read from socket: {:?}", e);
tracing::error!(error);
(Some(error), true)
}
}
}
#[inline]
pub(crate) fn send_text_message_to_websocket(socket: &mut WebSocket<MaybeTlsStream<TcpStream>>, text: &str) -> S9Result<()> {
socket.send(Message::text(text))
.map(|_| {
if tracing::enabled!(tracing::Level::TRACE) {
tracing::trace!("Sent text message: {}", text);
}
})
.map_err(|e| {
tracing::error!("Error sending text message: {}", e);
S9WebSocketError::from(e).into()
})
}
#[inline]
pub(crate) fn send_binary_message_to_websocket(socket: &mut WebSocket<MaybeTlsStream<TcpStream>>, data: Vec<u8>) -> S9Result<()> {
socket.send(Message::Binary(data.into()))
.map(|_| {
if tracing::enabled!(tracing::Level::TRACE) {
tracing::trace!("Sent binary message");
}
})
.map_err(|e| {
tracing::error!("Error sending binary message: {}", e);
S9WebSocketError::from(e).into()
})
}
#[inline]
pub(crate) fn send_ping_to_websocket(socket: &mut WebSocket<MaybeTlsStream<TcpStream>>, data: Vec<u8>) -> S9Result<()> {
socket.send(Message::Ping(data.into()))
.map(|_| {
if tracing::enabled!(tracing::Level::TRACE) {
tracing::trace!("Sent ping");
}
})
.map_err(|e| {
tracing::error!("Error sending ping: {}", e);
S9WebSocketError::from(e).into()
})
}
#[inline]
pub(crate) fn send_pong_to_websocket(socket: &mut WebSocket<MaybeTlsStream<TcpStream>>, data: Vec<u8>) -> S9Result<()> {
socket.send(Message::Pong(data.into()))
.map(|_| {
if tracing::enabled!(tracing::Level::TRACE) {
tracing::trace!("Sent pong");
}
})
.map_err(|e| {
tracing::error!("Error sending pong: {}", e);
S9WebSocketError::from(e).into()
})
}
#[inline]
pub(crate) fn is_connection_closed_error(error_msg: &str) -> bool {
error_msg.contains("Connection closed") || error_msg.contains("closed")
}
pub(crate) fn close_websocket_with_logging(socket: &mut WebSocket<MaybeTlsStream<TcpStream>>, context: &str) {
if socket.can_write() {
socket.close(None)
.map(|_| {
tracing::trace!("Connection close successfully requested for context: {}", context);
})
.unwrap_or_else(|e| {
tracing::error!("Error on connection close request for context {}: {}", context, e);
});
}
}
pub(crate) fn trace_on_connected(response: &Response) {
if tracing::enabled!(tracing::Level::TRACE) {
tracing::trace!("Connected to the server");
tracing::trace!("Response HTTP code: {}", response.status());
tracing::trace!("Response contains the following headers:");
for (header, _value) in response.headers() {
tracing::trace!("* {header}");
}
}
}
#[inline]
pub(crate) fn trace_on_text_message(message: &Utf8Bytes) {
if tracing::enabled!(tracing::Level::TRACE) {
tracing::trace!("Received text message: {}", message);
}
}
#[inline]
pub(crate) fn trace_on_binary_message(bytes: &Bytes) {
if tracing::enabled!(tracing::Level::TRACE) {
tracing::trace!("Received binary message: {:?}", bytes);
}
}
#[inline]
pub(crate) fn trace_on_ping_message(bytes: &Bytes) {
if tracing::enabled!(tracing::Level::TRACE) {
tracing::trace!("Received ping frame: {}", String::from_utf8_lossy(&bytes));
}
}
#[inline]
pub(crate) fn trace_on_pong_message(bytes: &Bytes) {
if tracing::enabled!(tracing::Level::TRACE) {
tracing::trace!("Received pong frame: {}", String::from_utf8_lossy(&bytes));
}
}
pub(crate) fn trace_on_close_frame(close_frame: &Option<CloseFrame>) {
if tracing::enabled!(tracing::Level::TRACE) {
match close_frame {
Some(reason) => {
tracing::trace!("Connection closed with reason: {}", reason)
},
None => {
tracing::trace!("Connection closed without reason")
},
}
}
}
#[inline]
pub(crate) fn trace_on_frame() {
if tracing::enabled!(tracing::Level::TRACE) {
tracing::trace!("Received frame from server");
}
}