aerosocket_server/
connection.rs

1//! WebSocket connection handling
2//!
3//! This module provides connection management for WebSocket clients.
4
5use aerosocket_core::frame::Frame;
6use aerosocket_core::protocol::Opcode;
7use aerosocket_core::{transport::TransportStream, Message, Result};
8use bytes::{Bytes, BytesMut};
9use std::net::SocketAddr;
10use std::time::Duration;
11
12/// Represents a WebSocket connection
13pub struct Connection {
14    /// Remote address
15    remote_addr: SocketAddr,
16    /// Local address
17    local_addr: SocketAddr,
18    /// Connection state
19    state: ConnectionState,
20    /// Connection metadata
21    metadata: ConnectionMetadata,
22    /// Transport stream
23    stream: Option<Box<dyn TransportStream>>,
24    /// Idle timeout duration
25    idle_timeout: Option<Duration>,
26    /// Last activity timestamp
27    last_activity: std::time::Instant,
28}
29
30impl std::fmt::Debug for Connection {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("Connection")
33            .field("remote_addr", &self.remote_addr)
34            .field("local_addr", &self.local_addr)
35            .field("state", &self.state)
36            .field("metadata", &self.metadata)
37            .field("stream", &"<stream>")
38            .finish()
39    }
40}
41
42/// Connection state
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum ConnectionState {
45    /// Connection is being established
46    Connecting,
47    /// Connection is established and ready
48    Connected,
49    /// Connection is closing
50    Closing,
51    /// Connection is closed
52    Closed,
53}
54
55/// Connection metadata
56#[derive(Debug, Clone)]
57pub struct ConnectionMetadata {
58    /// WebSocket subprotocol
59    pub subprotocol: Option<String>,
60    /// WebSocket extensions
61    pub extensions: Vec<String>,
62    /// Connection established time
63    pub established_at: std::time::Instant,
64    /// Last activity time
65    pub last_activity_at: std::time::Instant,
66    /// Messages sent count
67    pub messages_sent: u64,
68    /// Messages received count
69    pub messages_received: u64,
70    /// Bytes sent count
71    pub bytes_sent: u64,
72    /// Bytes received count
73    pub bytes_received: u64,
74}
75
76impl Connection {
77    /// Create a new connection
78    pub fn new(remote_addr: SocketAddr, local_addr: SocketAddr) -> Self {
79        let now = std::time::Instant::now();
80        Self {
81            remote_addr,
82            local_addr,
83            state: ConnectionState::Connecting,
84            metadata: ConnectionMetadata {
85                subprotocol: None,
86                extensions: Vec::new(),
87                established_at: now,
88                last_activity_at: now,
89                messages_sent: 0,
90                messages_received: 0,
91                bytes_sent: 0,
92                bytes_received: 0,
93            },
94            stream: None,
95            idle_timeout: None,
96            last_activity: now,
97        }
98    }
99
100    /// Create a new connection with a transport stream
101    pub fn with_stream(
102        remote_addr: SocketAddr,
103        local_addr: SocketAddr,
104        stream: Box<dyn TransportStream>,
105    ) -> Self {
106        let now = std::time::Instant::now();
107        Self {
108            remote_addr,
109            local_addr,
110            state: ConnectionState::Connected,
111            metadata: ConnectionMetadata {
112                subprotocol: None,
113                extensions: Vec::new(),
114                established_at: now,
115                last_activity_at: now,
116                messages_sent: 0,
117                messages_received: 0,
118                bytes_sent: 0,
119                bytes_received: 0,
120            },
121            stream: Some(stream),
122            idle_timeout: None,
123            last_activity: now,
124        }
125    }
126
127    /// Create a new connection with timeout settings
128    pub fn with_timeout(
129        remote_addr: SocketAddr,
130        local_addr: SocketAddr,
131        stream: Box<dyn TransportStream>,
132        idle_timeout: Option<Duration>,
133    ) -> Self {
134        let now = std::time::Instant::now();
135        Self {
136            remote_addr,
137            local_addr,
138            state: ConnectionState::Connected,
139            metadata: ConnectionMetadata {
140                subprotocol: None,
141                extensions: Vec::new(),
142                established_at: now,
143                last_activity_at: now,
144                messages_sent: 0,
145                messages_received: 0,
146                bytes_sent: 0,
147                bytes_received: 0,
148            },
149            stream: Some(stream),
150            idle_timeout,
151            last_activity: now,
152        }
153    }
154
155    /// Set the transport stream
156    pub fn set_stream(&mut self, stream: Box<dyn TransportStream>) {
157        self.stream = Some(stream);
158        self.state = ConnectionState::Connected;
159    }
160
161    /// Get the remote address
162    pub fn remote_addr(&self) -> SocketAddr {
163        self.remote_addr
164    }
165
166    /// Get the local address
167    pub fn local_addr(&self) -> SocketAddr {
168        self.local_addr
169    }
170
171    /// Get the connection state
172    pub fn state(&self) -> ConnectionState {
173        self.state
174    }
175
176    /// Get the connection metadata
177    pub fn metadata(&self) -> &ConnectionMetadata {
178        &self.metadata
179    }
180
181    /// Check if the connection has timed out
182    pub fn is_timed_out(&self) -> bool {
183        if let Some(timeout) = self.idle_timeout {
184            self.last_activity.elapsed() > timeout
185        } else {
186            false
187        }
188    }
189
190    /// Get the time until the connection times out
191    pub fn time_until_timeout(&self) -> Option<Duration> {
192        self.idle_timeout.map(|timeout| {
193            let elapsed = self.last_activity.elapsed();
194            if elapsed >= timeout {
195                Duration::ZERO
196            } else {
197                timeout - elapsed
198            }
199        })
200    }
201
202    /// Update the last activity timestamp
203    fn update_activity(&mut self) {
204        self.last_activity = std::time::Instant::now();
205        self.metadata.last_activity_at = self.last_activity;
206    }
207
208    /// Set the idle timeout
209    pub fn set_idle_timeout(&mut self, timeout: Option<Duration>) {
210        self.idle_timeout = timeout;
211    }
212
213    /// Send a message
214    pub async fn send(&mut self, message: Message) -> Result<()> {
215        // Update activity timestamp before borrowing stream
216        self.update_activity();
217
218        if let Some(stream) = &mut self.stream {
219            // Convert message to WebSocket frame
220            let frame = match message {
221                Message::Text(text) => Frame::text(text.as_bytes().to_vec()),
222                Message::Binary(data) => Frame::binary(data.as_bytes().to_vec()),
223                Message::Ping(data) => Frame::ping(data.as_bytes().to_vec()),
224                Message::Pong(data) => Frame::pong(data.as_bytes().to_vec()),
225                Message::Close(code_and_reason) => {
226                    Frame::close(code_and_reason.code(), Some(code_and_reason.reason()))
227                }
228            };
229
230            // Serialize frame to bytes
231            let frame_bytes = frame.to_bytes();
232
233            // Send frame
234            stream.write_all(&frame_bytes).await?;
235            stream.flush().await?;
236
237            // Update metadata
238            self.metadata.messages_sent += 1;
239            self.metadata.bytes_sent += frame_bytes.len() as u64;
240
241            Ok(())
242        } else {
243            Err(aerosocket_core::Error::Other(
244                "Connection not established".to_string(),
245            ))
246        }
247    }
248
249    /// Send a text message
250    pub async fn send_text(&mut self, text: impl AsRef<str>) -> Result<()> {
251        self.send(Message::text(text.as_ref().to_string())).await
252    }
253
254    /// Send a binary message
255    pub async fn send_binary(&mut self, data: impl Into<Bytes>) -> Result<()> {
256        self.send(Message::binary(data)).await
257    }
258
259    /// Send a ping message
260    pub async fn ping(&mut self, data: Option<&[u8]>) -> Result<()> {
261        self.send(Message::ping(data.map(|d| d.to_vec()))).await
262    }
263
264    /// Send a pong message
265    pub async fn pong(&mut self, data: Option<&[u8]>) -> Result<()> {
266        self.send(Message::pong(data.map(|d| d.to_vec()))).await
267    }
268
269    /// Send a pong message (convenience method)
270    pub async fn send_pong(&mut self) -> Result<()> {
271        self.pong(None).await
272    }
273
274    /// Receive the next message
275    pub async fn next(&mut self) -> Result<Option<Message>> {
276        // Update activity timestamp before borrowing stream
277        self.update_activity();
278
279        if let Some(stream) = &mut self.stream {
280            let mut message_buffer = Vec::new();
281            let mut final_frame = false;
282            let mut opcode = None;
283
284            // Keep reading frames until we get a complete message
285            while !final_frame {
286                // Read frame data
287                let mut frame_buffer = BytesMut::new();
288
289                // Read at least the frame header (2 bytes)
290                loop {
291                    let mut temp_buf = [0u8; 2];
292                    let n = stream.read(&mut temp_buf).await?;
293                    if n == 0 {
294                        self.state = ConnectionState::Closed;
295                        return Ok(None);
296                    }
297                    frame_buffer.extend_from_slice(&temp_buf[..n]);
298
299                    if frame_buffer.len() >= 2 {
300                        break;
301                    }
302                }
303
304                // Parse the frame to determine how much more data we need
305                match Frame::parse(&mut frame_buffer) {
306                    Ok(frame) => {
307                        // Handle control frames immediately
308                        match frame.opcode {
309                            Opcode::Ping => {
310                                let ping_data = frame.payload.to_vec();
311                                // Send pong response
312                                stream.write_all(&Frame::pong(ping_data).to_bytes()).await?;
313                                stream.flush().await?;
314                                continue;
315                            }
316                            Opcode::Pong => {
317                                // Handle pong response (update activity)
318                                // Note: We can't call update_activity here due to borrowing,
319                                // but activity is already updated at the start of next()
320                                continue;
321                            }
322                            Opcode::Close => {
323                                // Parse close frame
324                                let close_code = if frame.payload.len() >= 2 {
325                                    let code_bytes = &frame.payload[..2];
326                                    u16::from_be_bytes([code_bytes[0], code_bytes[1]])
327                                } else {
328                                    1000 // Normal closure
329                                };
330
331                                let close_reason = if frame.payload.len() > 2 {
332                                    String::from_utf8_lossy(&frame.payload[2..]).to_string()
333                                } else {
334                                    String::new()
335                                };
336
337                                self.state = ConnectionState::Closing;
338                                return Ok(Some(Message::close(
339                                    Some(close_code),
340                                    Some(close_reason),
341                                )));
342                            }
343                            Opcode::Continuation | Opcode::Text | Opcode::Binary => {
344                                // Handle data frames
345                                if opcode.is_none() {
346                                    opcode = Some(frame.opcode);
347                                }
348
349                                message_buffer.extend_from_slice(&frame.payload);
350                                final_frame = frame.fin;
351
352                                if !final_frame && frame.opcode != Opcode::Continuation {
353                                    return Err(aerosocket_core::Error::Other(
354                                        "Expected continuation frame".to_string(),
355                                    ));
356                                }
357                            }
358                            _ => {
359                                return Err(aerosocket_core::Error::Other(
360                                    "Unsupported opcode".to_string(),
361                                ));
362                            }
363                        }
364                    }
365                    Err(_e) => {
366                        // Need more data - read from stream
367                        let mut temp_buf = [0u8; 1024];
368                        match stream.read(&mut temp_buf).await {
369                            Ok(0) => {
370                                self.state = ConnectionState::Closed;
371                                return Ok(None);
372                            }
373                            Ok(n) => {
374                                frame_buffer.extend_from_slice(&temp_buf[..n]);
375                            }
376                            Err(e) => return Err(e),
377                        }
378                        continue;
379                    }
380                }
381            }
382
383            // Convert the collected message based on opcode
384            let message = match opcode.unwrap_or(Opcode::Text) {
385                Opcode::Text => {
386                    let text = String::from_utf8_lossy(&message_buffer).to_string();
387                    Message::text(text)
388                }
389                Opcode::Binary => {
390                    let data = Bytes::from(message_buffer.clone());
391                    Message::binary(data)
392                }
393                _ => {
394                    return Err(aerosocket_core::Error::Other(
395                        "Invalid message opcode".to_string(),
396                    ))
397                }
398            };
399
400            // Update metadata
401            self.metadata.messages_received += 1;
402            self.metadata.bytes_received += message_buffer.len() as u64;
403
404            Ok(Some(message))
405        } else {
406            Err(aerosocket_core::Error::Other(
407                "Connection not established".to_string(),
408            ))
409        }
410    }
411
412    /// Close the connection
413    pub async fn close(&mut self, code: Option<u16>, reason: Option<&str>) -> Result<()> {
414        self.state = ConnectionState::Closing;
415        self.send(Message::close(code, reason.map(|s| s.to_string())))
416            .await
417    }
418
419    /// Check if the connection is established
420    pub fn is_connected(&self) -> bool {
421        self.state == ConnectionState::Connected
422    }
423
424    /// Check if the connection is closed
425    pub fn is_closed(&self) -> bool {
426        self.state == ConnectionState::Closed
427    }
428
429    /// Get the connection age
430    pub fn age(&self) -> std::time::Duration {
431        self.metadata.established_at.elapsed()
432    }
433
434    /// Get the time since last activity
435    pub fn idle_time(&self) -> std::time::Duration {
436        self.metadata.last_activity_at.elapsed()
437    }
438}
439
440/// Connection handle for managing connections
441#[derive(Debug, Clone)]
442pub struct ConnectionHandle {
443    /// Connection ID
444    id: u64,
445    /// Connection reference
446    connection: std::sync::Arc<tokio::sync::Mutex<Connection>>,
447}
448
449impl ConnectionHandle {
450    /// Create a new connection handle
451    pub fn new(id: u64, connection: Connection) -> Self {
452        Self {
453            id,
454            connection: std::sync::Arc::new(tokio::sync::Mutex::new(connection)),
455        }
456    }
457
458    /// Get the connection ID
459    pub fn id(&self) -> u64 {
460        self.id
461    }
462
463    /// Try to lock the connection
464    pub async fn try_lock(&self) -> Result<tokio::sync::MutexGuard<'_, Connection>> {
465        self.connection
466            .try_lock()
467            .map_err(|_| aerosocket_core::Error::Other("Failed to lock connection".to_string()))
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    #[test]
476    fn test_connection_creation() {
477        let remote = "127.0.0.1:12345".parse().unwrap();
478        let local = "127.0.0.1:8080".parse().unwrap();
479        let conn = Connection::new(remote, local);
480
481        assert_eq!(conn.remote_addr(), remote);
482        assert_eq!(conn.local_addr(), local);
483        assert_eq!(conn.state(), ConnectionState::Connecting);
484        assert!(!conn.is_connected());
485        assert!(!conn.is_closed());
486    }
487
488    #[tokio::test]
489    async fn test_connection_handle() {
490        let remote = "127.0.0.1:12345".parse().unwrap();
491        let local = "127.0.0.1:8080".parse().unwrap();
492        let conn = Connection::new(remote, local);
493        let handle = ConnectionHandle::new(1, conn);
494
495        assert_eq!(handle.id(), 1);
496        assert!(handle.try_lock().await.is_ok());
497    }
498}