use crate::error::{Error, Result};
use crate::shared::{Transport, TransportMessage};
use async_trait::async_trait;
use futures::{SinkExt, StreamExt};
use std::net::SocketAddr;
use std::sync::Arc;
#[cfg(not(target_arch = "wasm32"))]
use tokio::net::TcpListener;
#[cfg(not(target_arch = "wasm32"))]
use tokio::sync::{mpsc, Mutex};
use tokio_tungstenite::{accept_async, tungstenite::protocol::Message};
use tracing::{error, info, warn};
#[derive(Debug, Clone)]
pub struct WebSocketServerConfig {
pub bind_addr: SocketAddr,
pub max_frame_size: Option<usize>,
pub max_message_size: Option<usize>,
pub accept_unmasked_frames: bool,
}
impl Default for WebSocketServerConfig {
fn default() -> Self {
Self {
bind_addr: "127.0.0.1:9001".parse().expect("Valid default address"),
max_frame_size: Some(64 * 1024 * 1024), max_message_size: Some(64 * 1024 * 1024), accept_unmasked_frames: false,
}
}
}
pub struct WebSocketServerTransport {
config: WebSocketServerConfig,
listener: Option<TcpListener>,
incoming_rx: Arc<Mutex<Option<mpsc::Receiver<TransportMessage>>>>,
outgoing_tx: Arc<Mutex<Option<mpsc::Sender<TransportMessage>>>>,
}
impl WebSocketServerTransport {
pub fn new(config: WebSocketServerConfig) -> Self {
Self {
config,
listener: None,
incoming_rx: Arc::new(Mutex::new(None)),
outgoing_tx: Arc::new(Mutex::new(None)),
}
}
pub fn default_server() -> Self {
Self::new(WebSocketServerConfig::default())
}
pub async fn bind(&mut self) -> Result<()> {
let listener = TcpListener::bind(&self.config.bind_addr)
.await
.map_err(|e| {
Error::internal(format!(
"Failed to bind to {}: {}",
self.config.bind_addr, e
))
})?;
info!("WebSocket server listening on {}", self.config.bind_addr);
self.listener = Some(listener);
Ok(())
}
pub async fn accept(&mut self) -> Result<()> {
let listener = self
.listener
.as_ref()
.ok_or_else(|| Error::internal("Server not bound"))?;
let (tcp_stream, peer_addr) = listener
.accept()
.await
.map_err(|e| Error::internal(format!("Failed to accept connection: {}", e)))?;
info!("Accepting WebSocket connection from {}", peer_addr);
let ws_stream = accept_async(tcp_stream)
.await
.map_err(|e| Error::internal(format!("WebSocket handshake failed: {}", e)))?;
info!("WebSocket connection established with {}", peer_addr);
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<TransportMessage>(100);
let (incoming_tx, incoming_rx) = mpsc::channel::<TransportMessage>(100);
*self.incoming_rx.lock().await = Some(incoming_rx);
*self.outgoing_tx.lock().await = Some(outgoing_tx);
let (pong_tx, mut pong_rx) = mpsc::channel::<Vec<u8>>(10);
let (mut ws_sink, mut ws_stream) = ws_stream.split();
tokio::spawn(async move {
loop {
tokio::select! {
Some(msg) = outgoing_rx.recv() => {
let json_bytes = match crate::shared::stdio::StdioTransport::serialize_message(&msg)
{
Ok(bytes) => bytes,
Err(e) => {
error!("Failed to serialize message: {}", e);
continue;
},
};
let json = match String::from_utf8(json_bytes) {
Ok(json) => json,
Err(e) => {
error!("Failed to convert to UTF-8: {}", e);
continue;
},
};
if let Err(e) = ws_sink.send(Message::Text(json.into())).await {
error!("Failed to send WebSocket message: {}", e);
break;
}
}
Some(data) = pong_rx.recv() => {
if let Err(e) = ws_sink.send(Message::Pong(data.into())).await {
error!("Failed to send pong: {}", e);
break;
}
}
else => break,
}
}
});
tokio::spawn(async move {
while let Some(result) = ws_stream.next().await {
match result {
Ok(Message::Text(text)) => {
match crate::shared::stdio::StdioTransport::parse_message(text.as_bytes()) {
Ok(msg) => {
if let Err(e) = incoming_tx.send(msg).await {
error!("Failed to queue incoming message: {}", e);
break;
}
},
Err(e) => {
error!("Failed to parse message: {}", e);
},
}
},
Ok(Message::Binary(_)) => {
warn!("Received binary message, ignoring");
},
Ok(Message::Close(_)) => {
info!("WebSocket closed by peer");
break;
},
Ok(Message::Ping(data)) => {
if let Err(e) = pong_tx.send(data.to_vec()).await {
error!("Failed to queue pong: {}", e);
break;
}
},
Ok(_) => {
},
Err(e) => {
error!("WebSocket error: {}", e);
break;
},
}
}
info!("WebSocket connection closed");
});
Ok(())
}
}
#[async_trait]
impl Transport for WebSocketServerTransport {
async fn send(&mut self, message: TransportMessage) -> Result<()> {
let tx_guard = self.outgoing_tx.lock().await;
let tx = tx_guard
.as_ref()
.ok_or_else(|| Error::internal("No active connection"))?;
let result = tx
.send(message)
.await
.map_err(|_| Error::internal("Failed to send message"));
drop(tx_guard);
result?;
Ok(())
}
async fn receive(&mut self) -> Result<TransportMessage> {
let mut rx_guard = self.incoming_rx.lock().await;
let rx = rx_guard
.as_mut()
.ok_or_else(|| Error::internal("No active connection"))?;
let result = rx
.recv()
.await
.ok_or_else(|| Error::internal("Connection closed"));
drop(rx_guard);
result
}
async fn close(&mut self) -> Result<()> {
*self.incoming_rx.lock().await = None;
*self.outgoing_tx.lock().await = None;
info!("WebSocket server transport closed");
Ok(())
}
fn is_connected(&self) -> bool {
futures::executor::block_on(async { self.outgoing_tx.lock().await.is_some() })
}
fn transport_type(&self) -> &'static str {
"websocket-server"
}
}
impl std::fmt::Debug for WebSocketServerTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebSocketServerTransport")
.field("config", &self.config)
.field("listener", &self.listener.is_some())
.field("has_active_connection", &self.is_connected())
.finish()
}
}
#[derive(Debug)]
pub struct WebSocketServerBuilder {
config: WebSocketServerConfig,
}
impl WebSocketServerBuilder {
pub fn new() -> Self {
Self {
config: WebSocketServerConfig::default(),
}
}
pub fn bind_addr(mut self, addr: SocketAddr) -> Self {
self.config.bind_addr = addr;
self
}
pub fn max_frame_size(mut self, size: usize) -> Self {
self.config.max_frame_size = Some(size);
self
}
pub fn max_message_size(mut self, size: usize) -> Self {
self.config.max_message_size = Some(size);
self
}
pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
self.config.accept_unmasked_frames = accept;
self
}
pub fn build(self) -> WebSocketServerTransport {
WebSocketServerTransport::new(self.config)
}
}
impl Default for WebSocketServerBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = WebSocketServerConfig::default();
assert_eq!(config.bind_addr.to_string(), "127.0.0.1:9001");
assert_eq!(config.max_frame_size, Some(64 * 1024 * 1024));
assert!(!config.accept_unmasked_frames);
}
#[test]
fn test_builder() {
let transport = WebSocketServerBuilder::new()
.bind_addr("127.0.0.1:9002".parse().unwrap())
.max_frame_size(1024 * 1024)
.max_message_size(2 * 1024 * 1024)
.accept_unmasked_frames(true)
.build();
assert_eq!(transport.config.bind_addr.to_string(), "127.0.0.1:9002");
assert_eq!(transport.config.max_frame_size, Some(1024 * 1024));
assert_eq!(transport.config.max_message_size, Some(2 * 1024 * 1024));
assert!(transport.config.accept_unmasked_frames);
}
}