Skip to main content

aerosocket_server/
connection.rs

1//! WebSocket connection handling
2//!
3//! This module provides connection management for WebSocket clients.
4
5use aerosocket_core::frame::Frame;
6use aerosocket_core::protocol::Opcode;
7use aerosocket_core::{transport::TransportStream, Message, Result};
8use bytes::{Bytes, BytesMut};
9use std::net::SocketAddr;
10use std::time::Duration;
11
12/// Represents a WebSocket connection
13pub struct Connection {
14    /// Remote address
15    remote_addr: SocketAddr,
16    /// Local address
17    local_addr: SocketAddr,
18    /// Connection state
19    state: ConnectionState,
20    /// Connection metadata
21    pub metadata: ConnectionMetadata,
22    /// Transport stream
23    stream: Option<Box<dyn TransportStream>>,
24    /// Idle timeout duration
25    idle_timeout: Option<Duration>,
26    /// Last activity timestamp
27    last_activity: std::time::Instant,
28}
29
30impl std::fmt::Debug for Connection {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("Connection")
33            .field("remote_addr", &self.remote_addr)
34            .field("local_addr", &self.local_addr)
35            .field("state", &self.state)
36            .field("metadata", &self.metadata)
37            .field("stream", &"<stream>")
38            .finish()
39    }
40}
41
42/// Connection state
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum ConnectionState {
45    /// Connection is being established
46    Connecting,
47    /// Connection is established and ready
48    Connected,
49    /// Connection is closing
50    Closing,
51    /// Connection is closed
52    Closed,
53}
54
55/// Connection metadata
56#[derive(Debug, Clone)]
57pub struct ConnectionMetadata {
58    /// WebSocket subprotocol
59    pub subprotocol: Option<String>,
60    /// WebSocket extensions
61    pub extensions: Vec<String>,
62    /// Connection established time
63    pub established_at: std::time::Instant,
64    /// Last activity time
65    pub last_activity_at: std::time::Instant,
66    /// Messages sent count
67    pub messages_sent: u64,
68    /// Messages received count
69    pub messages_received: u64,
70    /// Bytes sent count
71    pub bytes_sent: u64,
72    /// Bytes received count
73    pub bytes_received: u64,
74    /// Whether compression was negotiated
75    pub compression_negotiated: bool,
76}
77
78impl Connection {
79    /// Create a new connection
80    pub fn new(remote_addr: SocketAddr, local_addr: SocketAddr) -> Self {
81        let now = std::time::Instant::now();
82        Self {
83            remote_addr,
84            local_addr,
85            state: ConnectionState::Connecting,
86            metadata: ConnectionMetadata {
87                subprotocol: None,
88                extensions: Vec::new(),
89                established_at: now,
90                last_activity_at: now,
91                messages_sent: 0,
92                messages_received: 0,
93                bytes_sent: 0,
94                bytes_received: 0,
95                compression_negotiated: false,
96            },
97            stream: None,
98            idle_timeout: None,
99            last_activity: now,
100        }
101    }
102
103    /// Create a new connection with a transport stream
104    pub fn with_stream(
105        remote_addr: SocketAddr,
106        local_addr: SocketAddr,
107        stream: Box<dyn TransportStream>,
108    ) -> Self {
109        let now = std::time::Instant::now();
110        Self {
111            remote_addr,
112            local_addr,
113            state: ConnectionState::Connected,
114            metadata: ConnectionMetadata {
115                subprotocol: None,
116                extensions: Vec::new(),
117                established_at: now,
118                last_activity_at: now,
119                messages_sent: 0,
120                messages_received: 0,
121                bytes_sent: 0,
122                bytes_received: 0,
123                compression_negotiated: false,
124            },
125            stream: Some(stream),
126            idle_timeout: None,
127            last_activity: now,
128        }
129    }
130
131    /// Create a new connection with timeout settings
132    pub fn with_timeout(
133        remote_addr: SocketAddr,
134        local_addr: SocketAddr,
135        stream: Box<dyn TransportStream>,
136        idle_timeout: Option<Duration>,
137    ) -> Self {
138        let now = std::time::Instant::now();
139        Self {
140            remote_addr,
141            local_addr,
142            state: ConnectionState::Connected,
143            metadata: ConnectionMetadata {
144                subprotocol: None,
145                extensions: Vec::new(),
146                established_at: now,
147                last_activity_at: now,
148                messages_sent: 0,
149                messages_received: 0,
150                bytes_sent: 0,
151                bytes_received: 0,
152                compression_negotiated: false,
153            },
154            stream: Some(stream),
155            idle_timeout,
156            last_activity: now,
157        }
158    }
159
160    /// Set the transport stream
161    pub fn set_stream(&mut self, stream: Box<dyn TransportStream>) {
162        self.stream = Some(stream);
163        self.state = ConnectionState::Connected;
164    }
165
166    /// Get the remote address
167    pub fn remote_addr(&self) -> SocketAddr {
168        self.remote_addr
169    }
170
171    /// Get the local address
172    pub fn local_addr(&self) -> SocketAddr {
173        self.local_addr
174    }
175
176    /// Get the connection state
177    pub fn state(&self) -> ConnectionState {
178        self.state
179    }
180
181    /// Get the connection metadata
182    pub fn metadata(&self) -> &ConnectionMetadata {
183        &self.metadata
184    }
185
186    /// Check if the connection has timed out
187    pub fn is_timed_out(&self) -> bool {
188        if let Some(timeout) = self.idle_timeout {
189            self.last_activity.elapsed() > timeout
190        } else {
191            false
192        }
193    }
194
195    /// Get the time until the connection times out
196    pub fn time_until_timeout(&self) -> Option<Duration> {
197        self.idle_timeout.map(|timeout| {
198            let elapsed = self.last_activity.elapsed();
199            if elapsed >= timeout {
200                Duration::ZERO
201            } else {
202                timeout - elapsed
203            }
204        })
205    }
206
207    /// Update the last activity timestamp
208    fn update_activity(&mut self) {
209        self.last_activity = std::time::Instant::now();
210        self.metadata.last_activity_at = self.last_activity;
211    }
212
213    /// Set the idle timeout
214    pub fn set_idle_timeout(&mut self, timeout: Option<Duration>) {
215        self.idle_timeout = timeout;
216    }
217
218    /// Send a message
219    pub async fn send(&mut self, message: Message) -> Result<()> {
220        // Update activity timestamp before borrowing stream
221        self.update_activity();
222
223        if let Some(stream) = &mut self.stream {
224            // Convert message to WebSocket frame
225            let frame = match message {
226                Message::Text(text) => Frame::text(text.as_bytes().to_vec()),
227                Message::Binary(data) => Frame::binary(data.as_bytes().to_vec()),
228                Message::Ping(data) => Frame::ping(data.as_bytes().to_vec()),
229                Message::Pong(data) => Frame::pong(data.as_bytes().to_vec()),
230                Message::Close(code_and_reason) => {
231                    Frame::close(code_and_reason.code(), Some(code_and_reason.reason()))
232                }
233            };
234
235            // Serialize frame to bytes
236            let frame_bytes = frame.to_bytes();
237
238            #[cfg(feature = "metrics")]
239            {
240                metrics::counter!("aerosocket_server_messages_sent_total").increment(1);
241                metrics::counter!("aerosocket_server_bytes_sent_total")
242                    .increment(frame_bytes.len() as u64);
243                metrics::histogram!("aerosocket_server_frame_size_bytes")
244                    .record(frame_bytes.len() as f64);
245            }
246
247            // Send frame
248            stream.write_all(&frame_bytes).await?;
249            stream.flush().await?;
250
251            // Update metadata
252            self.metadata.messages_sent += 1;
253            self.metadata.bytes_sent += frame_bytes.len() as u64;
254
255            Ok(())
256        } else {
257            Err(aerosocket_core::Error::Other(
258                "Connection not established".to_string(),
259            ))
260        }
261    }
262
263    /// Send a text message
264    pub async fn send_text(&mut self, text: impl AsRef<str>) -> Result<()> {
265        self.send(Message::text(text.as_ref().to_string())).await
266    }
267
268    /// Send a binary message
269    pub async fn send_binary(&mut self, data: impl Into<Bytes>) -> Result<()> {
270        self.send(Message::binary(data)).await
271    }
272
273    /// Send a ping message
274    pub async fn ping(&mut self, data: Option<&[u8]>) -> Result<()> {
275        self.send(Message::ping(data.map(|d| d.to_vec()))).await
276    }
277
278    /// Send a pong message
279    pub async fn pong(&mut self, data: Option<&[u8]>) -> Result<()> {
280        self.send(Message::pong(data.map(|d| d.to_vec()))).await
281    }
282
283    /// Send a pong message (convenience method)
284    pub async fn send_pong(&mut self) -> Result<()> {
285        self.pong(None).await
286    }
287
288    /// Receive the next message
289    pub async fn next(&mut self) -> Result<Option<Message>> {
290        // Update activity timestamp before borrowing stream
291        self.update_activity();
292
293        if let Some(stream) = &mut self.stream {
294            let mut message_buffer = Vec::new();
295            let mut final_frame = false;
296            let mut opcode = None;
297
298            // Keep reading frames until we get a complete message
299            while !final_frame {
300                // Read frame data
301                let mut frame_buffer = BytesMut::new();
302
303                // Read at least the frame header (2 bytes)
304                loop {
305                    let mut temp_buf = [0u8; 2];
306                    let n = stream.read(&mut temp_buf).await?;
307                    if n == 0 {
308                        self.state = ConnectionState::Closed;
309                        return Ok(None);
310                    }
311                    frame_buffer.extend_from_slice(&temp_buf[..n]);
312
313                    if frame_buffer.len() >= 2 {
314                        break;
315                    }
316                }
317
318                // Parse the frame to determine how much more data we need
319                match Frame::parse(&mut frame_buffer, self.metadata.compression_negotiated) {
320                    Ok(frame) => {
321                        // Handle control frames immediately
322                        match frame.opcode {
323                            Opcode::Ping => {
324                                let ping_data = frame.payload.to_vec();
325                                // Send pong response
326                                stream.write_all(&Frame::pong(ping_data).to_bytes()).await?;
327                                stream.flush().await?;
328                                continue;
329                            }
330                            Opcode::Pong => {
331                                // Handle pong response (update activity)
332                                // Note: We can't call update_activity here due to borrowing,
333                                // but activity is already updated at the start of next()
334                                continue;
335                            }
336                            Opcode::Close => {
337                                // Parse close frame
338                                let close_code = if frame.payload.len() >= 2 {
339                                    let code_bytes = &frame.payload[..2];
340                                    u16::from_be_bytes([code_bytes[0], code_bytes[1]])
341                                } else {
342                                    1000 // Normal closure
343                                };
344
345                                let close_reason = if frame.payload.len() > 2 {
346                                    String::from_utf8_lossy(&frame.payload[2..]).to_string()
347                                } else {
348                                    String::new()
349                                };
350
351                                self.state = ConnectionState::Closing;
352                                return Ok(Some(Message::close(
353                                    Some(close_code),
354                                    Some(close_reason),
355                                )));
356                            }
357                            Opcode::Continuation | Opcode::Text | Opcode::Binary => {
358                                // Handle data frames
359                                if opcode.is_none() {
360                                    opcode = Some(frame.opcode);
361                                }
362
363                                message_buffer.extend_from_slice(&frame.payload);
364                                final_frame = frame.fin;
365
366                                if !final_frame && frame.opcode != Opcode::Continuation {
367                                    return Err(aerosocket_core::Error::Other(
368                                        "Expected continuation frame".to_string(),
369                                    ));
370                                }
371                            }
372                            _ => {
373                                return Err(aerosocket_core::Error::Other(
374                                    "Unsupported opcode".to_string(),
375                                ));
376                            }
377                        }
378                    }
379                    Err(_e) => {
380                        // Need more data - read from stream
381                        let mut temp_buf = [0u8; 1024];
382                        match stream.read(&mut temp_buf).await {
383                            Ok(0) => {
384                                self.state = ConnectionState::Closed;
385                                return Ok(None);
386                            }
387                            Ok(n) => {
388                                frame_buffer.extend_from_slice(&temp_buf[..n]);
389                            }
390                            Err(e) => return Err(e),
391                        }
392                        continue;
393                    }
394                }
395            }
396
397            // Convert the collected message based on opcode
398            let message = match opcode.unwrap_or(Opcode::Text) {
399                Opcode::Text => {
400                    let text = String::from_utf8_lossy(&message_buffer).to_string();
401                    Message::text(text)
402                }
403                Opcode::Binary => {
404                    let data = Bytes::from(message_buffer.clone());
405                    Message::binary(data)
406                }
407                _ => {
408                    return Err(aerosocket_core::Error::Other(
409                        "Invalid message opcode".to_string(),
410                    ))
411                }
412            };
413
414            // Update metadata
415            self.metadata.messages_received += 1;
416            self.metadata.bytes_received += message_buffer.len() as u64;
417
418            #[cfg(feature = "metrics")]
419            {
420                metrics::counter!("aerosocket_server_messages_received_total").increment(1);
421                metrics::counter!("aerosocket_server_bytes_received_total")
422                    .increment(message_buffer.len() as u64);
423                metrics::histogram!("aerosocket_server_message_size_bytes")
424                    .record(message_buffer.len() as f64);
425            }
426
427            Ok(Some(message))
428        } else {
429            Err(aerosocket_core::Error::Other(
430                "Connection not established".to_string(),
431            ))
432        }
433    }
434
435    /// Close the connection
436    pub async fn close(&mut self, code: Option<u16>, reason: Option<&str>) -> Result<()> {
437        self.state = ConnectionState::Closing;
438        self.send(Message::close(code, reason.map(|s| s.to_string())))
439            .await
440    }
441
442    /// Check if the connection is established
443    pub fn is_connected(&self) -> bool {
444        self.state == ConnectionState::Connected
445    }
446
447    /// Check if the connection is closed
448    pub fn is_closed(&self) -> bool {
449        self.state == ConnectionState::Closed
450    }
451
452    /// Get the connection age
453    pub fn age(&self) -> std::time::Duration {
454        self.metadata.established_at.elapsed()
455    }
456
457    /// Get the time since last activity
458    pub fn idle_time(&self) -> std::time::Duration {
459        self.metadata.last_activity_at.elapsed()
460    }
461}
462
463/// Connection handle for managing connections
464#[derive(Debug, Clone)]
465pub struct ConnectionHandle {
466    /// Connection ID
467    id: u64,
468    /// Connection reference
469    connection: std::sync::Arc<tokio::sync::Mutex<Connection>>,
470}
471
472impl ConnectionHandle {
473    /// Create a new connection handle
474    pub fn new(id: u64, connection: Connection) -> Self {
475        Self {
476            id,
477            connection: std::sync::Arc::new(tokio::sync::Mutex::new(connection)),
478        }
479    }
480
481    /// Get the connection ID
482    pub fn id(&self) -> u64 {
483        self.id
484    }
485
486    /// Try to lock the connection
487    pub async fn try_lock(&self) -> Result<tokio::sync::MutexGuard<'_, Connection>> {
488        self.connection
489            .try_lock()
490            .map_err(|_| aerosocket_core::Error::Other("Failed to lock connection".to_string()))
491    }
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497
498    #[test]
499    fn test_connection_creation() {
500        let remote = "127.0.0.1:12345".parse().unwrap();
501        let local = "127.0.0.1:8080".parse().unwrap();
502        let conn = Connection::new(remote, local);
503
504        assert_eq!(conn.remote_addr(), remote);
505        assert_eq!(conn.local_addr(), local);
506        assert_eq!(conn.state(), ConnectionState::Connecting);
507        assert!(!conn.is_connected());
508        assert!(!conn.is_closed());
509    }
510
511    #[tokio::test]
512    async fn test_connection_handle() {
513        let remote = "127.0.0.1:12345".parse().unwrap();
514        let local = "127.0.0.1:8080".parse().unwrap();
515        let conn = Connection::new(remote, local);
516        let handle = ConnectionHandle::new(1, conn);
517
518        assert_eq!(handle.id(), 1);
519        assert!(handle.try_lock().await.is_ok());
520    }
521}