elif_http/websocket/
connection.rs

1//! WebSocket connection management - high-performance wrapper around tokio-tungstenite
2
3use super::types::{
4    ConnectionId, ConnectionState, WebSocketConfig, WebSocketError, WebSocketMessage,
5    WebSocketResult,
6};
7use futures_util::{SinkExt, StreamExt};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::{mpsc, RwLock};
12use tokio::time;
13use tokio_tungstenite::{accept_async, tungstenite, WebSocketStream};
14use tracing::{debug, error, info};
15
16/// WebSocket connection wrapper - clean API over tokio-tungstenite
17#[derive(Clone)]
18pub struct WebSocketConnection {
19    /// Unique connection identifier
20    pub id: ConnectionId,
21    /// Connection state
22    state: Arc<RwLock<ConnectionState>>,
23    /// Connection metadata
24    metadata: Arc<RwLock<ConnectionMetadata>>,
25    /// Message sender channel
26    sender: mpsc::UnboundedSender<WebSocketMessage>,
27    /// Configuration
28    _config: WebSocketConfig,
29}
30
31/// Connection metadata for tracking and debugging
32#[derive(Debug, Clone)]
33pub struct ConnectionMetadata {
34    /// When the connection was established
35    pub connected_at: Instant,
36    /// Remote address if available
37    pub remote_addr: Option<String>,
38    /// User agent if available
39    pub user_agent: Option<String>,
40    /// Custom metadata
41    pub custom: HashMap<String, String>,
42    /// Message statistics
43    pub stats: ConnectionStats,
44}
45
46/// Connection statistics
47#[derive(Debug, Clone, Default)]
48pub struct ConnectionStats {
49    /// Total messages sent
50    pub messages_sent: u64,
51    /// Total messages received
52    pub messages_received: u64,
53    /// Total bytes sent
54    pub bytes_sent: u64,
55    /// Total bytes received
56    pub bytes_received: u64,
57    /// Last activity timestamp
58    pub last_activity: Option<Instant>,
59}
60
61impl WebSocketConnection {
62    /// Create a new WebSocket connection from a TCP stream
63    pub async fn from_stream<S>(stream: S, config: WebSocketConfig) -> WebSocketResult<Self>
64    where
65        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
66    {
67        let id = ConnectionId::new();
68        let ws_stream = accept_async(stream).await?;
69
70        let (sender, receiver) = mpsc::unbounded_channel();
71        let state = Arc::new(RwLock::new(ConnectionState::Connected));
72        let metadata = Arc::new(RwLock::new(ConnectionMetadata {
73            connected_at: Instant::now(),
74            remote_addr: None,
75            user_agent: None,
76            custom: HashMap::new(),
77            stats: ConnectionStats::default(),
78        }));
79
80        // Start the connection handler task
81        let connection = Self {
82            id,
83            state: state.clone(),
84            metadata: metadata.clone(),
85            sender,
86            _config: config.clone(),
87        };
88
89        // Spawn the connection handler
90        tokio::spawn(Self::handle_connection(
91            id, ws_stream, receiver, state, metadata, config,
92        ));
93
94        info!("WebSocket connection established: {}", id);
95        Ok(connection)
96    }
97
98    /// Send a message to the WebSocket
99    pub async fn send(&self, message: WebSocketMessage) -> WebSocketResult<()> {
100        if !self.is_active().await {
101            return Err(WebSocketError::ConnectionClosed);
102        }
103
104        self.sender
105            .send(message)
106            .map_err(|_| WebSocketError::SendQueueFull)?;
107
108        Ok(())
109    }
110
111    /// Send a text message
112    pub async fn send_text<T: Into<String>>(&self, text: T) -> WebSocketResult<()> {
113        self.send(WebSocketMessage::text(text)).await
114    }
115
116    /// Send a binary message
117    pub async fn send_binary<T: Into<Vec<u8>>>(&self, data: T) -> WebSocketResult<()> {
118        self.send(WebSocketMessage::binary(data)).await
119    }
120
121    /// Send a ping
122    pub async fn ping<T: Into<Vec<u8>>>(&self, data: T) -> WebSocketResult<()> {
123        self.send(WebSocketMessage::ping(data)).await
124    }
125
126    /// Close the connection
127    pub async fn close(&self) -> WebSocketResult<()> {
128        self.send(WebSocketMessage::close()).await?;
129
130        let mut state = self.state.write().await;
131        *state = ConnectionState::Closing;
132
133        Ok(())
134    }
135
136    /// Close the connection with a reason
137    pub async fn close_with_reason(&self, code: u16, reason: String) -> WebSocketResult<()> {
138        self.send(WebSocketMessage::close_with_reason(code, reason))
139            .await?;
140
141        let mut state = self.state.write().await;
142        *state = ConnectionState::Closing;
143
144        Ok(())
145    }
146
147    /// Get the current connection state
148    pub async fn state(&self) -> ConnectionState {
149        self.state.read().await.clone()
150    }
151
152    /// Check if the connection is active
153    pub async fn is_active(&self) -> bool {
154        self.state().await.is_active()
155    }
156
157    /// Check if the connection is closed
158    pub async fn is_closed(&self) -> bool {
159        self.state().await.is_closed()
160    }
161
162    /// Get connection metadata
163    pub async fn metadata(&self) -> ConnectionMetadata {
164        self.metadata.read().await.clone()
165    }
166
167    /// Update connection metadata
168    pub async fn set_metadata(&self, key: String, value: String) {
169        let mut metadata = self.metadata.write().await;
170        metadata.custom.insert(key, value);
171    }
172
173    /// Get connection statistics
174    pub async fn stats(&self) -> ConnectionStats {
175        self.metadata.read().await.stats.clone()
176    }
177
178    /// Connection handler - runs the actual WebSocket loop
179    async fn handle_connection<S>(
180        id: ConnectionId,
181        mut ws_stream: WebSocketStream<S>,
182        mut receiver: mpsc::UnboundedReceiver<WebSocketMessage>,
183        state: Arc<RwLock<ConnectionState>>,
184        metadata: Arc<RwLock<ConnectionMetadata>>,
185        config: WebSocketConfig,
186    ) where
187        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
188    {
189        debug!("Starting WebSocket handler for connection: {}", id);
190
191        // Set up ping interval if configured
192        let mut ping_interval = config
193            .ping_interval
194            .map(|interval| time::interval(Duration::from_secs(interval)));
195
196        loop {
197            tokio::select! {
198                // Handle incoming messages from WebSocket
199                ws_msg = ws_stream.next() => {
200                    match ws_msg {
201                        Some(Ok(msg)) => {
202                            let elif_msg = WebSocketMessage::from(msg);
203
204                            // Update stats
205                            {
206                                let mut meta = metadata.write().await;
207                                meta.stats.messages_received += 1;
208                                meta.stats.last_activity = Some(Instant::now());
209
210                                // Estimate bytes received
211                                let bytes = match &elif_msg {
212                                    WebSocketMessage::Text(s) => s.len() as u64,
213                                    WebSocketMessage::Binary(b) => b.len() as u64,
214                                    _ => 0,
215                                };
216                                meta.stats.bytes_received += bytes;
217                            }
218
219                            // Handle control frames automatically
220                            match &elif_msg {
221                                WebSocketMessage::Ping(data) => {
222                                    if config.auto_pong {
223                                        let pong_msg = tungstenite::Message::Pong(data.clone());
224                                        if let Err(e) = ws_stream.send(pong_msg).await {
225                                            error!("Failed to send pong for {}: {}", id, e);
226                                            break;
227                                        }
228                                    }
229                                }
230                                WebSocketMessage::Close(_) => {
231                                    info!("Received close frame for connection: {}", id);
232                                    break;
233                                }
234                                _ => {
235                                    // For now, we just log other messages
236                                    // In a full implementation, we'd route these to handlers
237                                    debug!("Received message on {}: {:?}", id, elif_msg.message_type());
238                                }
239                            }
240                        }
241                        Some(Err(e)) => {
242                            error!("WebSocket error for {}: {}", id, e);
243                            let mut state_lock = state.write().await;
244                            *state_lock = ConnectionState::Failed(e.to_string());
245                            break;
246                        }
247                        None => {
248                            info!("WebSocket stream ended for connection: {}", id);
249                            break;
250                        }
251                    }
252                }
253
254                // Handle outgoing messages from application
255                app_msg = receiver.recv() => {
256                    match app_msg {
257                        Some(msg) => {
258                            // Update stats
259                            {
260                                let mut meta = metadata.write().await;
261                                meta.stats.messages_sent += 1;
262                                meta.stats.last_activity = Some(Instant::now());
263
264                                // Estimate bytes sent
265                                let bytes = match &msg {
266                                    WebSocketMessage::Text(s) => s.len() as u64,
267                                    WebSocketMessage::Binary(b) => b.len() as u64,
268                                    _ => 0,
269                                };
270                                meta.stats.bytes_sent += bytes;
271                            }
272
273                            let tungstenite_msg = tungstenite::Message::from(msg);
274                            if let Err(e) = ws_stream.send(tungstenite_msg).await {
275                                error!("Failed to send message for {}: {}", id, e);
276                                let mut state_lock = state.write().await;
277                                *state_lock = ConnectionState::Failed(e.to_string());
278                                break;
279                            }
280                        }
281                        None => {
282                            debug!("Application message channel closed for: {}", id);
283                            break;
284                        }
285                    }
286                }
287
288                // Handle ping interval
289                _ = async {
290                    if let Some(ref mut interval) = ping_interval {
291                        interval.tick().await;
292                    } else {
293                        // If no ping interval, wait indefinitely
294                        std::future::pending::<()>().await;
295                    }
296                } => {
297                    // Send ping
298                    let ping_msg = tungstenite::Message::Ping(vec![]);
299                    if let Err(e) = ws_stream.send(ping_msg).await {
300                        error!("Failed to send ping for {}: {}", id, e);
301                        break;
302                    }
303                    debug!("Sent ping to connection: {}", id);
304                }
305            }
306        }
307
308        // Connection cleanup
309        let mut state_lock = state.write().await;
310        if !matches!(*state_lock, ConnectionState::Failed(_)) {
311            *state_lock = ConnectionState::Closed;
312        }
313
314        info!("WebSocket connection handler finished: {}", id);
315    }
316}
317
318impl Drop for WebSocketConnection {
319    fn drop(&mut self) {
320        debug!("Dropping WebSocket connection: {}", self.id);
321    }
322}