1use super::types::{
4 ConnectionId, ConnectionState, WebSocketConfig, WebSocketError, WebSocketMessage,
5 WebSocketResult,
6};
7use futures_util::{SinkExt, StreamExt};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::{mpsc, RwLock};
12use tokio::time;
13use tokio_tungstenite::{accept_async, tungstenite, WebSocketStream};
14use tracing::{debug, error, info};
15
16#[derive(Clone)]
18pub struct WebSocketConnection {
19 pub id: ConnectionId,
21 state: Arc<RwLock<ConnectionState>>,
23 metadata: Arc<RwLock<ConnectionMetadata>>,
25 sender: mpsc::UnboundedSender<WebSocketMessage>,
27 _config: WebSocketConfig,
29}
30
31#[derive(Debug, Clone)]
33pub struct ConnectionMetadata {
34 pub connected_at: Instant,
36 pub remote_addr: Option<String>,
38 pub user_agent: Option<String>,
40 pub custom: HashMap<String, String>,
42 pub stats: ConnectionStats,
44}
45
46#[derive(Debug, Clone, Default)]
48pub struct ConnectionStats {
49 pub messages_sent: u64,
51 pub messages_received: u64,
53 pub bytes_sent: u64,
55 pub bytes_received: u64,
57 pub last_activity: Option<Instant>,
59}
60
61impl WebSocketConnection {
62 pub async fn from_stream<S>(stream: S, config: WebSocketConfig) -> WebSocketResult<Self>
64 where
65 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
66 {
67 let id = ConnectionId::new();
68 let ws_stream = accept_async(stream).await?;
69
70 let (sender, receiver) = mpsc::unbounded_channel();
71 let state = Arc::new(RwLock::new(ConnectionState::Connected));
72 let metadata = Arc::new(RwLock::new(ConnectionMetadata {
73 connected_at: Instant::now(),
74 remote_addr: None,
75 user_agent: None,
76 custom: HashMap::new(),
77 stats: ConnectionStats::default(),
78 }));
79
80 let connection = Self {
82 id,
83 state: state.clone(),
84 metadata: metadata.clone(),
85 sender,
86 _config: config.clone(),
87 };
88
89 tokio::spawn(Self::handle_connection(
91 id, ws_stream, receiver, state, metadata, config,
92 ));
93
94 info!("WebSocket connection established: {}", id);
95 Ok(connection)
96 }
97
98 pub async fn send(&self, message: WebSocketMessage) -> WebSocketResult<()> {
100 if !self.is_active().await {
101 return Err(WebSocketError::ConnectionClosed);
102 }
103
104 self.sender
105 .send(message)
106 .map_err(|_| WebSocketError::SendQueueFull)?;
107
108 Ok(())
109 }
110
111 pub async fn send_text<T: Into<String>>(&self, text: T) -> WebSocketResult<()> {
113 self.send(WebSocketMessage::text(text)).await
114 }
115
116 pub async fn send_binary<T: Into<Vec<u8>>>(&self, data: T) -> WebSocketResult<()> {
118 self.send(WebSocketMessage::binary(data)).await
119 }
120
121 pub async fn ping<T: Into<Vec<u8>>>(&self, data: T) -> WebSocketResult<()> {
123 self.send(WebSocketMessage::ping(data)).await
124 }
125
126 pub async fn close(&self) -> WebSocketResult<()> {
128 self.send(WebSocketMessage::close()).await?;
129
130 let mut state = self.state.write().await;
131 *state = ConnectionState::Closing;
132
133 Ok(())
134 }
135
136 pub async fn close_with_reason(&self, code: u16, reason: String) -> WebSocketResult<()> {
138 self.send(WebSocketMessage::close_with_reason(code, reason))
139 .await?;
140
141 let mut state = self.state.write().await;
142 *state = ConnectionState::Closing;
143
144 Ok(())
145 }
146
147 pub async fn state(&self) -> ConnectionState {
149 self.state.read().await.clone()
150 }
151
152 pub async fn is_active(&self) -> bool {
154 self.state().await.is_active()
155 }
156
157 pub async fn is_closed(&self) -> bool {
159 self.state().await.is_closed()
160 }
161
162 pub async fn metadata(&self) -> ConnectionMetadata {
164 self.metadata.read().await.clone()
165 }
166
167 pub async fn set_metadata(&self, key: String, value: String) {
169 let mut metadata = self.metadata.write().await;
170 metadata.custom.insert(key, value);
171 }
172
173 pub async fn stats(&self) -> ConnectionStats {
175 self.metadata.read().await.stats.clone()
176 }
177
178 async fn handle_connection<S>(
180 id: ConnectionId,
181 mut ws_stream: WebSocketStream<S>,
182 mut receiver: mpsc::UnboundedReceiver<WebSocketMessage>,
183 state: Arc<RwLock<ConnectionState>>,
184 metadata: Arc<RwLock<ConnectionMetadata>>,
185 config: WebSocketConfig,
186 ) where
187 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
188 {
189 debug!("Starting WebSocket handler for connection: {}", id);
190
191 let mut ping_interval = config
193 .ping_interval
194 .map(|interval| time::interval(Duration::from_secs(interval)));
195
196 loop {
197 tokio::select! {
198 ws_msg = ws_stream.next() => {
200 match ws_msg {
201 Some(Ok(msg)) => {
202 let elif_msg = WebSocketMessage::from(msg);
203
204 {
206 let mut meta = metadata.write().await;
207 meta.stats.messages_received += 1;
208 meta.stats.last_activity = Some(Instant::now());
209
210 let bytes = match &elif_msg {
212 WebSocketMessage::Text(s) => s.len() as u64,
213 WebSocketMessage::Binary(b) => b.len() as u64,
214 _ => 0,
215 };
216 meta.stats.bytes_received += bytes;
217 }
218
219 match &elif_msg {
221 WebSocketMessage::Ping(data) => {
222 if config.auto_pong {
223 let pong_msg = tungstenite::Message::Pong(data.clone());
224 if let Err(e) = ws_stream.send(pong_msg).await {
225 error!("Failed to send pong for {}: {}", id, e);
226 break;
227 }
228 }
229 }
230 WebSocketMessage::Close(_) => {
231 info!("Received close frame for connection: {}", id);
232 break;
233 }
234 _ => {
235 debug!("Received message on {}: {:?}", id, elif_msg.message_type());
238 }
239 }
240 }
241 Some(Err(e)) => {
242 error!("WebSocket error for {}: {}", id, e);
243 let mut state_lock = state.write().await;
244 *state_lock = ConnectionState::Failed(e.to_string());
245 break;
246 }
247 None => {
248 info!("WebSocket stream ended for connection: {}", id);
249 break;
250 }
251 }
252 }
253
254 app_msg = receiver.recv() => {
256 match app_msg {
257 Some(msg) => {
258 {
260 let mut meta = metadata.write().await;
261 meta.stats.messages_sent += 1;
262 meta.stats.last_activity = Some(Instant::now());
263
264 let bytes = match &msg {
266 WebSocketMessage::Text(s) => s.len() as u64,
267 WebSocketMessage::Binary(b) => b.len() as u64,
268 _ => 0,
269 };
270 meta.stats.bytes_sent += bytes;
271 }
272
273 let tungstenite_msg = tungstenite::Message::from(msg);
274 if let Err(e) = ws_stream.send(tungstenite_msg).await {
275 error!("Failed to send message for {}: {}", id, e);
276 let mut state_lock = state.write().await;
277 *state_lock = ConnectionState::Failed(e.to_string());
278 break;
279 }
280 }
281 None => {
282 debug!("Application message channel closed for: {}", id);
283 break;
284 }
285 }
286 }
287
288 _ = async {
290 if let Some(ref mut interval) = ping_interval {
291 interval.tick().await;
292 } else {
293 std::future::pending::<()>().await;
295 }
296 } => {
297 let ping_msg = tungstenite::Message::Ping(vec![]);
299 if let Err(e) = ws_stream.send(ping_msg).await {
300 error!("Failed to send ping for {}: {}", id, e);
301 break;
302 }
303 debug!("Sent ping to connection: {}", id);
304 }
305 }
306 }
307
308 let mut state_lock = state.write().await;
310 if !matches!(*state_lock, ConnectionState::Failed(_)) {
311 *state_lock = ConnectionState::Closed;
312 }
313
314 info!("WebSocket connection handler finished: {}", id);
315 }
316}
317
318impl Drop for WebSocketConnection {
319 fn drop(&mut self) {
320 debug!("Dropping WebSocket connection: {}", self.id);
321 }
322}