1use super::types::{
4 ConnectionId, ConnectionState, WebSocketMessage, WebSocketError, WebSocketResult, WebSocketConfig,
5};
6use futures_util::{SinkExt, StreamExt};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{mpsc, RwLock};
11use tokio::time;
12use tokio_tungstenite::{accept_async, tungstenite, WebSocketStream};
13use tracing::{debug, error, info};
14
15#[derive(Clone)]
17pub struct WebSocketConnection {
18 pub id: ConnectionId,
20 state: Arc<RwLock<ConnectionState>>,
22 metadata: Arc<RwLock<ConnectionMetadata>>,
24 sender: mpsc::UnboundedSender<WebSocketMessage>,
26 _config: WebSocketConfig,
28}
29
30#[derive(Debug, Clone)]
32pub struct ConnectionMetadata {
33 pub connected_at: Instant,
35 pub remote_addr: Option<String>,
37 pub user_agent: Option<String>,
39 pub custom: HashMap<String, String>,
41 pub stats: ConnectionStats,
43}
44
45#[derive(Debug, Clone, Default)]
47pub struct ConnectionStats {
48 pub messages_sent: u64,
50 pub messages_received: u64,
52 pub bytes_sent: u64,
54 pub bytes_received: u64,
56 pub last_activity: Option<Instant>,
58}
59
60impl WebSocketConnection {
61 pub async fn from_stream<S>(
63 stream: S,
64 config: WebSocketConfig,
65 ) -> WebSocketResult<Self>
66 where
67 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
68 {
69 let id = ConnectionId::new();
70 let ws_stream = accept_async(stream).await?;
71
72 let (sender, receiver) = mpsc::unbounded_channel();
73 let state = Arc::new(RwLock::new(ConnectionState::Connected));
74 let metadata = Arc::new(RwLock::new(ConnectionMetadata {
75 connected_at: Instant::now(),
76 remote_addr: None,
77 user_agent: None,
78 custom: HashMap::new(),
79 stats: ConnectionStats::default(),
80 }));
81
82 let connection = Self {
84 id,
85 state: state.clone(),
86 metadata: metadata.clone(),
87 sender,
88 _config: config.clone(),
89 };
90
91 tokio::spawn(Self::handle_connection(
93 id,
94 ws_stream,
95 receiver,
96 state,
97 metadata,
98 config,
99 ));
100
101 info!("WebSocket connection established: {}", id);
102 Ok(connection)
103 }
104
105 pub async fn send(&self, message: WebSocketMessage) -> WebSocketResult<()> {
107 if !self.is_active().await {
108 return Err(WebSocketError::ConnectionClosed);
109 }
110
111 self.sender
112 .send(message)
113 .map_err(|_| WebSocketError::SendQueueFull)?;
114
115 Ok(())
116 }
117
118 pub async fn send_text<T: Into<String>>(&self, text: T) -> WebSocketResult<()> {
120 self.send(WebSocketMessage::text(text)).await
121 }
122
123 pub async fn send_binary<T: Into<Vec<u8>>>(&self, data: T) -> WebSocketResult<()> {
125 self.send(WebSocketMessage::binary(data)).await
126 }
127
128 pub async fn ping<T: Into<Vec<u8>>>(&self, data: T) -> WebSocketResult<()> {
130 self.send(WebSocketMessage::ping(data)).await
131 }
132
133 pub async fn close(&self) -> WebSocketResult<()> {
135 self.send(WebSocketMessage::close()).await?;
136
137 let mut state = self.state.write().await;
138 *state = ConnectionState::Closing;
139
140 Ok(())
141 }
142
143 pub async fn close_with_reason(&self, code: u16, reason: String) -> WebSocketResult<()> {
145 self.send(WebSocketMessage::close_with_reason(code, reason)).await?;
146
147 let mut state = self.state.write().await;
148 *state = ConnectionState::Closing;
149
150 Ok(())
151 }
152
153 pub async fn state(&self) -> ConnectionState {
155 self.state.read().await.clone()
156 }
157
158 pub async fn is_active(&self) -> bool {
160 self.state().await.is_active()
161 }
162
163 pub async fn is_closed(&self) -> bool {
165 self.state().await.is_closed()
166 }
167
168 pub async fn metadata(&self) -> ConnectionMetadata {
170 self.metadata.read().await.clone()
171 }
172
173 pub async fn set_metadata(&self, key: String, value: String) {
175 let mut metadata = self.metadata.write().await;
176 metadata.custom.insert(key, value);
177 }
178
179 pub async fn stats(&self) -> ConnectionStats {
181 self.metadata.read().await.stats.clone()
182 }
183
184 async fn handle_connection<S>(
186 id: ConnectionId,
187 mut ws_stream: WebSocketStream<S>,
188 mut receiver: mpsc::UnboundedReceiver<WebSocketMessage>,
189 state: Arc<RwLock<ConnectionState>>,
190 metadata: Arc<RwLock<ConnectionMetadata>>,
191 config: WebSocketConfig,
192 ) where
193 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
194 {
195 debug!("Starting WebSocket handler for connection: {}", id);
196
197 let mut ping_interval = if let Some(interval) = config.ping_interval {
199 Some(time::interval(Duration::from_secs(interval)))
200 } else {
201 None
202 };
203
204 loop {
205 tokio::select! {
206 ws_msg = ws_stream.next() => {
208 match ws_msg {
209 Some(Ok(msg)) => {
210 let elif_msg = WebSocketMessage::from(msg);
211
212 {
214 let mut meta = metadata.write().await;
215 meta.stats.messages_received += 1;
216 meta.stats.last_activity = Some(Instant::now());
217
218 let bytes = match &elif_msg {
220 WebSocketMessage::Text(s) => s.len() as u64,
221 WebSocketMessage::Binary(b) => b.len() as u64,
222 _ => 0,
223 };
224 meta.stats.bytes_received += bytes;
225 }
226
227 match &elif_msg {
229 WebSocketMessage::Ping(data) => {
230 if config.auto_pong {
231 let pong_msg = tungstenite::Message::Pong(data.clone());
232 if let Err(e) = ws_stream.send(pong_msg).await {
233 error!("Failed to send pong for {}: {}", id, e);
234 break;
235 }
236 }
237 }
238 WebSocketMessage::Close(_) => {
239 info!("Received close frame for connection: {}", id);
240 break;
241 }
242 _ => {
243 debug!("Received message on {}: {:?}", id, elif_msg.message_type());
246 }
247 }
248 }
249 Some(Err(e)) => {
250 error!("WebSocket error for {}: {}", id, e);
251 let mut state_lock = state.write().await;
252 *state_lock = ConnectionState::Failed(e.to_string());
253 break;
254 }
255 None => {
256 info!("WebSocket stream ended for connection: {}", id);
257 break;
258 }
259 }
260 }
261
262 app_msg = receiver.recv() => {
264 match app_msg {
265 Some(msg) => {
266 {
268 let mut meta = metadata.write().await;
269 meta.stats.messages_sent += 1;
270 meta.stats.last_activity = Some(Instant::now());
271
272 let bytes = match &msg {
274 WebSocketMessage::Text(s) => s.len() as u64,
275 WebSocketMessage::Binary(b) => b.len() as u64,
276 _ => 0,
277 };
278 meta.stats.bytes_sent += bytes;
279 }
280
281 let tungstenite_msg = tungstenite::Message::from(msg);
282 if let Err(e) = ws_stream.send(tungstenite_msg).await {
283 error!("Failed to send message for {}: {}", id, e);
284 let mut state_lock = state.write().await;
285 *state_lock = ConnectionState::Failed(e.to_string());
286 break;
287 }
288 }
289 None => {
290 debug!("Application message channel closed for: {}", id);
291 break;
292 }
293 }
294 }
295
296 _ = async {
298 if let Some(ref mut interval) = ping_interval {
299 interval.tick().await;
300 } else {
301 std::future::pending::<()>().await;
303 }
304 } => {
305 let ping_msg = tungstenite::Message::Ping(vec![]);
307 if let Err(e) = ws_stream.send(ping_msg).await {
308 error!("Failed to send ping for {}: {}", id, e);
309 break;
310 }
311 debug!("Sent ping to connection: {}", id);
312 }
313 }
314 }
315
316 let mut state_lock = state.write().await;
318 if !matches!(*state_lock, ConnectionState::Failed(_)) {
319 *state_lock = ConnectionState::Closed;
320 }
321
322 info!("WebSocket connection handler finished: {}", id);
323 }
324}
325
326impl Drop for WebSocketConnection {
327 fn drop(&mut self) {
328 debug!("Dropping WebSocket connection: {}", self.id);
329 }
330}