elif_http/websocket/
connection.rs

1//! WebSocket connection management - high-performance wrapper around tokio-tungstenite
2
3use super::types::{
4    ConnectionId, ConnectionState, WebSocketMessage, WebSocketError, WebSocketResult, WebSocketConfig,
5};
6use futures_util::{SinkExt, StreamExt};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{mpsc, RwLock};
11use tokio::time;
12use tokio_tungstenite::{accept_async, tungstenite, WebSocketStream};
13use tracing::{debug, error, info};
14
15/// WebSocket connection wrapper - clean API over tokio-tungstenite
16#[derive(Clone)]
17pub struct WebSocketConnection {
18    /// Unique connection identifier
19    pub id: ConnectionId,
20    /// Connection state
21    state: Arc<RwLock<ConnectionState>>,
22    /// Connection metadata
23    metadata: Arc<RwLock<ConnectionMetadata>>,
24    /// Message sender channel
25    sender: mpsc::UnboundedSender<WebSocketMessage>,
26    /// Configuration
27    _config: WebSocketConfig,
28}
29
30/// Connection metadata for tracking and debugging
31#[derive(Debug, Clone)]
32pub struct ConnectionMetadata {
33    /// When the connection was established
34    pub connected_at: Instant,
35    /// Remote address if available
36    pub remote_addr: Option<String>,
37    /// User agent if available
38    pub user_agent: Option<String>,
39    /// Custom metadata
40    pub custom: HashMap<String, String>,
41    /// Message statistics
42    pub stats: ConnectionStats,
43}
44
45/// Connection statistics
46#[derive(Debug, Clone, Default)]
47pub struct ConnectionStats {
48    /// Total messages sent
49    pub messages_sent: u64,
50    /// Total messages received
51    pub messages_received: u64,
52    /// Total bytes sent
53    pub bytes_sent: u64,
54    /// Total bytes received
55    pub bytes_received: u64,
56    /// Last activity timestamp
57    pub last_activity: Option<Instant>,
58}
59
60impl WebSocketConnection {
61    /// Create a new WebSocket connection from a TCP stream
62    pub async fn from_stream<S>(
63        stream: S,
64        config: WebSocketConfig,
65    ) -> WebSocketResult<Self>
66    where
67        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
68    {
69        let id = ConnectionId::new();
70        let ws_stream = accept_async(stream).await?;
71        
72        let (sender, receiver) = mpsc::unbounded_channel();
73        let state = Arc::new(RwLock::new(ConnectionState::Connected));
74        let metadata = Arc::new(RwLock::new(ConnectionMetadata {
75            connected_at: Instant::now(),
76            remote_addr: None,
77            user_agent: None,
78            custom: HashMap::new(),
79            stats: ConnectionStats::default(),
80        }));
81
82        // Start the connection handler task
83        let connection = Self {
84            id,
85            state: state.clone(),
86            metadata: metadata.clone(),
87            sender,
88            _config: config.clone(),
89        };
90
91        // Spawn the connection handler
92        tokio::spawn(Self::handle_connection(
93            id,
94            ws_stream,
95            receiver,
96            state,
97            metadata,
98            config,
99        ));
100
101        info!("WebSocket connection established: {}", id);
102        Ok(connection)
103    }
104
105    /// Send a message to the WebSocket
106    pub async fn send(&self, message: WebSocketMessage) -> WebSocketResult<()> {
107        if !self.is_active().await {
108            return Err(WebSocketError::ConnectionClosed);
109        }
110
111        self.sender
112            .send(message)
113            .map_err(|_| WebSocketError::SendQueueFull)?;
114        
115        Ok(())
116    }
117
118    /// Send a text message
119    pub async fn send_text<T: Into<String>>(&self, text: T) -> WebSocketResult<()> {
120        self.send(WebSocketMessage::text(text)).await
121    }
122
123    /// Send a binary message
124    pub async fn send_binary<T: Into<Vec<u8>>>(&self, data: T) -> WebSocketResult<()> {
125        self.send(WebSocketMessage::binary(data)).await
126    }
127
128    /// Send a ping
129    pub async fn ping<T: Into<Vec<u8>>>(&self, data: T) -> WebSocketResult<()> {
130        self.send(WebSocketMessage::ping(data)).await
131    }
132
133    /// Close the connection
134    pub async fn close(&self) -> WebSocketResult<()> {
135        self.send(WebSocketMessage::close()).await?;
136        
137        let mut state = self.state.write().await;
138        *state = ConnectionState::Closing;
139        
140        Ok(())
141    }
142
143    /// Close the connection with a reason
144    pub async fn close_with_reason(&self, code: u16, reason: String) -> WebSocketResult<()> {
145        self.send(WebSocketMessage::close_with_reason(code, reason)).await?;
146        
147        let mut state = self.state.write().await;
148        *state = ConnectionState::Closing;
149        
150        Ok(())
151    }
152
153    /// Get the current connection state
154    pub async fn state(&self) -> ConnectionState {
155        self.state.read().await.clone()
156    }
157
158    /// Check if the connection is active
159    pub async fn is_active(&self) -> bool {
160        self.state().await.is_active()
161    }
162
163    /// Check if the connection is closed
164    pub async fn is_closed(&self) -> bool {
165        self.state().await.is_closed()
166    }
167
168    /// Get connection metadata
169    pub async fn metadata(&self) -> ConnectionMetadata {
170        self.metadata.read().await.clone()
171    }
172
173    /// Update connection metadata
174    pub async fn set_metadata(&self, key: String, value: String) {
175        let mut metadata = self.metadata.write().await;
176        metadata.custom.insert(key, value);
177    }
178
179    /// Get connection statistics
180    pub async fn stats(&self) -> ConnectionStats {
181        self.metadata.read().await.stats.clone()
182    }
183
184    /// Connection handler - runs the actual WebSocket loop
185    async fn handle_connection<S>(
186        id: ConnectionId,
187        mut ws_stream: WebSocketStream<S>,
188        mut receiver: mpsc::UnboundedReceiver<WebSocketMessage>,
189        state: Arc<RwLock<ConnectionState>>,
190        metadata: Arc<RwLock<ConnectionMetadata>>,
191        config: WebSocketConfig,
192    ) where
193        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
194    {
195        debug!("Starting WebSocket handler for connection: {}", id);
196
197        // Set up ping interval if configured
198        let mut ping_interval = if let Some(interval) = config.ping_interval {
199            Some(time::interval(Duration::from_secs(interval)))
200        } else {
201            None
202        };
203
204        loop {
205            tokio::select! {
206                // Handle incoming messages from WebSocket
207                ws_msg = ws_stream.next() => {
208                    match ws_msg {
209                        Some(Ok(msg)) => {
210                            let elif_msg = WebSocketMessage::from(msg);
211                            
212                            // Update stats
213                            {
214                                let mut meta = metadata.write().await;
215                                meta.stats.messages_received += 1;
216                                meta.stats.last_activity = Some(Instant::now());
217                                
218                                // Estimate bytes received
219                                let bytes = match &elif_msg {
220                                    WebSocketMessage::Text(s) => s.len() as u64,
221                                    WebSocketMessage::Binary(b) => b.len() as u64,
222                                    _ => 0,
223                                };
224                                meta.stats.bytes_received += bytes;
225                            }
226
227                            // Handle control frames automatically
228                            match &elif_msg {
229                                WebSocketMessage::Ping(data) => {
230                                    if config.auto_pong {
231                                        let pong_msg = tungstenite::Message::Pong(data.clone());
232                                        if let Err(e) = ws_stream.send(pong_msg).await {
233                                            error!("Failed to send pong for {}: {}", id, e);
234                                            break;
235                                        }
236                                    }
237                                }
238                                WebSocketMessage::Close(_) => {
239                                    info!("Received close frame for connection: {}", id);
240                                    break;
241                                }
242                                _ => {
243                                    // For now, we just log other messages
244                                    // In a full implementation, we'd route these to handlers
245                                    debug!("Received message on {}: {:?}", id, elif_msg.message_type());
246                                }
247                            }
248                        }
249                        Some(Err(e)) => {
250                            error!("WebSocket error for {}: {}", id, e);
251                            let mut state_lock = state.write().await;
252                            *state_lock = ConnectionState::Failed(e.to_string());
253                            break;
254                        }
255                        None => {
256                            info!("WebSocket stream ended for connection: {}", id);
257                            break;
258                        }
259                    }
260                }
261
262                // Handle outgoing messages from application
263                app_msg = receiver.recv() => {
264                    match app_msg {
265                        Some(msg) => {
266                            // Update stats
267                            {
268                                let mut meta = metadata.write().await;
269                                meta.stats.messages_sent += 1;
270                                meta.stats.last_activity = Some(Instant::now());
271                                
272                                // Estimate bytes sent
273                                let bytes = match &msg {
274                                    WebSocketMessage::Text(s) => s.len() as u64,
275                                    WebSocketMessage::Binary(b) => b.len() as u64,
276                                    _ => 0,
277                                };
278                                meta.stats.bytes_sent += bytes;
279                            }
280
281                            let tungstenite_msg = tungstenite::Message::from(msg);
282                            if let Err(e) = ws_stream.send(tungstenite_msg).await {
283                                error!("Failed to send message for {}: {}", id, e);
284                                let mut state_lock = state.write().await;
285                                *state_lock = ConnectionState::Failed(e.to_string());
286                                break;
287                            }
288                        }
289                        None => {
290                            debug!("Application message channel closed for: {}", id);
291                            break;
292                        }
293                    }
294                }
295
296                // Handle ping interval
297                _ = async {
298                    if let Some(ref mut interval) = ping_interval {
299                        interval.tick().await;
300                    } else {
301                        // If no ping interval, wait indefinitely
302                        std::future::pending::<()>().await;
303                    }
304                } => {
305                    // Send ping
306                    let ping_msg = tungstenite::Message::Ping(vec![]);
307                    if let Err(e) = ws_stream.send(ping_msg).await {
308                        error!("Failed to send ping for {}: {}", id, e);
309                        break;
310                    }
311                    debug!("Sent ping to connection: {}", id);
312                }
313            }
314        }
315
316        // Connection cleanup
317        let mut state_lock = state.write().await;
318        if !matches!(*state_lock, ConnectionState::Failed(_)) {
319            *state_lock = ConnectionState::Closed;
320        }
321        
322        info!("WebSocket connection handler finished: {}", id);
323    }
324}
325
326impl Drop for WebSocketConnection {
327    fn drop(&mut self) {
328        debug!("Dropping WebSocket connection: {}", self.id);
329    }
330}