use crate::error::{WebSocketError, WebSocketResult};
use crate::message::Message;
use futures_util::stream::SplitSink;
use futures_util::SinkExt;
use parking_lot::RwLock;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio_tungstenite::WebSocketStream;
pub type ConnectionId = String;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Connecting,
Open,
Closing,
Closed,
}
pub struct Connection {
pub id: ConnectionId,
pub remote_addr: Option<SocketAddr>,
state: Arc<RwLock<ConnectionState>>,
tx: mpsc::UnboundedSender<Message>,
}
impl Connection {
pub(crate) fn new(
id: ConnectionId,
remote_addr: Option<SocketAddr>,
tx: mpsc::UnboundedSender<Message>,
) -> Self {
Self {
id,
remote_addr,
state: Arc::new(RwLock::new(ConnectionState::Open)),
tx,
}
}
pub fn state(&self) -> ConnectionState {
*self.state.read()
}
pub fn is_open(&self) -> bool {
self.state() == ConnectionState::Open
}
pub fn send(&self, message: Message) -> WebSocketResult<()> {
if !self.is_open() {
return Err(WebSocketError::ConnectionClosed);
}
self.tx
.send(message)
.map_err(|e| WebSocketError::Send(e.to_string()))
}
pub fn send_text<S: Into<String>>(&self, text: S) -> WebSocketResult<()> {
self.send(Message::text(text))
}
pub fn send_binary<B: Into<bytes::Bytes>>(&self, data: B) -> WebSocketResult<()> {
self.send(Message::binary(data))
}
pub fn send_json<T: serde::Serialize>(&self, value: &T) -> WebSocketResult<()> {
let message = Message::json(value)?;
self.send(message)
}
pub fn close(&self) {
*self.state.write() = ConnectionState::Closing;
let _ = self.tx.send(Message::close());
}
pub(crate) fn set_state(&self, state: ConnectionState) {
*self.state.write() = state;
}
}
impl Clone for Connection {
fn clone(&self) -> Self {
Self {
id: self.id.clone(),
remote_addr: self.remote_addr,
state: Arc::clone(&self.state),
tx: self.tx.clone(),
}
}
}
pub(crate) struct ConnectionWriter {
sink: SplitSink<WebSocketStream<TcpStream>, tungstenite::Message>,
rx: mpsc::UnboundedReceiver<Message>,
}
impl ConnectionWriter {
pub fn new(
sink: SplitSink<WebSocketStream<TcpStream>, tungstenite::Message>,
rx: mpsc::UnboundedReceiver<Message>,
) -> Self {
Self { sink, rx }
}
pub async fn run(mut self) -> WebSocketResult<()> {
while let Some(message) = self.rx.recv().await {
let is_close = message.is_close();
let raw_message: tungstenite::Message = message.into();
if let Err(e) = self.sink.send(raw_message).await {
tracing::error!(error = %e, "Failed to send WebSocket message");
return Err(WebSocketError::Protocol(e));
}
if is_close {
break;
}
}
let _ = self.sink.close().await;
Ok(())
}
}