use crate::{
Error, Result,
websocket::{CloseCode, Message},
};
use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::{
MaybeTlsStream, WebSocketStream,
tungstenite::{self, protocol::CloseFrame},
};
pub struct ReqwestWebSocket {
pub stream: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
pub closed: bool,
pub max_message_size: isize,
}
impl ReqwestWebSocket {
pub async fn new(url: &str) -> Result<Self> {
let (stream, _) =
tokio_tungstenite::connect_async(url)
.await
.map_err(|e| Error::Network {
code: -1,
message: format!("WebSocket connection failed: {}", e),
})?;
Ok(Self {
stream,
closed: false,
max_message_size: 1024 * 1024, })
}
pub async fn new_with_config(url: &str, max_message_size: Option<isize>) -> Result<Self> {
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
let config = if let Some(size) = max_message_size {
let mut config = WebSocketConfig::default();
config.max_message_size = Some(size as usize);
Some(config)
} else {
None
};
let (stream, _) = if let Some(config) = config {
tokio_tungstenite::connect_async_with_config(url, Some(config), false)
.await
.map_err(|e| Error::Network {
code: -1,
message: format!("WebSocket connection failed: {}", e),
})?
} else {
tokio_tungstenite::connect_async(url)
.await
.map_err(|e| Error::Network {
code: -1,
message: format!("WebSocket connection failed: {}", e),
})?
};
Ok(Self {
stream,
closed: false,
max_message_size: max_message_size.unwrap_or(1024 * 1024),
})
}
pub async fn send(&mut self, message: Message) -> Result<()> {
if self.closed {
return Err(Error::Internal(
"WebSocket connection is closed".to_string(),
));
}
let tungstenite_message = match message {
Message::Text(text) => tungstenite::Message::Text(text.into()),
Message::Binary(data) => tungstenite::Message::Binary(data.into()),
};
self.stream
.send(tungstenite_message)
.await
.map_err(|e| Error::Network {
code: -1,
message: format!("Failed to send WebSocket message: {}", e),
})
}
pub async fn receive(&mut self) -> Result<Message> {
if self.closed {
return Err(Error::Internal(
"WebSocket connection is closed".to_string(),
));
}
loop {
match self.stream.next().await {
Some(Ok(tungstenite::Message::Text(text))) => {
return Ok(Message::Text(text.to_string()));
}
Some(Ok(tungstenite::Message::Binary(data))) => {
return Ok(Message::Binary(data.to_vec()));
}
Some(Ok(tungstenite::Message::Close(_))) => {
self.closed = true;
return Err(Error::Internal(
"WebSocket connection closed by server".to_string(),
));
}
Some(Ok(tungstenite::Message::Ping(_))) => {
continue;
}
Some(Ok(tungstenite::Message::Pong(_))) => {
continue;
}
Some(Ok(tungstenite::Message::Frame(_))) => {
continue;
}
Some(Err(e)) => {
self.closed = true;
return Err(Error::Network {
code: -1,
message: format!("WebSocket error: {}", e),
});
}
None => {
self.closed = true;
return Err(Error::Internal("WebSocket stream ended".to_string()));
}
}
}
}
pub async fn close(&mut self, code: CloseCode, reason: Option<&str>) -> Result<()> {
if self.closed {
return Ok(());
}
let close_frame = reason.map(|r| CloseFrame {
code: tungstenite::protocol::frame::coding::CloseCode::from(code as u16),
reason: r.to_string().into(),
});
let close_message = tungstenite::Message::Close(close_frame);
let result = self.stream.send(close_message).await;
self.closed = true;
result.map_err(|e| Error::Network {
code: -1,
message: format!("Failed to close WebSocket: {}", e),
})
}
pub fn close_code(&self) -> Option<isize> {
if self.closed {
Some(CloseCode::Normal as isize)
} else {
None
}
}
pub fn close_reason(&self) -> Option<String> {
if self.closed {
Some("Connection closed".to_string())
} else {
None
}
}
pub fn set_maximum_message_size(&mut self, size: isize) {
self.max_message_size = size;
}
pub fn maximum_message_size(&self) -> isize {
self.max_message_size
}
}
pub struct ReqwestWebSocketBuilder {
pub max_message_size: Option<isize>,
}
impl ReqwestWebSocketBuilder {
pub fn new() -> Self {
Self {
max_message_size: None,
}
}
pub fn maximum_message_size(mut self, size: isize) -> Self {
self.max_message_size = Some(size);
self
}
pub async fn connect(self, url: &str) -> Result<ReqwestWebSocket> {
ReqwestWebSocket::new_with_config(url, self.max_message_size).await
}
}