Skip to main content

oxigdal_websocket/server/
connection.rs

1//! WebSocket connection management
2
3use crate::error::{Error, Result};
4use crate::protocol::ProtocolCodec;
5use crate::protocol::message::Message;
6use futures::{SinkExt, StreamExt};
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::SystemTime;
11use tokio::net::TcpStream;
12use tokio::sync::{Mutex, mpsc};
13use tokio_tungstenite::WebSocketStream;
14use tokio_tungstenite::tungstenite::Message as WsMessage;
15use uuid::Uuid;
16
17/// Connection ID type
18pub type ConnectionId = Uuid;
19
20/// Connection state
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum ConnectionState {
23    /// Connecting
24    Connecting,
25    /// Connected
26    Connected,
27    /// Disconnecting
28    Disconnecting,
29    /// Disconnected
30    Disconnected,
31}
32
33/// WebSocket connection
34pub struct Connection {
35    /// Connection ID
36    id: ConnectionId,
37    /// Remote address
38    remote_addr: SocketAddr,
39    /// Connection state
40    state: Arc<Mutex<ConnectionState>>,
41    /// WebSocket stream
42    ws: Arc<Mutex<WebSocketStream<TcpStream>>>,
43    /// Protocol codec
44    codec: Arc<ProtocolCodec>,
45    /// Outgoing message queue
46    tx: mpsc::UnboundedSender<Message>,
47    /// Last activity timestamp
48    last_activity: Arc<AtomicU64>,
49    /// Connection metadata
50    metadata: Arc<Mutex<ConnectionMetadata>>,
51    /// Message statistics
52    stats: Arc<ConnectionStatistics>,
53}
54
55/// Connection metadata
56#[derive(Debug, Default, Clone)]
57pub struct ConnectionMetadata {
58    /// User ID (if authenticated)
59    pub user_id: Option<String>,
60    /// Custom tags
61    pub tags: std::collections::HashMap<String, String>,
62    /// Subscribed topics
63    pub subscriptions: std::collections::HashSet<String>,
64    /// Joined rooms
65    pub rooms: std::collections::HashSet<String>,
66}
67
68/// Connection statistics
69#[derive(Debug, Default)]
70pub struct ConnectionStatistics {
71    /// Messages sent
72    pub messages_sent: AtomicU64,
73    /// Messages received
74    pub messages_received: AtomicU64,
75    /// Bytes sent
76    pub bytes_sent: AtomicU64,
77    /// Bytes received
78    pub bytes_received: AtomicU64,
79    /// Errors encountered
80    pub errors: AtomicU64,
81}
82
83impl Connection {
84    /// Create a new connection
85    pub fn new(
86        ws: WebSocketStream<TcpStream>,
87        remote_addr: SocketAddr,
88        codec: ProtocolCodec,
89    ) -> (Self, mpsc::UnboundedReceiver<Message>) {
90        let (tx, rx) = mpsc::unbounded_channel();
91
92        let connection = Self {
93            id: Uuid::new_v4(),
94            remote_addr,
95            state: Arc::new(Mutex::new(ConnectionState::Connected)),
96            ws: Arc::new(Mutex::new(ws)),
97            codec: Arc::new(codec),
98            tx,
99            last_activity: Arc::new(AtomicU64::new(Self::current_timestamp())),
100            metadata: Arc::new(Mutex::new(ConnectionMetadata::default())),
101            stats: Arc::new(ConnectionStatistics::default()),
102        };
103
104        (connection, rx)
105    }
106
107    /// Get connection ID
108    pub fn id(&self) -> ConnectionId {
109        self.id
110    }
111
112    /// Get remote address
113    pub fn remote_addr(&self) -> SocketAddr {
114        self.remote_addr
115    }
116
117    /// Get connection state
118    pub async fn state(&self) -> ConnectionState {
119        *self.state.lock().await
120    }
121
122    /// Set connection state
123    pub async fn set_state(&self, new_state: ConnectionState) {
124        let mut state = self.state.lock().await;
125        *state = new_state;
126    }
127
128    /// Send a message
129    pub async fn send(&self, message: Message) -> Result<()> {
130        self.tx
131            .send(message)
132            .map_err(|e| Error::Connection(format!("Failed to send message: {}", e)))?;
133        Ok(())
134    }
135
136    /// Receive a message
137    pub async fn receive(&self) -> Result<Option<Message>> {
138        let mut ws = self.ws.lock().await;
139
140        match ws.next().await {
141            Some(Ok(ws_msg)) => {
142                self.update_activity();
143                self.stats.messages_received.fetch_add(1, Ordering::Relaxed);
144
145                match ws_msg {
146                    WsMessage::Binary(data) => {
147                        let bytes: &[u8] = &data;
148                        self.stats
149                            .bytes_received
150                            .fetch_add(bytes.len() as u64, Ordering::Relaxed);
151                        let message = self.codec.decode(bytes)?;
152                        Ok(Some(message))
153                    }
154                    WsMessage::Text(text) => {
155                        let bytes = text.as_bytes();
156                        self.stats
157                            .bytes_received
158                            .fetch_add(bytes.len() as u64, Ordering::Relaxed);
159                        let message = self.codec.decode(bytes)?;
160                        Ok(Some(message))
161                    }
162                    WsMessage::Ping(data) => {
163                        // Respond with pong
164                        ws.send(WsMessage::Pong(data)).await?;
165                        Ok(None)
166                    }
167                    WsMessage::Pong(_) => {
168                        // Update activity on pong
169                        Ok(None)
170                    }
171                    WsMessage::Close(_) => {
172                        self.set_state(ConnectionState::Disconnecting).await;
173                        Ok(None)
174                    }
175                    _ => Ok(None),
176                }
177            }
178            Some(Err(e)) => {
179                self.stats.errors.fetch_add(1, Ordering::Relaxed);
180                Err(Error::WebSocket(e.to_string()))
181            }
182            None => {
183                self.set_state(ConnectionState::Disconnected).await;
184                Ok(None)
185            }
186        }
187    }
188
189    /// Process outgoing messages
190    pub async fn process_outgoing(&self, mut rx: mpsc::UnboundedReceiver<Message>) -> Result<()> {
191        while let Some(message) = rx.recv().await {
192            if let Err(e) = self.send_message(message).await {
193                tracing::error!("Failed to send message: {}", e);
194                self.stats.errors.fetch_add(1, Ordering::Relaxed);
195            }
196        }
197        Ok(())
198    }
199
200    /// Send a message directly to the WebSocket
201    async fn send_message(&self, message: Message) -> Result<()> {
202        let encoded = self.codec.encode(&message)?;
203        self.stats
204            .bytes_sent
205            .fetch_add(encoded.len() as u64, Ordering::Relaxed);
206        self.stats.messages_sent.fetch_add(1, Ordering::Relaxed);
207
208        let mut ws = self.ws.lock().await;
209        ws.send(WsMessage::Binary(encoded.to_vec().into())).await?;
210
211        self.update_activity();
212        Ok(())
213    }
214
215    /// Send a ping
216    pub async fn ping(&self) -> Result<()> {
217        let mut ws = self.ws.lock().await;
218        ws.send(WsMessage::Ping(Vec::new().into())).await?;
219        self.update_activity();
220        Ok(())
221    }
222
223    /// Close the connection
224    pub async fn close(&self) -> Result<()> {
225        self.set_state(ConnectionState::Disconnecting).await;
226        let mut ws = self.ws.lock().await;
227        ws.close(None).await?;
228        self.set_state(ConnectionState::Disconnected).await;
229        Ok(())
230    }
231
232    /// Get metadata
233    pub async fn metadata(&self) -> ConnectionMetadata {
234        self.metadata.lock().await.clone()
235    }
236
237    /// Update metadata
238    pub async fn update_metadata<F>(&self, f: F)
239    where
240        F: FnOnce(&mut ConnectionMetadata),
241    {
242        let mut metadata = self.metadata.lock().await;
243        f(&mut metadata);
244    }
245
246    /// Get last activity timestamp
247    pub fn last_activity(&self) -> u64 {
248        self.last_activity.load(Ordering::Relaxed)
249    }
250
251    /// Update last activity
252    fn update_activity(&self) {
253        self.last_activity
254            .store(Self::current_timestamp(), Ordering::Relaxed);
255    }
256
257    /// Get current timestamp in seconds
258    fn current_timestamp() -> u64 {
259        SystemTime::now()
260            .duration_since(SystemTime::UNIX_EPOCH)
261            .map(|d| d.as_secs())
262            .unwrap_or(0)
263    }
264
265    /// Check if connection is idle
266    pub fn is_idle(&self, timeout_secs: u64) -> bool {
267        let now = Self::current_timestamp();
268        let last = self.last_activity();
269        now.saturating_sub(last) > timeout_secs
270    }
271
272    /// Get statistics
273    pub fn stats(&self) -> ConnectionStats {
274        ConnectionStats {
275            messages_sent: self.stats.messages_sent.load(Ordering::Relaxed),
276            messages_received: self.stats.messages_received.load(Ordering::Relaxed),
277            bytes_sent: self.stats.bytes_sent.load(Ordering::Relaxed),
278            bytes_received: self.stats.bytes_received.load(Ordering::Relaxed),
279            errors: self.stats.errors.load(Ordering::Relaxed),
280        }
281    }
282}
283
284/// Connection statistics snapshot
285#[derive(Debug, Clone, Default)]
286pub struct ConnectionStats {
287    /// Messages sent
288    pub messages_sent: u64,
289    /// Messages received
290    pub messages_received: u64,
291    /// Bytes sent
292    pub bytes_sent: u64,
293    /// Bytes received
294    pub bytes_received: u64,
295    /// Errors
296    pub errors: u64,
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_connection_id() {
305        let id1 = Uuid::new_v4();
306        let id2 = Uuid::new_v4();
307        assert_ne!(id1, id2);
308    }
309
310    #[test]
311    fn test_connection_state() {
312        assert_eq!(ConnectionState::Connected, ConnectionState::Connected);
313        assert_ne!(ConnectionState::Connected, ConnectionState::Disconnected);
314    }
315
316    #[test]
317    fn test_connection_metadata() {
318        let mut metadata = ConnectionMetadata {
319            user_id: Some("user123".to_string()),
320            ..Default::default()
321        };
322        metadata
323            .tags
324            .insert("role".to_string(), "admin".to_string());
325
326        assert_eq!(metadata.user_id, Some("user123".to_string()));
327        assert_eq!(metadata.tags.get("role"), Some(&"admin".to_string()));
328    }
329
330    #[test]
331    fn test_connection_stats() {
332        let stats = ConnectionStatistics::default();
333        stats.messages_sent.fetch_add(5, Ordering::Relaxed);
334        stats.bytes_sent.fetch_add(1024, Ordering::Relaxed);
335
336        assert_eq!(stats.messages_sent.load(Ordering::Relaxed), 5);
337        assert_eq!(stats.bytes_sent.load(Ordering::Relaxed), 1024);
338    }
339}