1use aerosocket_core::frame::Frame;
6use aerosocket_core::protocol::Opcode;
7use aerosocket_core::{transport::TransportStream, Message, Result};
8use bytes::{Bytes, BytesMut};
9use std::net::SocketAddr;
10use std::time::Duration;
11
12pub struct Connection {
14 remote_addr: SocketAddr,
16 local_addr: SocketAddr,
18 state: ConnectionState,
20 metadata: ConnectionMetadata,
22 stream: Option<Box<dyn TransportStream>>,
24 idle_timeout: Option<Duration>,
26 last_activity: std::time::Instant,
28}
29
30impl std::fmt::Debug for Connection {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("Connection")
33 .field("remote_addr", &self.remote_addr)
34 .field("local_addr", &self.local_addr)
35 .field("state", &self.state)
36 .field("metadata", &self.metadata)
37 .field("stream", &"<stream>")
38 .finish()
39 }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum ConnectionState {
45 Connecting,
47 Connected,
49 Closing,
51 Closed,
53}
54
55#[derive(Debug, Clone)]
57pub struct ConnectionMetadata {
58 pub subprotocol: Option<String>,
60 pub extensions: Vec<String>,
62 pub established_at: std::time::Instant,
64 pub last_activity_at: std::time::Instant,
66 pub messages_sent: u64,
68 pub messages_received: u64,
70 pub bytes_sent: u64,
72 pub bytes_received: u64,
74}
75
76impl Connection {
77 pub fn new(remote_addr: SocketAddr, local_addr: SocketAddr) -> Self {
79 let now = std::time::Instant::now();
80 Self {
81 remote_addr,
82 local_addr,
83 state: ConnectionState::Connecting,
84 metadata: ConnectionMetadata {
85 subprotocol: None,
86 extensions: Vec::new(),
87 established_at: now,
88 last_activity_at: now,
89 messages_sent: 0,
90 messages_received: 0,
91 bytes_sent: 0,
92 bytes_received: 0,
93 },
94 stream: None,
95 idle_timeout: None,
96 last_activity: now,
97 }
98 }
99
100 pub fn with_stream(
102 remote_addr: SocketAddr,
103 local_addr: SocketAddr,
104 stream: Box<dyn TransportStream>,
105 ) -> Self {
106 let now = std::time::Instant::now();
107 Self {
108 remote_addr,
109 local_addr,
110 state: ConnectionState::Connected,
111 metadata: ConnectionMetadata {
112 subprotocol: None,
113 extensions: Vec::new(),
114 established_at: now,
115 last_activity_at: now,
116 messages_sent: 0,
117 messages_received: 0,
118 bytes_sent: 0,
119 bytes_received: 0,
120 },
121 stream: Some(stream),
122 idle_timeout: None,
123 last_activity: now,
124 }
125 }
126
127 pub fn with_timeout(
129 remote_addr: SocketAddr,
130 local_addr: SocketAddr,
131 stream: Box<dyn TransportStream>,
132 idle_timeout: Option<Duration>,
133 ) -> Self {
134 let now = std::time::Instant::now();
135 Self {
136 remote_addr,
137 local_addr,
138 state: ConnectionState::Connected,
139 metadata: ConnectionMetadata {
140 subprotocol: None,
141 extensions: Vec::new(),
142 established_at: now,
143 last_activity_at: now,
144 messages_sent: 0,
145 messages_received: 0,
146 bytes_sent: 0,
147 bytes_received: 0,
148 },
149 stream: Some(stream),
150 idle_timeout,
151 last_activity: now,
152 }
153 }
154
155 pub fn set_stream(&mut self, stream: Box<dyn TransportStream>) {
157 self.stream = Some(stream);
158 self.state = ConnectionState::Connected;
159 }
160
161 pub fn remote_addr(&self) -> SocketAddr {
163 self.remote_addr
164 }
165
166 pub fn local_addr(&self) -> SocketAddr {
168 self.local_addr
169 }
170
171 pub fn state(&self) -> ConnectionState {
173 self.state
174 }
175
176 pub fn metadata(&self) -> &ConnectionMetadata {
178 &self.metadata
179 }
180
181 pub fn is_timed_out(&self) -> bool {
183 if let Some(timeout) = self.idle_timeout {
184 self.last_activity.elapsed() > timeout
185 } else {
186 false
187 }
188 }
189
190 pub fn time_until_timeout(&self) -> Option<Duration> {
192 self.idle_timeout.map(|timeout| {
193 let elapsed = self.last_activity.elapsed();
194 if elapsed >= timeout {
195 Duration::ZERO
196 } else {
197 timeout - elapsed
198 }
199 })
200 }
201
202 fn update_activity(&mut self) {
204 self.last_activity = std::time::Instant::now();
205 self.metadata.last_activity_at = self.last_activity;
206 }
207
208 pub fn set_idle_timeout(&mut self, timeout: Option<Duration>) {
210 self.idle_timeout = timeout;
211 }
212
213 pub async fn send(&mut self, message: Message) -> Result<()> {
215 self.update_activity();
217
218 if let Some(stream) = &mut self.stream {
219 let frame = match message {
221 Message::Text(text) => Frame::text(text.as_bytes().to_vec()),
222 Message::Binary(data) => Frame::binary(data.as_bytes().to_vec()),
223 Message::Ping(data) => Frame::ping(data.as_bytes().to_vec()),
224 Message::Pong(data) => Frame::pong(data.as_bytes().to_vec()),
225 Message::Close(code_and_reason) => {
226 Frame::close(code_and_reason.code(), Some(code_and_reason.reason()))
227 }
228 };
229
230 let frame_bytes = frame.to_bytes();
232
233 stream.write_all(&frame_bytes).await?;
235 stream.flush().await?;
236
237 self.metadata.messages_sent += 1;
239 self.metadata.bytes_sent += frame_bytes.len() as u64;
240
241 Ok(())
242 } else {
243 Err(aerosocket_core::Error::Other(
244 "Connection not established".to_string(),
245 ))
246 }
247 }
248
249 pub async fn send_text(&mut self, text: impl AsRef<str>) -> Result<()> {
251 self.send(Message::text(text.as_ref().to_string())).await
252 }
253
254 pub async fn send_binary(&mut self, data: impl Into<Bytes>) -> Result<()> {
256 self.send(Message::binary(data)).await
257 }
258
259 pub async fn ping(&mut self, data: Option<&[u8]>) -> Result<()> {
261 self.send(Message::ping(data.map(|d| d.to_vec()))).await
262 }
263
264 pub async fn pong(&mut self, data: Option<&[u8]>) -> Result<()> {
266 self.send(Message::pong(data.map(|d| d.to_vec()))).await
267 }
268
269 pub async fn send_pong(&mut self) -> Result<()> {
271 self.pong(None).await
272 }
273
274 pub async fn next(&mut self) -> Result<Option<Message>> {
276 self.update_activity();
278
279 if let Some(stream) = &mut self.stream {
280 let mut message_buffer = Vec::new();
281 let mut final_frame = false;
282 let mut opcode = None;
283
284 while !final_frame {
286 let mut frame_buffer = BytesMut::new();
288
289 loop {
291 let mut temp_buf = [0u8; 2];
292 let n = stream.read(&mut temp_buf).await?;
293 if n == 0 {
294 self.state = ConnectionState::Closed;
295 return Ok(None);
296 }
297 frame_buffer.extend_from_slice(&temp_buf[..n]);
298
299 if frame_buffer.len() >= 2 {
300 break;
301 }
302 }
303
304 match Frame::parse(&mut frame_buffer) {
306 Ok(frame) => {
307 match frame.opcode {
309 Opcode::Ping => {
310 let ping_data = frame.payload.to_vec();
311 stream.write_all(&Frame::pong(ping_data).to_bytes()).await?;
313 stream.flush().await?;
314 continue;
315 }
316 Opcode::Pong => {
317 continue;
321 }
322 Opcode::Close => {
323 let close_code = if frame.payload.len() >= 2 {
325 let code_bytes = &frame.payload[..2];
326 u16::from_be_bytes([code_bytes[0], code_bytes[1]])
327 } else {
328 1000 };
330
331 let close_reason = if frame.payload.len() > 2 {
332 String::from_utf8_lossy(&frame.payload[2..]).to_string()
333 } else {
334 String::new()
335 };
336
337 self.state = ConnectionState::Closing;
338 return Ok(Some(Message::close(
339 Some(close_code),
340 Some(close_reason),
341 )));
342 }
343 Opcode::Continuation | Opcode::Text | Opcode::Binary => {
344 if opcode.is_none() {
346 opcode = Some(frame.opcode);
347 }
348
349 message_buffer.extend_from_slice(&frame.payload);
350 final_frame = frame.fin;
351
352 if !final_frame && frame.opcode != Opcode::Continuation {
353 return Err(aerosocket_core::Error::Other(
354 "Expected continuation frame".to_string(),
355 ));
356 }
357 }
358 _ => {
359 return Err(aerosocket_core::Error::Other(
360 "Unsupported opcode".to_string(),
361 ));
362 }
363 }
364 }
365 Err(_e) => {
366 let mut temp_buf = [0u8; 1024];
368 match stream.read(&mut temp_buf).await {
369 Ok(0) => {
370 self.state = ConnectionState::Closed;
371 return Ok(None);
372 }
373 Ok(n) => {
374 frame_buffer.extend_from_slice(&temp_buf[..n]);
375 }
376 Err(e) => return Err(e),
377 }
378 continue;
379 }
380 }
381 }
382
383 let message = match opcode.unwrap_or(Opcode::Text) {
385 Opcode::Text => {
386 let text = String::from_utf8_lossy(&message_buffer).to_string();
387 Message::text(text)
388 }
389 Opcode::Binary => {
390 let data = Bytes::from(message_buffer.clone());
391 Message::binary(data)
392 }
393 _ => {
394 return Err(aerosocket_core::Error::Other(
395 "Invalid message opcode".to_string(),
396 ))
397 }
398 };
399
400 self.metadata.messages_received += 1;
402 self.metadata.bytes_received += message_buffer.len() as u64;
403
404 Ok(Some(message))
405 } else {
406 Err(aerosocket_core::Error::Other(
407 "Connection not established".to_string(),
408 ))
409 }
410 }
411
412 pub async fn close(&mut self, code: Option<u16>, reason: Option<&str>) -> Result<()> {
414 self.state = ConnectionState::Closing;
415 self.send(Message::close(code, reason.map(|s| s.to_string())))
416 .await
417 }
418
419 pub fn is_connected(&self) -> bool {
421 self.state == ConnectionState::Connected
422 }
423
424 pub fn is_closed(&self) -> bool {
426 self.state == ConnectionState::Closed
427 }
428
429 pub fn age(&self) -> std::time::Duration {
431 self.metadata.established_at.elapsed()
432 }
433
434 pub fn idle_time(&self) -> std::time::Duration {
436 self.metadata.last_activity_at.elapsed()
437 }
438}
439
440#[derive(Debug, Clone)]
442pub struct ConnectionHandle {
443 id: u64,
445 connection: std::sync::Arc<tokio::sync::Mutex<Connection>>,
447}
448
449impl ConnectionHandle {
450 pub fn new(id: u64, connection: Connection) -> Self {
452 Self {
453 id,
454 connection: std::sync::Arc::new(tokio::sync::Mutex::new(connection)),
455 }
456 }
457
458 pub fn id(&self) -> u64 {
460 self.id
461 }
462
463 pub async fn try_lock(&self) -> Result<tokio::sync::MutexGuard<'_, Connection>> {
465 self.connection
466 .try_lock()
467 .map_err(|_| aerosocket_core::Error::Other("Failed to lock connection".to_string()))
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474
475 #[test]
476 fn test_connection_creation() {
477 let remote = "127.0.0.1:12345".parse().unwrap();
478 let local = "127.0.0.1:8080".parse().unwrap();
479 let conn = Connection::new(remote, local);
480
481 assert_eq!(conn.remote_addr(), remote);
482 assert_eq!(conn.local_addr(), local);
483 assert_eq!(conn.state(), ConnectionState::Connecting);
484 assert!(!conn.is_connected());
485 assert!(!conn.is_closed());
486 }
487
488 #[tokio::test]
489 async fn test_connection_handle() {
490 let remote = "127.0.0.1:12345".parse().unwrap();
491 let local = "127.0.0.1:8080".parse().unwrap();
492 let conn = Connection::new(remote, local);
493 let handle = ConnectionHandle::new(1, conn);
494
495 assert_eq!(handle.id(), 1);
496 assert!(handle.try_lock().await.is_ok());
497 }
498}