use crate::config::GLOBAL_CONFIG;
use axum::extract::ws::{self, Utf8Bytes as AxumUtf8Bytes, WebSocket, WebSocketUpgrade};
use axum::response::IntoResponse;
use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::{Error, Message};
use tungstenite::ClientRequestBuilder;
use tungstenite::client::IntoClientRequest;
const VITE_WS_PROTOCOL: &str = "vite-hmr";
pub async fn vite_websocket_proxy(ws: WebSocketUpgrade) -> impl IntoResponse {
ws.protocols([VITE_WS_PROTOCOL]).on_upgrade(handle_socket)
}
async fn handle_socket(mut tuono_socket: WebSocket) {
if tuono_socket
.send(ws::Message::Ping("tuono connected".into()))
.await
.is_err()
{
return;
}
let config = GLOBAL_CONFIG
.get()
.expect("Failed to get the internal config");
let vite_ws = format!(
"ws://{}:{}/vite-server/",
config.server.host,
config.server.port + 1
);
let vite_ws_request = ClientRequestBuilder::new(vite_ws.parse().unwrap())
.with_sub_protocol(VITE_WS_PROTOCOL)
.into_client_request()
.expect("Failed to create vite WS request");
let vite_socket = match connect_async(vite_ws_request).await {
Ok((stream, _)) => {
stream
}
Err(e) => {
eprintln!("Failed to connect to vite's WebSocket. Error: {e}");
return;
}
};
let (mut vite_sender, mut vite_receiver) = vite_socket.split();
let (mut tuono_sender, mut tuono_receiver) = tuono_socket.split();
tokio::spawn(async move {
while let Some(msg) = tuono_receiver.next().await {
if let Ok(msg) = msg {
let msg_to_vite = match msg.clone() {
ws::Message::Text(str) => Message::Text(str.to_string().into()),
ws::Message::Pong(payload) => Message::Pong(payload),
ws::Message::Ping(payload) => Message::Ping(payload),
ws::Message::Binary(payload) => Message::Binary(payload),
ws::Message::Close(_) => Message::Close(None),
};
vite_sender
.send(msg_to_vite)
.await
.expect("Failed to tunnel msg to vite's WebSocket");
msg
} else {
return;
};
}
});
tokio::spawn(async move {
while let Some(Ok(msg)) = vite_receiver.next().await {
let msg_to_browser = match msg {
Message::Text(str) => ws::Message::Text(AxumUtf8Bytes::from(str.to_string())),
Message::Ping(payload) => ws::Message::Ping(payload),
Message::Pong(payload) => ws::Message::Pong(payload),
Message::Binary(payload) => ws::Message::Binary(payload),
Message::Close(_) => ws::Message::Close(None),
_ => {
eprintln!("Unexpected message from the vite WebSocket to the browser: {msg:?}");
ws::Message::Text("Unhandled".to_string().into())
}
};
if let Err(err) = tuono_sender.send(msg_to_browser).await {
if err.to_string() != Error::AlreadyClosed.to_string() {
eprintln!("Failed to send back message from vite to browser: {err}")
}
}
}
});
}