use std::net::TcpStream;
use tungstenite::stream::MaybeTlsStream;
use tungstenite::{connect, Message, WebSocket};
pub struct WsConnection {
socket: WebSocket<MaybeTlsStream<TcpStream>>,
}
#[derive(Debug)]
pub enum WsError {
InvalidUri(String),
ConnectionFailed(String),
SendFailed(String),
}
impl std::fmt::Display for WsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WsError::InvalidUri(msg) => write!(f, "Invalid URI: {}", msg),
WsError::ConnectionFailed(msg) => write!(f, "Connection failed: {}", msg),
WsError::SendFailed(msg) => write!(f, "Send failed: {}", msg),
}
}
}
impl std::error::Error for WsError {}
impl WsConnection {
pub fn connect(uri: &str) -> Result<Self, WsError> {
let (socket, _response) =
connect(uri).map_err(|e| WsError::ConnectionFailed(e.to_string()))?;
Ok(WsConnection { socket })
}
pub fn send_text(&mut self, text: &str) -> Result<(), WsError> {
self.socket
.send(Message::Text(text.to_string().into()))
.map_err(|e| WsError::SendFailed(e.to_string()))
}
pub fn close(&mut self) {
let _ = self.socket.close(None);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connect_invalid_uri() {
let result = WsConnection::connect("not_a_valid_uri");
assert!(result.is_err());
if let Err(WsError::ConnectionFailed(_)) = result {
} else {
panic!("Expected ConnectionFailed error for invalid URI");
}
}
#[test]
fn test_connect_unreachable_server() {
let result = WsConnection::connect("ws://127.0.0.1:19999");
assert!(result.is_err());
if let Err(WsError::ConnectionFailed(_)) = result {
} else {
panic!("Expected ConnectionFailed error");
}
}
#[test]
fn test_ws_error_display() {
let err = WsError::InvalidUri("bad uri".to_string());
assert_eq!(err.to_string(), "Invalid URI: bad uri");
let err = WsError::ConnectionFailed("refused".to_string());
assert_eq!(err.to_string(), "Connection failed: refused");
let err = WsError::SendFailed("broken pipe".to_string());
assert_eq!(err.to_string(), "Send failed: broken pipe");
}
}