use std::io::{self, Read, Write};
use tungstenite::{Message, WebSocket};
#[derive(Debug)]
pub struct WebSocketWrapper<S> {
socket: WebSocket<S>,
read_buffer: Vec<u8>,
write_buffer: Vec<u8>,
}
impl<S> WebSocketWrapper<S> {
pub fn new(socket: WebSocket<S>) -> Self {
Self {
socket,
read_buffer: Vec::new(),
write_buffer: Vec::new(),
}
}
pub fn get_ref(&self) -> &WebSocket<S> {
&self.socket
}
pub fn get_mut(&mut self) -> &mut WebSocket<S> {
&mut self.socket
}
}
impl<S: Read + Write> Read for WebSocketWrapper<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if !self.read_buffer.is_empty() {
let len = std::cmp::min(buf.len(), self.read_buffer.len());
buf[..len].copy_from_slice(&self.read_buffer[..len]);
self.read_buffer.drain(..len);
return Ok(len);
}
loop {
match self.socket.read() {
Ok(Message::Text(text)) => {
let text_bytes: &[u8] = text.as_ref();
let mut data = text_bytes.to_vec();
if !data.ends_with(b"\n") {
data.push(b'\n');
}
let len = std::cmp::min(buf.len(), data.len());
buf[..len].copy_from_slice(&data[..len]);
if data.len() > len {
self.read_buffer.extend_from_slice(&data[len..]);
}
return Ok(len);
}
Ok(Message::Binary(data)) => {
let data_bytes: &[u8] = data.as_ref();
let mut data = data_bytes.to_vec();
if !data.ends_with(b"\n") {
data.push(b'\n');
}
let len = std::cmp::min(buf.len(), data.len());
buf[..len].copy_from_slice(&data[..len]);
if data.len() > len {
self.read_buffer.extend_from_slice(&data[len..]);
}
return Ok(len);
}
Ok(Message::Ping(data)) => {
let _ = self.socket.write(Message::Pong(data));
let _ = self.socket.flush();
continue;
}
Ok(Message::Pong(_)) => {
continue;
}
Ok(Message::Close(_)) => {
return Ok(0);
}
Ok(Message::Frame(_)) => {
continue;
}
Err(tungstenite::Error::Io(e)) => {
return Err(e);
}
Err(e) => {
return Err(io::Error::new(io::ErrorKind::Other, e.to_string()));
}
}
}
}
}
impl<S: Read + Write> Write for WebSocketWrapper<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_buffer.extend_from_slice(buf);
while let Some(newline_pos) = self.write_buffer.iter().position(|&b| b == b'\n') {
let message: Vec<u8> = self.write_buffer.drain(..=newline_pos).collect();
let message_str = String::from_utf8_lossy(&message[..message.len() - 1]);
self.socket
.write(Message::Text(message_str.into_owned().into()))
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
}
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
if !self.write_buffer.is_empty() {
let message = std::mem::take(&mut self.write_buffer);
let message_str = String::from_utf8_lossy(&message);
self.socket
.write(Message::Text(message_str.into_owned().into()))
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
}
self.socket
.flush()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))
}
}