armature_websocket/
connection.rs

1//! WebSocket connection management.
2
3use crate::error::{WebSocketError, WebSocketResult};
4use crate::message::Message;
5use futures_util::stream::SplitSink;
6use futures_util::SinkExt;
7use parking_lot::RwLock;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio::net::TcpStream;
11use tokio::sync::mpsc;
12use tokio_tungstenite::WebSocketStream;
13
14/// Unique identifier for a connection.
15pub type ConnectionId = String;
16
17/// Connection state.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum ConnectionState {
20    /// Connection is being established
21    Connecting,
22    /// Connection is open and ready
23    Open,
24    /// Connection is closing
25    Closing,
26    /// Connection is closed
27    Closed,
28}
29
30/// A WebSocket connection.
31pub struct Connection {
32    /// Unique connection identifier
33    pub id: ConnectionId,
34    /// Remote address
35    pub remote_addr: Option<SocketAddr>,
36    /// Connection state
37    state: Arc<RwLock<ConnectionState>>,
38    /// Sender for outgoing messages
39    tx: mpsc::UnboundedSender<Message>,
40}
41
42impl Connection {
43    /// Create a new connection.
44    pub(crate) fn new(
45        id: ConnectionId,
46        remote_addr: Option<SocketAddr>,
47        tx: mpsc::UnboundedSender<Message>,
48    ) -> Self {
49        Self {
50            id,
51            remote_addr,
52            state: Arc::new(RwLock::new(ConnectionState::Open)),
53            tx,
54        }
55    }
56
57    /// Get the connection state.
58    pub fn state(&self) -> ConnectionState {
59        *self.state.read()
60    }
61
62    /// Check if the connection is open.
63    pub fn is_open(&self) -> bool {
64        self.state() == ConnectionState::Open
65    }
66
67    /// Send a message to this connection.
68    pub fn send(&self, message: Message) -> WebSocketResult<()> {
69        if !self.is_open() {
70            return Err(WebSocketError::ConnectionClosed);
71        }
72        self.tx
73            .send(message)
74            .map_err(|e| WebSocketError::Send(e.to_string()))
75    }
76
77    /// Send a text message.
78    pub fn send_text<S: Into<String>>(&self, text: S) -> WebSocketResult<()> {
79        self.send(Message::text(text))
80    }
81
82    /// Send a binary message.
83    pub fn send_binary<B: Into<bytes::Bytes>>(&self, data: B) -> WebSocketResult<()> {
84        self.send(Message::binary(data))
85    }
86
87    /// Send a JSON message.
88    pub fn send_json<T: serde::Serialize>(&self, value: &T) -> WebSocketResult<()> {
89        let message = Message::json(value)?;
90        self.send(message)
91    }
92
93    /// Close the connection.
94    pub fn close(&self) {
95        *self.state.write() = ConnectionState::Closing;
96        let _ = self.tx.send(Message::close());
97    }
98
99    /// Set the connection state (internal use).
100    pub(crate) fn set_state(&self, state: ConnectionState) {
101        *self.state.write() = state;
102    }
103}
104
105impl Clone for Connection {
106    fn clone(&self) -> Self {
107        Self {
108            id: self.id.clone(),
109            remote_addr: self.remote_addr,
110            state: Arc::clone(&self.state),
111            tx: self.tx.clone(),
112        }
113    }
114}
115
116/// Manages the write side of a WebSocket connection.
117pub(crate) struct ConnectionWriter {
118    sink: SplitSink<WebSocketStream<TcpStream>, tungstenite::Message>,
119    rx: mpsc::UnboundedReceiver<Message>,
120}
121
122impl ConnectionWriter {
123    /// Create a new connection writer.
124    pub fn new(
125        sink: SplitSink<WebSocketStream<TcpStream>, tungstenite::Message>,
126        rx: mpsc::UnboundedReceiver<Message>,
127    ) -> Self {
128        Self { sink, rx }
129    }
130
131    /// Run the writer loop, sending messages from the channel to the WebSocket.
132    pub async fn run(mut self) -> WebSocketResult<()> {
133        while let Some(message) = self.rx.recv().await {
134            let is_close = message.is_close();
135            let raw_message: tungstenite::Message = message.into();
136
137            if let Err(e) = self.sink.send(raw_message).await {
138                tracing::error!(error = %e, "Failed to send WebSocket message");
139                return Err(WebSocketError::Protocol(e));
140            }
141
142            if is_close {
143                break;
144            }
145        }
146
147        // Gracefully close the sink
148        let _ = self.sink.close().await;
149        Ok(())
150    }
151}
152