#[cfg(feature = "realtime")]
use crate::error::{Error, Result};
#[cfg(all(feature = "realtime", not(target_arch = "wasm32")))]
#[async_trait::async_trait]
pub trait WebSocketConnection: Send + Sync {
async fn connect(&mut self, url: &str) -> Result<()>;
async fn send(&mut self, message: &str) -> Result<()>;
async fn receive(&mut self) -> Result<Option<String>>;
async fn close(&mut self) -> Result<()>;
fn is_connected(&self) -> bool;
}
#[cfg(all(feature = "realtime", target_arch = "wasm32"))]
#[async_trait::async_trait(?Send)]
pub trait WebSocketConnection {
async fn connect(&mut self, url: &str) -> Result<()>;
async fn send(&mut self, message: &str) -> Result<()>;
async fn receive(&mut self) -> Result<Option<String>>;
async fn close(&mut self) -> Result<()>;
fn is_connected(&self) -> bool;
}
#[cfg(all(feature = "realtime", not(target_arch = "wasm32")))]
pub struct NativeWebSocket {
connection: Option<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
>,
is_connected: std::sync::Arc<std::sync::atomic::AtomicBool>,
}
#[cfg(all(feature = "realtime", not(target_arch = "wasm32")))]
impl NativeWebSocket {
pub fn new() -> Self {
Self {
connection: None,
is_connected: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)),
}
}
}
#[cfg(all(feature = "realtime", not(target_arch = "wasm32")))]
#[async_trait::async_trait]
impl WebSocketConnection for NativeWebSocket {
async fn connect(&mut self, url: &str) -> Result<()> {
use std::sync::atomic::Ordering;
use tokio_tungstenite::connect_async;
tracing::debug!("Connecting to WebSocket: {}", url);
let (ws_stream, _) = connect_async(url)
.await
.map_err(|e| Error::network(format!("WebSocket connection failed: {}", e)))?;
self.connection = Some(ws_stream);
self.is_connected.store(true, Ordering::SeqCst);
tracing::info!("Connected to WebSocket successfully");
Ok(())
}
async fn send(&mut self, message: &str) -> Result<()> {
use futures_util::SinkExt;
use tokio_tungstenite::tungstenite::Message;
if let Some(ref mut ws) = self.connection {
ws.send(Message::Text(message.to_string()))
.await
.map_err(|e| Error::network(format!("Failed to send WebSocket message: {}", e)))?;
tracing::debug!("Sent WebSocket message: {}", message);
Ok(())
} else {
Err(Error::network("WebSocket not connected"))
}
}
async fn receive(&mut self) -> Result<Option<String>> {
use futures_util::StreamExt;
use std::sync::atomic::Ordering;
use tokio_tungstenite::tungstenite::Message;
if let Some(ref mut ws) = self.connection {
match ws.next().await {
Some(Ok(Message::Text(text))) => {
tracing::debug!("Received WebSocket message: {}", text);
Ok(Some(text))
}
Some(Ok(Message::Close(_))) => {
tracing::info!("WebSocket connection closed by remote");
self.is_connected.store(false, Ordering::SeqCst);
Ok(None)
}
Some(Err(e)) => {
tracing::error!("WebSocket error: {}", e);
self.is_connected.store(false, Ordering::SeqCst);
Err(Error::network(format!("WebSocket error: {}", e)))
}
None => Ok(None),
_ => Ok(None), }
} else {
Err(Error::network("WebSocket not connected"))
}
}
async fn close(&mut self) -> Result<()> {
use std::sync::atomic::Ordering;
if let Some(ref mut ws) = self.connection {
let _ = ws.close(None).await;
}
self.connection = None;
self.is_connected.store(false, Ordering::SeqCst);
tracing::info!("WebSocket connection closed");
Ok(())
}
fn is_connected(&self) -> bool {
use std::sync::atomic::Ordering;
self.is_connected.load(Ordering::SeqCst)
}
}
#[cfg(all(feature = "realtime", target_arch = "wasm32"))]
pub struct WasmWebSocket {
websocket: Option<web_sys::WebSocket>,
is_connected: std::sync::Arc<std::sync::atomic::AtomicBool>,
message_queue: std::rc::Rc<std::cell::RefCell<Vec<String>>>,
}
#[cfg(all(feature = "realtime", target_arch = "wasm32"))]
impl WasmWebSocket {
pub fn new() -> Self {
Self {
websocket: None,
is_connected: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)),
message_queue: std::rc::Rc::new(std::cell::RefCell::new(Vec::new())),
}
}
}
#[cfg(all(feature = "realtime", target_arch = "wasm32"))]
#[async_trait::async_trait(?Send)]
impl WebSocketConnection for WasmWebSocket {
async fn connect(&mut self, url: &str) -> Result<()> {
use std::sync::atomic::Ordering;
use wasm_bindgen::{prelude::*, JsCast};
use web_sys::{CloseEvent, ErrorEvent, MessageEvent, WebSocket};
web_sys::console::log_1(&format!("Connecting to WebSocket: {}", url).into());
let websocket = WebSocket::new(url)
.map_err(|e| Error::network(format!("Failed to create WebSocket: {:?}", e)))?;
let is_connected = std::sync::Arc::clone(&self.is_connected);
let message_queue = std::rc::Rc::clone(&self.message_queue);
let onopen_callback = {
let is_connected = std::sync::Arc::clone(&is_connected);
Closure::wrap(Box::new(move |_event: web_sys::Event| {
web_sys::console::log_1(&"WebSocket connection opened".into());
is_connected.store(true, Ordering::SeqCst);
}) as Box<dyn FnMut(_)>)
};
websocket.set_onopen(Some(onopen_callback.as_ref().unchecked_ref()));
onopen_callback.forget();
let onmessage_callback = {
let message_queue = std::rc::Rc::clone(&message_queue);
Closure::wrap(Box::new(move |event: MessageEvent| {
if let Ok(text) = event.data().dyn_into::<js_sys::JsString>() {
let message = String::from(text);
web_sys::console::log_1(
&format!("Received WebSocket message: {}", message).into(),
);
message_queue.borrow_mut().push(message);
}
}) as Box<dyn FnMut(_)>)
};
websocket.set_onmessage(Some(onmessage_callback.as_ref().unchecked_ref()));
onmessage_callback.forget();
let onerror_callback = {
let is_connected = std::sync::Arc::clone(&is_connected);
Closure::wrap(Box::new(move |event: ErrorEvent| {
web_sys::console::log_1(&format!("WebSocket error: {:?}", event).into());
is_connected.store(false, Ordering::SeqCst);
}) as Box<dyn FnMut(_)>)
};
websocket.set_onerror(Some(onerror_callback.as_ref().unchecked_ref()));
onerror_callback.forget();
let onclose_callback = {
let is_connected = std::sync::Arc::clone(&is_connected);
Closure::wrap(Box::new(move |event: CloseEvent| {
web_sys::console::log_1(
&format!("WebSocket connection closed: {}", event.reason()).into(),
);
is_connected.store(false, Ordering::SeqCst);
}) as Box<dyn FnMut(_)>)
};
websocket.set_onclose(Some(onclose_callback.as_ref().unchecked_ref()));
onclose_callback.forget();
self.websocket = Some(websocket);
let start_time = js_sys::Date::now();
let timeout_ms = 5000.0;
while !self.is_connected.load(Ordering::SeqCst)
&& (js_sys::Date::now() - start_time) < timeout_ms
{
let promise = js_sys::Promise::resolve(&wasm_bindgen::JsValue::NULL);
wasm_bindgen_futures::JsFuture::from(promise)
.await
.map_err(|e| Error::network(format!("Promise error: {:?}", e)))?;
let delay_promise = js_sys::Promise::new(&mut |resolve, _| {
web_sys::window()
.unwrap()
.set_timeout_with_callback_and_timeout_and_arguments_0(&resolve, 10)
.unwrap();
});
wasm_bindgen_futures::JsFuture::from(delay_promise)
.await
.map_err(|e| Error::network(format!("Timeout error: {:?}", e)))?;
}
if !self.is_connected.load(Ordering::SeqCst) {
return Err(Error::network("WebSocket connection timeout"));
}
web_sys::console::log_1(&"WebSocket connected successfully".into());
Ok(())
}
async fn send(&mut self, message: &str) -> Result<()> {
if let Some(ref websocket) = self.websocket {
websocket.send_with_str(message).map_err(|e| {
Error::network(format!("Failed to send WebSocket message: {:?}", e))
})?;
web_sys::console::log_1(&format!("Sent WebSocket message: {}", message).into());
Ok(())
} else {
Err(Error::network("WebSocket not connected"))
}
}
async fn receive(&mut self) -> Result<Option<String>> {
let mut queue = self.message_queue.borrow_mut();
if !queue.is_empty() {
Ok(Some(queue.remove(0)))
} else {
Ok(None)
}
}
async fn close(&mut self) -> Result<()> {
use std::sync::atomic::Ordering;
if let Some(ref websocket) = self.websocket {
websocket.close().ok();
}
self.websocket = None;
self.is_connected.store(false, Ordering::SeqCst);
self.message_queue.borrow_mut().clear();
web_sys::console::log_1(&"WebSocket connection closed".into());
Ok(())
}
fn is_connected(&self) -> bool {
use std::sync::atomic::Ordering;
self.is_connected.load(Ordering::SeqCst)
}
}
#[cfg(feature = "realtime")]
pub fn create_websocket() -> Box<dyn WebSocketConnection> {
#[cfg(not(target_arch = "wasm32"))]
{
Box::new(NativeWebSocket::new())
}
#[cfg(target_arch = "wasm32")]
{
Box::new(WasmWebSocket::new())
}
}
#[cfg(all(test, feature = "realtime"))]
mod tests {
use super::*;
#[cfg(not(target_arch = "wasm32"))]
#[test]
fn test_native_websocket_creation() {
let ws = NativeWebSocket::new();
assert!(!ws.is_connected());
}
#[cfg(target_arch = "wasm32")]
#[test]
fn test_wasm_websocket_creation() {
let ws = WasmWebSocket::new();
assert!(!ws.is_connected());
}
#[test]
fn test_create_websocket() {
let ws = create_websocket();
assert!(!ws.is_connected());
}
#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn test_websocket_error_handling() {
let mut ws = NativeWebSocket::new();
let result = ws.send("test").await;
assert!(result.is_err());
let result = ws.receive().await;
assert!(result.is_err());
}
#[cfg(target_arch = "wasm32")]
#[wasm_bindgen_test::wasm_bindgen_test]
async fn test_wasm_websocket_error_handling() {
let mut ws = WasmWebSocket::new();
let result = ws.send("test").await;
assert!(result.is_err());
let result = ws.receive().await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), None);
}
#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn test_websocket_state_management() {
let mut ws = NativeWebSocket::new();
assert!(!ws.is_connected());
ws.close().await.unwrap();
assert!(!ws.is_connected());
}
}