oxigdal_websocket/server/
connection.rs1use crate::error::{Error, Result};
4use crate::protocol::ProtocolCodec;
5use crate::protocol::message::Message;
6use futures::{SinkExt, StreamExt};
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::SystemTime;
11use tokio::net::TcpStream;
12use tokio::sync::{Mutex, mpsc};
13use tokio_tungstenite::WebSocketStream;
14use tokio_tungstenite::tungstenite::Message as WsMessage;
15use uuid::Uuid;
16
17pub type ConnectionId = Uuid;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum ConnectionState {
23 Connecting,
25 Connected,
27 Disconnecting,
29 Disconnected,
31}
32
33pub struct Connection {
35 id: ConnectionId,
37 remote_addr: SocketAddr,
39 state: Arc<Mutex<ConnectionState>>,
41 ws: Arc<Mutex<WebSocketStream<TcpStream>>>,
43 codec: Arc<ProtocolCodec>,
45 tx: mpsc::UnboundedSender<Message>,
47 last_activity: Arc<AtomicU64>,
49 metadata: Arc<Mutex<ConnectionMetadata>>,
51 stats: Arc<ConnectionStatistics>,
53}
54
55#[derive(Debug, Default, Clone)]
57pub struct ConnectionMetadata {
58 pub user_id: Option<String>,
60 pub tags: std::collections::HashMap<String, String>,
62 pub subscriptions: std::collections::HashSet<String>,
64 pub rooms: std::collections::HashSet<String>,
66}
67
68#[derive(Debug, Default)]
70pub struct ConnectionStatistics {
71 pub messages_sent: AtomicU64,
73 pub messages_received: AtomicU64,
75 pub bytes_sent: AtomicU64,
77 pub bytes_received: AtomicU64,
79 pub errors: AtomicU64,
81}
82
83impl Connection {
84 pub fn new(
86 ws: WebSocketStream<TcpStream>,
87 remote_addr: SocketAddr,
88 codec: ProtocolCodec,
89 ) -> (Self, mpsc::UnboundedReceiver<Message>) {
90 let (tx, rx) = mpsc::unbounded_channel();
91
92 let connection = Self {
93 id: Uuid::new_v4(),
94 remote_addr,
95 state: Arc::new(Mutex::new(ConnectionState::Connected)),
96 ws: Arc::new(Mutex::new(ws)),
97 codec: Arc::new(codec),
98 tx,
99 last_activity: Arc::new(AtomicU64::new(Self::current_timestamp())),
100 metadata: Arc::new(Mutex::new(ConnectionMetadata::default())),
101 stats: Arc::new(ConnectionStatistics::default()),
102 };
103
104 (connection, rx)
105 }
106
107 pub fn id(&self) -> ConnectionId {
109 self.id
110 }
111
112 pub fn remote_addr(&self) -> SocketAddr {
114 self.remote_addr
115 }
116
117 pub async fn state(&self) -> ConnectionState {
119 *self.state.lock().await
120 }
121
122 pub async fn set_state(&self, new_state: ConnectionState) {
124 let mut state = self.state.lock().await;
125 *state = new_state;
126 }
127
128 pub async fn send(&self, message: Message) -> Result<()> {
130 self.tx
131 .send(message)
132 .map_err(|e| Error::Connection(format!("Failed to send message: {}", e)))?;
133 Ok(())
134 }
135
136 pub async fn receive(&self) -> Result<Option<Message>> {
138 let mut ws = self.ws.lock().await;
139
140 match ws.next().await {
141 Some(Ok(ws_msg)) => {
142 self.update_activity();
143 self.stats.messages_received.fetch_add(1, Ordering::Relaxed);
144
145 match ws_msg {
146 WsMessage::Binary(data) => {
147 let bytes: &[u8] = &data;
148 self.stats
149 .bytes_received
150 .fetch_add(bytes.len() as u64, Ordering::Relaxed);
151 let message = self.codec.decode(bytes)?;
152 Ok(Some(message))
153 }
154 WsMessage::Text(text) => {
155 let bytes = text.as_bytes();
156 self.stats
157 .bytes_received
158 .fetch_add(bytes.len() as u64, Ordering::Relaxed);
159 let message = self.codec.decode(bytes)?;
160 Ok(Some(message))
161 }
162 WsMessage::Ping(data) => {
163 ws.send(WsMessage::Pong(data)).await?;
165 Ok(None)
166 }
167 WsMessage::Pong(_) => {
168 Ok(None)
170 }
171 WsMessage::Close(_) => {
172 self.set_state(ConnectionState::Disconnecting).await;
173 Ok(None)
174 }
175 _ => Ok(None),
176 }
177 }
178 Some(Err(e)) => {
179 self.stats.errors.fetch_add(1, Ordering::Relaxed);
180 Err(Error::WebSocket(e.to_string()))
181 }
182 None => {
183 self.set_state(ConnectionState::Disconnected).await;
184 Ok(None)
185 }
186 }
187 }
188
189 pub async fn process_outgoing(&self, mut rx: mpsc::UnboundedReceiver<Message>) -> Result<()> {
191 while let Some(message) = rx.recv().await {
192 if let Err(e) = self.send_message(message).await {
193 tracing::error!("Failed to send message: {}", e);
194 self.stats.errors.fetch_add(1, Ordering::Relaxed);
195 }
196 }
197 Ok(())
198 }
199
200 async fn send_message(&self, message: Message) -> Result<()> {
202 let encoded = self.codec.encode(&message)?;
203 self.stats
204 .bytes_sent
205 .fetch_add(encoded.len() as u64, Ordering::Relaxed);
206 self.stats.messages_sent.fetch_add(1, Ordering::Relaxed);
207
208 let mut ws = self.ws.lock().await;
209 ws.send(WsMessage::Binary(encoded.to_vec().into())).await?;
210
211 self.update_activity();
212 Ok(())
213 }
214
215 pub async fn ping(&self) -> Result<()> {
217 let mut ws = self.ws.lock().await;
218 ws.send(WsMessage::Ping(Vec::new().into())).await?;
219 self.update_activity();
220 Ok(())
221 }
222
223 pub async fn close(&self) -> Result<()> {
225 self.set_state(ConnectionState::Disconnecting).await;
226 let mut ws = self.ws.lock().await;
227 ws.close(None).await?;
228 self.set_state(ConnectionState::Disconnected).await;
229 Ok(())
230 }
231
232 pub async fn metadata(&self) -> ConnectionMetadata {
234 self.metadata.lock().await.clone()
235 }
236
237 pub async fn update_metadata<F>(&self, f: F)
239 where
240 F: FnOnce(&mut ConnectionMetadata),
241 {
242 let mut metadata = self.metadata.lock().await;
243 f(&mut metadata);
244 }
245
246 pub fn last_activity(&self) -> u64 {
248 self.last_activity.load(Ordering::Relaxed)
249 }
250
251 fn update_activity(&self) {
253 self.last_activity
254 .store(Self::current_timestamp(), Ordering::Relaxed);
255 }
256
257 fn current_timestamp() -> u64 {
259 SystemTime::now()
260 .duration_since(SystemTime::UNIX_EPOCH)
261 .map(|d| d.as_secs())
262 .unwrap_or(0)
263 }
264
265 pub fn is_idle(&self, timeout_secs: u64) -> bool {
267 let now = Self::current_timestamp();
268 let last = self.last_activity();
269 now.saturating_sub(last) > timeout_secs
270 }
271
272 pub fn stats(&self) -> ConnectionStats {
274 ConnectionStats {
275 messages_sent: self.stats.messages_sent.load(Ordering::Relaxed),
276 messages_received: self.stats.messages_received.load(Ordering::Relaxed),
277 bytes_sent: self.stats.bytes_sent.load(Ordering::Relaxed),
278 bytes_received: self.stats.bytes_received.load(Ordering::Relaxed),
279 errors: self.stats.errors.load(Ordering::Relaxed),
280 }
281 }
282}
283
284#[derive(Debug, Clone, Default)]
286pub struct ConnectionStats {
287 pub messages_sent: u64,
289 pub messages_received: u64,
291 pub bytes_sent: u64,
293 pub bytes_received: u64,
295 pub errors: u64,
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_connection_id() {
305 let id1 = Uuid::new_v4();
306 let id2 = Uuid::new_v4();
307 assert_ne!(id1, id2);
308 }
309
310 #[test]
311 fn test_connection_state() {
312 assert_eq!(ConnectionState::Connected, ConnectionState::Connected);
313 assert_ne!(ConnectionState::Connected, ConnectionState::Disconnected);
314 }
315
316 #[test]
317 fn test_connection_metadata() {
318 let mut metadata = ConnectionMetadata {
319 user_id: Some("user123".to_string()),
320 ..Default::default()
321 };
322 metadata
323 .tags
324 .insert("role".to_string(), "admin".to_string());
325
326 assert_eq!(metadata.user_id, Some("user123".to_string()));
327 assert_eq!(metadata.tags.get("role"), Some(&"admin".to_string()));
328 }
329
330 #[test]
331 fn test_connection_stats() {
332 let stats = ConnectionStatistics::default();
333 stats.messages_sent.fetch_add(5, Ordering::Relaxed);
334 stats.bytes_sent.fetch_add(1024, Ordering::Relaxed);
335
336 assert_eq!(stats.messages_sent.load(Ordering::Relaxed), 5);
337 assert_eq!(stats.bytes_sent.load(Ordering::Relaxed), 1024);
338 }
339}