aerosocket_server/
connection.rs

1//! WebSocket connection handling
2//!
3//! This module provides connection management for WebSocket clients.
4
5use aerosocket_core::{Message, Result, transport::TransportStream};
6use aerosocket_core::frame::Frame;
7use aerosocket_core::protocol::Opcode;
8use bytes::{Bytes, BytesMut};
9use std::net::SocketAddr;
10use std::time::Duration;
11
12/// Represents a WebSocket connection
13pub struct Connection {
14    /// Remote address
15    remote_addr: SocketAddr,
16    /// Local address
17    local_addr: SocketAddr,
18    /// Connection state
19    state: ConnectionState,
20    /// Connection metadata
21    metadata: ConnectionMetadata,
22    /// Transport stream
23    stream: Option<Box<dyn TransportStream>>,
24    /// Idle timeout duration
25    idle_timeout: Option<Duration>,
26    /// Last activity timestamp
27    last_activity: std::time::Instant,
28}
29
30impl std::fmt::Debug for Connection {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("Connection")
33            .field("remote_addr", &self.remote_addr)
34            .field("local_addr", &self.local_addr)
35            .field("state", &self.state)
36            .field("metadata", &self.metadata)
37            .field("stream", &"<stream>")
38            .finish()
39    }
40}
41
42/// Connection state
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum ConnectionState {
45    /// Connection is being established
46    Connecting,
47    /// Connection is established and ready
48    Connected,
49    /// Connection is closing
50    Closing,
51    /// Connection is closed
52    Closed,
53}
54
55/// Connection metadata
56#[derive(Debug, Clone)]
57pub struct ConnectionMetadata {
58    /// WebSocket subprotocol
59    pub subprotocol: Option<String>,
60    /// WebSocket extensions
61    pub extensions: Vec<String>,
62    /// Connection established time
63    pub established_at: std::time::Instant,
64    /// Last activity time
65    pub last_activity_at: std::time::Instant,
66    /// Messages sent count
67    pub messages_sent: u64,
68    /// Messages received count
69    pub messages_received: u64,
70    /// Bytes sent count
71    pub bytes_sent: u64,
72    /// Bytes received count
73    pub bytes_received: u64,
74}
75
76impl Connection {
77    /// Create a new connection
78    pub fn new(remote_addr: SocketAddr, local_addr: SocketAddr) -> Self {
79        let now = std::time::Instant::now();
80        Self {
81            remote_addr,
82            local_addr,
83            state: ConnectionState::Connecting,
84            metadata: ConnectionMetadata {
85                subprotocol: None,
86                extensions: Vec::new(),
87                established_at: now,
88                last_activity_at: now,
89                messages_sent: 0,
90                messages_received: 0,
91                bytes_sent: 0,
92                bytes_received: 0,
93            },
94            stream: None,
95            idle_timeout: None,
96            last_activity: now,
97        }
98    }
99
100    /// Create a new connection with a transport stream
101    pub fn with_stream(remote_addr: SocketAddr, local_addr: SocketAddr, stream: Box<dyn TransportStream>) -> Self {
102        let now = std::time::Instant::now();
103        Self {
104            remote_addr,
105            local_addr,
106            state: ConnectionState::Connected,
107            metadata: ConnectionMetadata {
108                subprotocol: None,
109                extensions: Vec::new(),
110                established_at: now,
111                last_activity_at: now,
112                messages_sent: 0,
113                messages_received: 0,
114                bytes_sent: 0,
115                bytes_received: 0,
116            },
117            stream: Some(stream),
118            idle_timeout: None,
119            last_activity: now,
120        }
121    }
122
123    /// Create a new connection with timeout settings
124    pub fn with_timeout(
125        remote_addr: SocketAddr, 
126        local_addr: SocketAddr, 
127        stream: Box<dyn TransportStream>,
128        idle_timeout: Option<Duration>
129    ) -> Self {
130        let now = std::time::Instant::now();
131        Self {
132            remote_addr,
133            local_addr,
134            state: ConnectionState::Connected,
135            metadata: ConnectionMetadata {
136                subprotocol: None,
137                extensions: Vec::new(),
138                established_at: now,
139                last_activity_at: now,
140                messages_sent: 0,
141                messages_received: 0,
142                bytes_sent: 0,
143                bytes_received: 0,
144            },
145            stream: Some(stream),
146            idle_timeout,
147            last_activity: now,
148        }
149    }
150
151    /// Set the transport stream
152    pub fn set_stream(&mut self, stream: Box<dyn TransportStream>) {
153        self.stream = Some(stream);
154        self.state = ConnectionState::Connected;
155    }
156
157    /// Get the remote address
158    pub fn remote_addr(&self) -> SocketAddr {
159        self.remote_addr
160    }
161
162    /// Get the local address
163    pub fn local_addr(&self) -> SocketAddr {
164        self.local_addr
165    }
166
167    /// Get the connection state
168    pub fn state(&self) -> ConnectionState {
169        self.state
170    }
171
172    /// Get the connection metadata
173    pub fn metadata(&self) -> &ConnectionMetadata {
174        &self.metadata
175    }
176
177    /// Check if the connection has timed out
178    pub fn is_timed_out(&self) -> bool {
179        if let Some(timeout) = self.idle_timeout {
180            self.last_activity.elapsed() > timeout
181        } else {
182            false
183        }
184    }
185
186    /// Get the time until the connection times out
187    pub fn time_until_timeout(&self) -> Option<Duration> {
188        self.idle_timeout.map(|timeout| {
189            let elapsed = self.last_activity.elapsed();
190            if elapsed >= timeout {
191                Duration::ZERO
192            } else {
193                timeout - elapsed
194            }
195        })
196    }
197
198    /// Update the last activity timestamp
199    fn update_activity(&mut self) {
200        self.last_activity = std::time::Instant::now();
201        self.metadata.last_activity_at = self.last_activity;
202    }
203
204    /// Set the idle timeout
205    pub fn set_idle_timeout(&mut self, timeout: Option<Duration>) {
206        self.idle_timeout = timeout;
207    }
208
209    /// Send a message
210    pub async fn send(&mut self, message: Message) -> Result<()> {
211        // Update activity timestamp before borrowing stream
212        self.update_activity();
213        
214        if let Some(stream) = &mut self.stream {
215            // Convert message to WebSocket frame
216            let frame = match message {
217                Message::Text(text) => Frame::text(text.as_bytes().to_vec()),
218                Message::Binary(data) => Frame::binary(data.as_bytes().to_vec()),
219                Message::Ping(data) => Frame::ping(data.as_bytes().to_vec()),
220                Message::Pong(data) => Frame::pong(data.as_bytes().to_vec()),
221                Message::Close(code_and_reason) => {
222                    Frame::close(code_and_reason.code(), Some(code_and_reason.reason()))
223                }
224            };
225
226            // Serialize frame to bytes
227            let frame_bytes = frame.to_bytes();
228            
229            // Send frame
230            stream.write_all(&frame_bytes).await?;
231            stream.flush().await?;
232            
233            // Update metadata
234            self.metadata.messages_sent += 1;
235            self.metadata.bytes_sent += frame_bytes.len() as u64;
236            
237            Ok(())
238        } else {
239            Err(aerosocket_core::Error::Other("Connection not established".to_string()))
240        }
241    }
242
243    /// Send a text message
244    pub async fn send_text(&mut self, text: impl AsRef<str>) -> Result<()> {
245        self.send(Message::text(text.as_ref().to_string())).await
246    }
247
248    /// Send a binary message
249    pub async fn send_binary(&mut self, data: impl Into<Bytes>) -> Result<()> {
250        self.send(Message::binary(data)).await
251    }
252
253    /// Send a ping message
254    pub async fn ping(&mut self, data: Option<&[u8]>) -> Result<()> {
255        self.send(Message::ping(data.map(|d| d.to_vec()))).await
256    }
257
258    /// Send a pong message
259    pub async fn pong(&mut self, data: Option<&[u8]>) -> Result<()> {
260        self.send(Message::pong(data.map(|d| d.to_vec()))).await
261    }
262
263    /// Send a pong message (convenience method)
264    pub async fn send_pong(&mut self) -> Result<()> {
265        self.pong(None).await
266    }
267
268    /// Receive the next message
269    pub async fn next(&mut self) -> Result<Option<Message>> {
270        // Update activity timestamp before borrowing stream
271        self.update_activity();
272        
273        if let Some(stream) = &mut self.stream {
274            let mut message_buffer = Vec::new();
275            let mut final_frame = false;
276            let mut opcode = None;
277
278            // Keep reading frames until we get a complete message
279            while !final_frame {
280                // Read frame data
281                let mut frame_buffer = BytesMut::new();
282                
283                // Read at least the frame header (2 bytes)
284                loop {
285                    let mut temp_buf = [0u8; 2];
286                    let n = stream.read(&mut temp_buf).await?;
287                    if n == 0 {
288                        self.state = ConnectionState::Closed;
289                        return Ok(None);
290                    }
291                    frame_buffer.extend_from_slice(&temp_buf[..n]);
292                    
293                    if frame_buffer.len() >= 2 {
294                        break;
295                    }
296                }
297
298                // Parse the frame to determine how much more data we need
299                match Frame::parse(&mut frame_buffer) {
300                    Ok(frame) => {
301                        // Handle control frames immediately
302                        match frame.opcode {
303                            Opcode::Ping => {
304                                let ping_data = frame.payload.to_vec();
305                                // Send pong response
306                                stream.write_all(&Frame::pong(ping_data).to_bytes()).await?;
307                                stream.flush().await?;
308                                continue;
309                            }
310                            Opcode::Pong => {
311                                // Handle pong response (update activity)
312                                // Note: We can't call update_activity here due to borrowing,
313                                // but activity is already updated at the start of next()
314                                continue;
315                            }
316                            Opcode::Close => {
317                                // Parse close frame
318                                let close_code = if frame.payload.len() >= 2 {
319                                    let code_bytes = &frame.payload[..2];
320                                    u16::from_be_bytes([code_bytes[0], code_bytes[1]])
321                                } else {
322                                    1000 // Normal closure
323                                };
324                                
325                                let close_reason = if frame.payload.len() > 2 {
326                                    String::from_utf8_lossy(&frame.payload[2..]).to_string()
327                                } else {
328                                    String::new()
329                                };
330                                
331                                self.state = ConnectionState::Closing;
332                                return Ok(Some(Message::close(Some(close_code), Some(close_reason))));
333                            }
334                            Opcode::Continuation | Opcode::Text | Opcode::Binary => {
335                                // Handle data frames
336                                if opcode.is_none() {
337                                    opcode = Some(frame.opcode);
338                                }
339                                
340                                message_buffer.extend_from_slice(&frame.payload);
341                                final_frame = frame.fin;
342                                
343                                if !final_frame && frame.opcode != Opcode::Continuation {
344                                    return Err(aerosocket_core::Error::Other("Expected continuation frame".to_string()));
345                                }
346                            }
347                            _ => {
348                                return Err(aerosocket_core::Error::Other("Unsupported opcode".to_string()));
349                            }
350                        }
351                    }
352                    Err(_e) => {
353                        // Need more data - read from stream
354                        let mut temp_buf = [0u8; 1024];
355                        match stream.read(&mut temp_buf).await {
356                            Ok(0) => {
357                                self.state = ConnectionState::Closed;
358                                return Ok(None);
359                            }
360                            Ok(n) => {
361                                frame_buffer.extend_from_slice(&temp_buf[..n]);
362                            }
363                            Err(e) => return Err(e),
364                        }
365                        continue;
366                    }
367                }
368            }
369
370            // Convert the collected message based on opcode
371            let message = match opcode.unwrap_or(Opcode::Text) {
372                Opcode::Text => {
373                    let text = String::from_utf8_lossy(&message_buffer).to_string();
374                    Message::text(text)
375                }
376                Opcode::Binary => {
377                    let data = Bytes::from(message_buffer.clone());
378                    Message::binary(data)
379                }
380                _ => return Err(aerosocket_core::Error::Other("Invalid message opcode".to_string())),
381            };
382
383            // Update metadata
384            self.metadata.messages_received += 1;
385            self.metadata.bytes_received += message_buffer.len() as u64;
386
387            Ok(Some(message))
388        } else {
389            Err(aerosocket_core::Error::Other("Connection not established".to_string()))
390        }
391    }
392
393    /// Close the connection
394    pub async fn close(&mut self, code: Option<u16>, reason: Option<&str>) -> Result<()> {
395        self.state = ConnectionState::Closing;
396        self.send(Message::close(code, reason.map(|s| s.to_string()))).await
397    }
398
399    /// Check if the connection is established
400    pub fn is_connected(&self) -> bool {
401        self.state == ConnectionState::Connected
402    }
403
404    /// Check if the connection is closed
405    pub fn is_closed(&self) -> bool {
406        self.state == ConnectionState::Closed
407    }
408
409    /// Get the connection age
410    pub fn age(&self) -> std::time::Duration {
411        self.metadata.established_at.elapsed()
412    }
413
414    /// Get the time since last activity
415    pub fn idle_time(&self) -> std::time::Duration {
416        self.metadata.last_activity_at.elapsed()
417    }
418}
419
420/// Connection handle for managing connections
421#[derive(Debug, Clone)]
422pub struct ConnectionHandle {
423    /// Connection ID
424    id: u64,
425    /// Connection reference
426    connection: std::sync::Arc<tokio::sync::Mutex<Connection>>,
427}
428
429impl ConnectionHandle {
430    /// Create a new connection handle
431    pub fn new(id: u64, connection: Connection) -> Self {
432        Self {
433            id,
434            connection: std::sync::Arc::new(tokio::sync::Mutex::new(connection)),
435        }
436    }
437
438    /// Get the connection ID
439    pub fn id(&self) -> u64 {
440        self.id
441    }
442
443    /// Try to lock the connection
444    pub async fn try_lock(&self) -> Result<tokio::sync::MutexGuard<'_, Connection>> {
445        self.connection.try_lock().map_err(|_| aerosocket_core::Error::Other("Failed to lock connection".to_string()))
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452
453    #[test]
454    fn test_connection_creation() {
455        let remote = "127.0.0.1:12345".parse().unwrap();
456        let local = "127.0.0.1:8080".parse().unwrap();
457        let conn = Connection::new(remote, local);
458
459        assert_eq!(conn.remote_addr(), remote);
460        assert_eq!(conn.local_addr(), local);
461        assert_eq!(conn.state(), ConnectionState::Connecting);
462        assert!(!conn.is_connected());
463        assert!(!conn.is_closed());
464    }
465
466    #[tokio::test]
467    async fn test_connection_handle() {
468        let remote = "127.0.0.1:12345".parse().unwrap();
469        let local = "127.0.0.1:8080".parse().unwrap();
470        let conn = Connection::new(remote, local);
471        let handle = ConnectionHandle::new(1, conn);
472
473        assert_eq!(handle.id(), 1);
474        assert!(handle.try_lock().await.is_ok());
475    }
476}