use std::{error::Error, future::Future};
use std::fmt::{Display};
use std::sync::Arc;
use fastwebsockets::{handshake, Frame, OpCode, WebSocket, WebSocketError, WebSocketRead, WebSocketWrite};
use hyper::{
body::Bytes,
header::{CONNECTION, UPGRADE},
upgrade::Upgraded,
Request,
};
use hyper_util::rt::TokioIo;
use tokio::io::{ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::Mutex;
#[derive(Debug, Clone)]
pub enum ConnectionTransportProtocol {
Http,
Https,
Ws,
Wss,
}
impl Display for ConnectionTransportProtocol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let str = match self {
ConnectionTransportProtocol::Http => "http",
ConnectionTransportProtocol::Https => "https",
ConnectionTransportProtocol::Ws => "ws",
ConnectionTransportProtocol::Wss => "wss",
};
write!(f, "{}", str)
}
}
pub enum UrlFormat {
HostPort,
ProtocolHostPort,
Full, }
#[derive(Debug, Clone)]
pub struct ConnectionTransportConfig {
pub protocol: ConnectionTransportProtocol,
pub host: String,
pub port: u16,
pub path: String,
}
impl Default for ConnectionTransportConfig {
fn default() -> Self {
Self {
protocol: ConnectionTransportProtocol::Ws,
host: String::from("localhost"),
port: 0,
path: "session".to_string(),
}
}
}
impl ConnectionTransportConfig {
pub fn full_endpoint(&self) -> String {
format!("{}://{}{}", self.protocol, self.host_port(), self.path())
}
pub fn host_port(&self) -> String {
format!("{}:{}", self.host, self.port)
}
pub fn path(&self) -> String {
let path_str = self.path.trim_start_matches('/');
format!("/{}", path_str)
}
pub fn from_ws_url(url: &str) -> Result<Self, String> {
let (protocol_str, rest) = url
.split_once("://")
.ok_or_else(|| format!("missing '://' in URL: {}", url))?;
let protocol = match protocol_str {
"ws" => ConnectionTransportProtocol::Ws,
"wss" => ConnectionTransportProtocol::Wss,
p => return Err(format!("unsupported WebSocket protocol: {}", p)),
};
let (host_port, path_tail) = rest.split_once('/').unwrap_or((rest, ""));
let (host, port_str) = host_port
.rsplit_once(':')
.ok_or_else(|| format!("missing port in URL: {}", url))?;
let port = port_str
.parse::<u16>()
.map_err(|e| format!("invalid port '{}': {}", port_str, e))?;
Ok(Self {
protocol,
host: host.to_string(),
port,
path: format!("/{}", path_tail),
})
}
}
pub trait ConnectionTransport {
fn send(&mut self, message: String) -> impl Future<Output=()> + Send;
fn listen(&self, listener: UnboundedSender<String>) -> ();
fn close(&self) -> impl Future<Output=()> + Send;
fn on_close(&self) -> ();
}
pub struct WebsocketConnectionTransport {
client_tx: Arc<Mutex<WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>>>,
client_rx: Arc<Mutex<WebSocketRead<ReadHalf<TokioIo<Upgraded>>>>>,
}
impl ConnectionTransport for WebsocketConnectionTransport {
fn send(&mut self, message: String) -> impl Future<Output=()> + Send
{
async move {
let frame = Frame::text(fastwebsockets::Payload::from(message.as_bytes()));
self.client_tx.lock().await.write_frame(frame).await.unwrap();
}
}
fn listen(&self, listener: UnboundedSender<String>) -> () {
WebsocketConnectionTransport::listener_loop(self.client_rx.clone(), self.client_tx.clone(), listener).unwrap();
}
fn close(&self) -> impl Future<Output=()> + Send {
let client_tx = self.client_tx.clone();
async move {
let mut tx = client_tx.lock().await;
let _ = tx.write_frame(Frame::close(1000, b"")).await;
}
}
fn on_close(&self) -> () {
todo!()
}
}
impl WebsocketConnectionTransport {
pub async fn new(connection_config: &ConnectionTransportConfig) -> Result<Self, Box<dyn Error>> {
let addr_host = connection_config.host_port();
let retry_delay_ms = 400;
let mut retries = 3;
tracing::debug!("[WebsocketConnectionTransport]: Connecting to websocket @ url: {}", connection_config.full_endpoint());
let stream = loop {
match TcpStream::connect(&addr_host).await {
Ok(stream) => break stream,
Err(e) if e.kind() == std::io::ErrorKind::ConnectionRefused && retries > 0 => {
tracing::warn!("Connection refused, retrying... ({} attempts remaining)", retries);
retries -= 1;
tokio::time::sleep(tokio::time::Duration::from_millis(retry_delay_ms)).await;
}
Err(e) => return Err(Box::new(e)),
}
};
let uri = connection_config.path();
let req = Request::builder()
.method("GET")
.uri(uri)
.header("Host", &addr_host)
.header(UPGRADE, "websocket")
.header(CONNECTION, "upgrade")
.header(
"Sec-WebSocket-Key",
fastwebsockets::handshake::generate_key(),
)
.header("Sec-WebSocket-Version", "13")
.body(http_body_util::Empty::<Bytes>::new()).unwrap();
let (mut ws, _) = handshake::client(&SpawnExecutor, req, stream).await.unwrap();
ws = Self::configure_client(ws);
let (rx, tx) = ws.split(tokio::io::split);
Ok(Self {
client_rx: Arc::new(Mutex::new(rx)),
client_tx: Arc::new(Mutex::new(tx))
})
}
fn configure_client(mut ws: WebSocket<TokioIo<Upgraded>>) -> WebSocket<TokioIo<Upgraded>> {
ws.set_writev(true);
ws.set_auto_close(true);
ws.set_auto_pong(true);
ws
}
pub fn listener_loop(ws_rx: Arc<Mutex<WebSocketRead<ReadHalf<TokioIo<Upgraded>>>>>, ws_tx: Arc<Mutex<WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>>>, tx: UnboundedSender<String>) -> Result<(), WebSocketError>
{
tokio::spawn(async move {
loop {
let mut ws_rx_half = ws_rx.lock().await;
let frame = match ws_rx_half.read_frame(&mut |frame| async {
let mut ws_write_half = ws_tx.lock().await;
return ws_write_half.write_frame(frame).await;
}).await {
Ok(frame) => frame,
Err(WebSocketError::UnexpectedEOF) => {
tracing::warn!("WebSocket connection closed (unexpected EOF). Exiting listener loop.");
break;
}
Err(e) => {
panic!("Unexpected WebSocket error: {:?}", e);
}
};
match frame.opcode {
OpCode::Close => break,
OpCode::Text | OpCode::Binary => {
let incoming = Frame::new(true, frame.opcode, None, frame.payload);
assert!(incoming.fin);
let string_payload = String::from_utf8(incoming.payload.to_owned());
if let Ok(str_payload) = string_payload {
tx.send(str_payload).unwrap()
}
}
_ => {}
}
}
});
Ok(())
}
}
struct SpawnExecutor;
impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
where
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
fn execute(&self, fut: Fut) {
tokio::task::spawn(fut);
}
}