1use aerosocket_core::frame::Frame;
6use aerosocket_core::protocol::Opcode;
7use aerosocket_core::{transport::TransportStream, Message, Result};
8use bytes::{Bytes, BytesMut};
9use std::net::SocketAddr;
10use std::time::Duration;
11
12pub struct Connection {
14 remote_addr: SocketAddr,
16 local_addr: SocketAddr,
18 state: ConnectionState,
20 pub metadata: ConnectionMetadata,
22 stream: Option<Box<dyn TransportStream>>,
24 idle_timeout: Option<Duration>,
26 last_activity: std::time::Instant,
28}
29
30impl std::fmt::Debug for Connection {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("Connection")
33 .field("remote_addr", &self.remote_addr)
34 .field("local_addr", &self.local_addr)
35 .field("state", &self.state)
36 .field("metadata", &self.metadata)
37 .field("stream", &"<stream>")
38 .finish()
39 }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum ConnectionState {
45 Connecting,
47 Connected,
49 Closing,
51 Closed,
53}
54
55#[derive(Debug, Clone)]
57pub struct ConnectionMetadata {
58 pub subprotocol: Option<String>,
60 pub extensions: Vec<String>,
62 pub established_at: std::time::Instant,
64 pub last_activity_at: std::time::Instant,
66 pub messages_sent: u64,
68 pub messages_received: u64,
70 pub bytes_sent: u64,
72 pub bytes_received: u64,
74 pub compression_negotiated: bool,
76}
77
78impl Connection {
79 pub fn new(remote_addr: SocketAddr, local_addr: SocketAddr) -> Self {
81 let now = std::time::Instant::now();
82 Self {
83 remote_addr,
84 local_addr,
85 state: ConnectionState::Connecting,
86 metadata: ConnectionMetadata {
87 subprotocol: None,
88 extensions: Vec::new(),
89 established_at: now,
90 last_activity_at: now,
91 messages_sent: 0,
92 messages_received: 0,
93 bytes_sent: 0,
94 bytes_received: 0,
95 compression_negotiated: false,
96 },
97 stream: None,
98 idle_timeout: None,
99 last_activity: now,
100 }
101 }
102
103 pub fn with_stream(
105 remote_addr: SocketAddr,
106 local_addr: SocketAddr,
107 stream: Box<dyn TransportStream>,
108 ) -> Self {
109 let now = std::time::Instant::now();
110 Self {
111 remote_addr,
112 local_addr,
113 state: ConnectionState::Connected,
114 metadata: ConnectionMetadata {
115 subprotocol: None,
116 extensions: Vec::new(),
117 established_at: now,
118 last_activity_at: now,
119 messages_sent: 0,
120 messages_received: 0,
121 bytes_sent: 0,
122 bytes_received: 0,
123 compression_negotiated: false,
124 },
125 stream: Some(stream),
126 idle_timeout: None,
127 last_activity: now,
128 }
129 }
130
131 pub fn with_timeout(
133 remote_addr: SocketAddr,
134 local_addr: SocketAddr,
135 stream: Box<dyn TransportStream>,
136 idle_timeout: Option<Duration>,
137 ) -> Self {
138 let now = std::time::Instant::now();
139 Self {
140 remote_addr,
141 local_addr,
142 state: ConnectionState::Connected,
143 metadata: ConnectionMetadata {
144 subprotocol: None,
145 extensions: Vec::new(),
146 established_at: now,
147 last_activity_at: now,
148 messages_sent: 0,
149 messages_received: 0,
150 bytes_sent: 0,
151 bytes_received: 0,
152 compression_negotiated: false,
153 },
154 stream: Some(stream),
155 idle_timeout,
156 last_activity: now,
157 }
158 }
159
160 pub fn set_stream(&mut self, stream: Box<dyn TransportStream>) {
162 self.stream = Some(stream);
163 self.state = ConnectionState::Connected;
164 }
165
166 pub fn remote_addr(&self) -> SocketAddr {
168 self.remote_addr
169 }
170
171 pub fn local_addr(&self) -> SocketAddr {
173 self.local_addr
174 }
175
176 pub fn state(&self) -> ConnectionState {
178 self.state
179 }
180
181 pub fn metadata(&self) -> &ConnectionMetadata {
183 &self.metadata
184 }
185
186 pub fn is_timed_out(&self) -> bool {
188 if let Some(timeout) = self.idle_timeout {
189 self.last_activity.elapsed() > timeout
190 } else {
191 false
192 }
193 }
194
195 pub fn time_until_timeout(&self) -> Option<Duration> {
197 self.idle_timeout.map(|timeout| {
198 let elapsed = self.last_activity.elapsed();
199 if elapsed >= timeout {
200 Duration::ZERO
201 } else {
202 timeout - elapsed
203 }
204 })
205 }
206
207 fn update_activity(&mut self) {
209 self.last_activity = std::time::Instant::now();
210 self.metadata.last_activity_at = self.last_activity;
211 }
212
213 pub fn set_idle_timeout(&mut self, timeout: Option<Duration>) {
215 self.idle_timeout = timeout;
216 }
217
218 pub async fn send(&mut self, message: Message) -> Result<()> {
220 self.update_activity();
222
223 if let Some(stream) = &mut self.stream {
224 let frame = match message {
226 Message::Text(text) => Frame::text(text.as_bytes().to_vec()),
227 Message::Binary(data) => Frame::binary(data.as_bytes().to_vec()),
228 Message::Ping(data) => Frame::ping(data.as_bytes().to_vec()),
229 Message::Pong(data) => Frame::pong(data.as_bytes().to_vec()),
230 Message::Close(code_and_reason) => {
231 Frame::close(code_and_reason.code(), Some(code_and_reason.reason()))
232 }
233 };
234
235 let frame_bytes = frame.to_bytes();
237
238 #[cfg(feature = "metrics")]
239 {
240 metrics::counter!("aerosocket_server_messages_sent_total").increment(1);
241 metrics::counter!("aerosocket_server_bytes_sent_total")
242 .increment(frame_bytes.len() as u64);
243 metrics::histogram!("aerosocket_server_frame_size_bytes")
244 .record(frame_bytes.len() as f64);
245 }
246
247 stream.write_all(&frame_bytes).await?;
249 stream.flush().await?;
250
251 self.metadata.messages_sent += 1;
253 self.metadata.bytes_sent += frame_bytes.len() as u64;
254
255 Ok(())
256 } else {
257 Err(aerosocket_core::Error::Other(
258 "Connection not established".to_string(),
259 ))
260 }
261 }
262
263 pub async fn send_text(&mut self, text: impl AsRef<str>) -> Result<()> {
265 self.send(Message::text(text.as_ref().to_string())).await
266 }
267
268 pub async fn send_binary(&mut self, data: impl Into<Bytes>) -> Result<()> {
270 self.send(Message::binary(data)).await
271 }
272
273 pub async fn ping(&mut self, data: Option<&[u8]>) -> Result<()> {
275 self.send(Message::ping(data.map(|d| d.to_vec()))).await
276 }
277
278 pub async fn pong(&mut self, data: Option<&[u8]>) -> Result<()> {
280 self.send(Message::pong(data.map(|d| d.to_vec()))).await
281 }
282
283 pub async fn send_pong(&mut self) -> Result<()> {
285 self.pong(None).await
286 }
287
288 pub async fn next(&mut self) -> Result<Option<Message>> {
290 self.update_activity();
292
293 if let Some(stream) = &mut self.stream {
294 let mut message_buffer = Vec::new();
295 let mut final_frame = false;
296 let mut opcode = None;
297
298 while !final_frame {
300 let mut frame_buffer = BytesMut::new();
302
303 loop {
305 let mut temp_buf = [0u8; 2];
306 let n = stream.read(&mut temp_buf).await?;
307 if n == 0 {
308 self.state = ConnectionState::Closed;
309 return Ok(None);
310 }
311 frame_buffer.extend_from_slice(&temp_buf[..n]);
312
313 if frame_buffer.len() >= 2 {
314 break;
315 }
316 }
317
318 match Frame::parse(&mut frame_buffer, self.metadata.compression_negotiated) {
320 Ok(frame) => {
321 match frame.opcode {
323 Opcode::Ping => {
324 let ping_data = frame.payload.to_vec();
325 stream.write_all(&Frame::pong(ping_data).to_bytes()).await?;
327 stream.flush().await?;
328 continue;
329 }
330 Opcode::Pong => {
331 continue;
335 }
336 Opcode::Close => {
337 let close_code = if frame.payload.len() >= 2 {
339 let code_bytes = &frame.payload[..2];
340 u16::from_be_bytes([code_bytes[0], code_bytes[1]])
341 } else {
342 1000 };
344
345 let close_reason = if frame.payload.len() > 2 {
346 String::from_utf8_lossy(&frame.payload[2..]).to_string()
347 } else {
348 String::new()
349 };
350
351 self.state = ConnectionState::Closing;
352 return Ok(Some(Message::close(
353 Some(close_code),
354 Some(close_reason),
355 )));
356 }
357 Opcode::Continuation | Opcode::Text | Opcode::Binary => {
358 if opcode.is_none() {
360 opcode = Some(frame.opcode);
361 }
362
363 message_buffer.extend_from_slice(&frame.payload);
364 final_frame = frame.fin;
365
366 if !final_frame && frame.opcode != Opcode::Continuation {
367 return Err(aerosocket_core::Error::Other(
368 "Expected continuation frame".to_string(),
369 ));
370 }
371 }
372 _ => {
373 return Err(aerosocket_core::Error::Other(
374 "Unsupported opcode".to_string(),
375 ));
376 }
377 }
378 }
379 Err(_e) => {
380 let mut temp_buf = [0u8; 1024];
382 match stream.read(&mut temp_buf).await {
383 Ok(0) => {
384 self.state = ConnectionState::Closed;
385 return Ok(None);
386 }
387 Ok(n) => {
388 frame_buffer.extend_from_slice(&temp_buf[..n]);
389 }
390 Err(e) => return Err(e),
391 }
392 continue;
393 }
394 }
395 }
396
397 let message = match opcode.unwrap_or(Opcode::Text) {
399 Opcode::Text => {
400 let text = String::from_utf8_lossy(&message_buffer).to_string();
401 Message::text(text)
402 }
403 Opcode::Binary => {
404 let data = Bytes::from(message_buffer.clone());
405 Message::binary(data)
406 }
407 _ => {
408 return Err(aerosocket_core::Error::Other(
409 "Invalid message opcode".to_string(),
410 ))
411 }
412 };
413
414 self.metadata.messages_received += 1;
416 self.metadata.bytes_received += message_buffer.len() as u64;
417
418 #[cfg(feature = "metrics")]
419 {
420 metrics::counter!("aerosocket_server_messages_received_total").increment(1);
421 metrics::counter!("aerosocket_server_bytes_received_total")
422 .increment(message_buffer.len() as u64);
423 metrics::histogram!("aerosocket_server_message_size_bytes")
424 .record(message_buffer.len() as f64);
425 }
426
427 Ok(Some(message))
428 } else {
429 Err(aerosocket_core::Error::Other(
430 "Connection not established".to_string(),
431 ))
432 }
433 }
434
435 pub async fn close(&mut self, code: Option<u16>, reason: Option<&str>) -> Result<()> {
437 self.state = ConnectionState::Closing;
438 self.send(Message::close(code, reason.map(|s| s.to_string())))
439 .await
440 }
441
442 pub fn is_connected(&self) -> bool {
444 self.state == ConnectionState::Connected
445 }
446
447 pub fn is_closed(&self) -> bool {
449 self.state == ConnectionState::Closed
450 }
451
452 pub fn age(&self) -> std::time::Duration {
454 self.metadata.established_at.elapsed()
455 }
456
457 pub fn idle_time(&self) -> std::time::Duration {
459 self.metadata.last_activity_at.elapsed()
460 }
461}
462
463#[derive(Debug, Clone)]
465pub struct ConnectionHandle {
466 id: u64,
468 connection: std::sync::Arc<tokio::sync::Mutex<Connection>>,
470}
471
472impl ConnectionHandle {
473 pub fn new(id: u64, connection: Connection) -> Self {
475 Self {
476 id,
477 connection: std::sync::Arc::new(tokio::sync::Mutex::new(connection)),
478 }
479 }
480
481 pub fn id(&self) -> u64 {
483 self.id
484 }
485
486 pub async fn try_lock(&self) -> Result<tokio::sync::MutexGuard<'_, Connection>> {
488 self.connection
489 .try_lock()
490 .map_err(|_| aerosocket_core::Error::Other("Failed to lock connection".to_string()))
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497
498 #[test]
499 fn test_connection_creation() {
500 let remote = "127.0.0.1:12345".parse().unwrap();
501 let local = "127.0.0.1:8080".parse().unwrap();
502 let conn = Connection::new(remote, local);
503
504 assert_eq!(conn.remote_addr(), remote);
505 assert_eq!(conn.local_addr(), local);
506 assert_eq!(conn.state(), ConnectionState::Connecting);
507 assert!(!conn.is_connected());
508 assert!(!conn.is_closed());
509 }
510
511 #[tokio::test]
512 async fn test_connection_handle() {
513 let remote = "127.0.0.1:12345".parse().unwrap();
514 let local = "127.0.0.1:8080".parse().unwrap();
515 let conn = Connection::new(remote, local);
516 let handle = ConnectionHandle::new(1, conn);
517
518 assert_eq!(handle.id(), 1);
519 assert!(handle.try_lock().await.is_ok());
520 }
521}