armature_websocket/
connection.rs1use crate::error::{WebSocketError, WebSocketResult};
4use crate::message::Message;
5use futures_util::stream::SplitSink;
6use futures_util::SinkExt;
7use parking_lot::RwLock;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio::net::TcpStream;
11use tokio::sync::mpsc;
12use tokio_tungstenite::WebSocketStream;
13
14pub type ConnectionId = String;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum ConnectionState {
20 Connecting,
22 Open,
24 Closing,
26 Closed,
28}
29
30pub struct Connection {
32 pub id: ConnectionId,
34 pub remote_addr: Option<SocketAddr>,
36 state: Arc<RwLock<ConnectionState>>,
38 tx: mpsc::UnboundedSender<Message>,
40}
41
42impl Connection {
43 pub(crate) fn new(
45 id: ConnectionId,
46 remote_addr: Option<SocketAddr>,
47 tx: mpsc::UnboundedSender<Message>,
48 ) -> Self {
49 Self {
50 id,
51 remote_addr,
52 state: Arc::new(RwLock::new(ConnectionState::Open)),
53 tx,
54 }
55 }
56
57 pub fn state(&self) -> ConnectionState {
59 *self.state.read()
60 }
61
62 pub fn is_open(&self) -> bool {
64 self.state() == ConnectionState::Open
65 }
66
67 pub fn send(&self, message: Message) -> WebSocketResult<()> {
69 if !self.is_open() {
70 return Err(WebSocketError::ConnectionClosed);
71 }
72 self.tx
73 .send(message)
74 .map_err(|e| WebSocketError::Send(e.to_string()))
75 }
76
77 pub fn send_text<S: Into<String>>(&self, text: S) -> WebSocketResult<()> {
79 self.send(Message::text(text))
80 }
81
82 pub fn send_binary<B: Into<bytes::Bytes>>(&self, data: B) -> WebSocketResult<()> {
84 self.send(Message::binary(data))
85 }
86
87 pub fn send_json<T: serde::Serialize>(&self, value: &T) -> WebSocketResult<()> {
89 let message = Message::json(value)?;
90 self.send(message)
91 }
92
93 pub fn close(&self) {
95 *self.state.write() = ConnectionState::Closing;
96 let _ = self.tx.send(Message::close());
97 }
98
99 pub(crate) fn set_state(&self, state: ConnectionState) {
101 *self.state.write() = state;
102 }
103}
104
105impl Clone for Connection {
106 fn clone(&self) -> Self {
107 Self {
108 id: self.id.clone(),
109 remote_addr: self.remote_addr,
110 state: Arc::clone(&self.state),
111 tx: self.tx.clone(),
112 }
113 }
114}
115
116pub(crate) struct ConnectionWriter {
118 sink: SplitSink<WebSocketStream<TcpStream>, tungstenite::Message>,
119 rx: mpsc::UnboundedReceiver<Message>,
120}
121
122impl ConnectionWriter {
123 pub fn new(
125 sink: SplitSink<WebSocketStream<TcpStream>, tungstenite::Message>,
126 rx: mpsc::UnboundedReceiver<Message>,
127 ) -> Self {
128 Self { sink, rx }
129 }
130
131 pub async fn run(mut self) -> WebSocketResult<()> {
133 while let Some(message) = self.rx.recv().await {
134 let is_close = message.is_close();
135 let raw_message: tungstenite::Message = message.into();
136
137 if let Err(e) = self.sink.send(raw_message).await {
138 tracing::error!(error = %e, "Failed to send WebSocket message");
139 return Err(WebSocketError::Protocol(e));
140 }
141
142 if is_close {
143 break;
144 }
145 }
146
147 let _ = self.sink.close().await;
149 Ok(())
150 }
151}
152