1use aerosocket_core::{Message, Result, transport::TransportStream};
6use aerosocket_core::frame::Frame;
7use aerosocket_core::protocol::Opcode;
8use bytes::{Bytes, BytesMut};
9use std::net::SocketAddr;
10use std::time::Duration;
11
12pub struct Connection {
14 remote_addr: SocketAddr,
16 local_addr: SocketAddr,
18 state: ConnectionState,
20 metadata: ConnectionMetadata,
22 stream: Option<Box<dyn TransportStream>>,
24 idle_timeout: Option<Duration>,
26 last_activity: std::time::Instant,
28}
29
30impl std::fmt::Debug for Connection {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("Connection")
33 .field("remote_addr", &self.remote_addr)
34 .field("local_addr", &self.local_addr)
35 .field("state", &self.state)
36 .field("metadata", &self.metadata)
37 .field("stream", &"<stream>")
38 .finish()
39 }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum ConnectionState {
45 Connecting,
47 Connected,
49 Closing,
51 Closed,
53}
54
55#[derive(Debug, Clone)]
57pub struct ConnectionMetadata {
58 pub subprotocol: Option<String>,
60 pub extensions: Vec<String>,
62 pub established_at: std::time::Instant,
64 pub last_activity_at: std::time::Instant,
66 pub messages_sent: u64,
68 pub messages_received: u64,
70 pub bytes_sent: u64,
72 pub bytes_received: u64,
74}
75
76impl Connection {
77 pub fn new(remote_addr: SocketAddr, local_addr: SocketAddr) -> Self {
79 let now = std::time::Instant::now();
80 Self {
81 remote_addr,
82 local_addr,
83 state: ConnectionState::Connecting,
84 metadata: ConnectionMetadata {
85 subprotocol: None,
86 extensions: Vec::new(),
87 established_at: now,
88 last_activity_at: now,
89 messages_sent: 0,
90 messages_received: 0,
91 bytes_sent: 0,
92 bytes_received: 0,
93 },
94 stream: None,
95 idle_timeout: None,
96 last_activity: now,
97 }
98 }
99
100 pub fn with_stream(remote_addr: SocketAddr, local_addr: SocketAddr, stream: Box<dyn TransportStream>) -> Self {
102 let now = std::time::Instant::now();
103 Self {
104 remote_addr,
105 local_addr,
106 state: ConnectionState::Connected,
107 metadata: ConnectionMetadata {
108 subprotocol: None,
109 extensions: Vec::new(),
110 established_at: now,
111 last_activity_at: now,
112 messages_sent: 0,
113 messages_received: 0,
114 bytes_sent: 0,
115 bytes_received: 0,
116 },
117 stream: Some(stream),
118 idle_timeout: None,
119 last_activity: now,
120 }
121 }
122
123 pub fn with_timeout(
125 remote_addr: SocketAddr,
126 local_addr: SocketAddr,
127 stream: Box<dyn TransportStream>,
128 idle_timeout: Option<Duration>
129 ) -> Self {
130 let now = std::time::Instant::now();
131 Self {
132 remote_addr,
133 local_addr,
134 state: ConnectionState::Connected,
135 metadata: ConnectionMetadata {
136 subprotocol: None,
137 extensions: Vec::new(),
138 established_at: now,
139 last_activity_at: now,
140 messages_sent: 0,
141 messages_received: 0,
142 bytes_sent: 0,
143 bytes_received: 0,
144 },
145 stream: Some(stream),
146 idle_timeout,
147 last_activity: now,
148 }
149 }
150
151 pub fn set_stream(&mut self, stream: Box<dyn TransportStream>) {
153 self.stream = Some(stream);
154 self.state = ConnectionState::Connected;
155 }
156
157 pub fn remote_addr(&self) -> SocketAddr {
159 self.remote_addr
160 }
161
162 pub fn local_addr(&self) -> SocketAddr {
164 self.local_addr
165 }
166
167 pub fn state(&self) -> ConnectionState {
169 self.state
170 }
171
172 pub fn metadata(&self) -> &ConnectionMetadata {
174 &self.metadata
175 }
176
177 pub fn is_timed_out(&self) -> bool {
179 if let Some(timeout) = self.idle_timeout {
180 self.last_activity.elapsed() > timeout
181 } else {
182 false
183 }
184 }
185
186 pub fn time_until_timeout(&self) -> Option<Duration> {
188 self.idle_timeout.map(|timeout| {
189 let elapsed = self.last_activity.elapsed();
190 if elapsed >= timeout {
191 Duration::ZERO
192 } else {
193 timeout - elapsed
194 }
195 })
196 }
197
198 fn update_activity(&mut self) {
200 self.last_activity = std::time::Instant::now();
201 self.metadata.last_activity_at = self.last_activity;
202 }
203
204 pub fn set_idle_timeout(&mut self, timeout: Option<Duration>) {
206 self.idle_timeout = timeout;
207 }
208
209 pub async fn send(&mut self, message: Message) -> Result<()> {
211 self.update_activity();
213
214 if let Some(stream) = &mut self.stream {
215 let frame = match message {
217 Message::Text(text) => Frame::text(text.as_bytes().to_vec()),
218 Message::Binary(data) => Frame::binary(data.as_bytes().to_vec()),
219 Message::Ping(data) => Frame::ping(data.as_bytes().to_vec()),
220 Message::Pong(data) => Frame::pong(data.as_bytes().to_vec()),
221 Message::Close(code_and_reason) => {
222 Frame::close(code_and_reason.code(), Some(code_and_reason.reason()))
223 }
224 };
225
226 let frame_bytes = frame.to_bytes();
228
229 stream.write_all(&frame_bytes).await?;
231 stream.flush().await?;
232
233 self.metadata.messages_sent += 1;
235 self.metadata.bytes_sent += frame_bytes.len() as u64;
236
237 Ok(())
238 } else {
239 Err(aerosocket_core::Error::Other("Connection not established".to_string()))
240 }
241 }
242
243 pub async fn send_text(&mut self, text: impl AsRef<str>) -> Result<()> {
245 self.send(Message::text(text.as_ref().to_string())).await
246 }
247
248 pub async fn send_binary(&mut self, data: impl Into<Bytes>) -> Result<()> {
250 self.send(Message::binary(data)).await
251 }
252
253 pub async fn ping(&mut self, data: Option<&[u8]>) -> Result<()> {
255 self.send(Message::ping(data.map(|d| d.to_vec()))).await
256 }
257
258 pub async fn pong(&mut self, data: Option<&[u8]>) -> Result<()> {
260 self.send(Message::pong(data.map(|d| d.to_vec()))).await
261 }
262
263 pub async fn send_pong(&mut self) -> Result<()> {
265 self.pong(None).await
266 }
267
268 pub async fn next(&mut self) -> Result<Option<Message>> {
270 self.update_activity();
272
273 if let Some(stream) = &mut self.stream {
274 let mut message_buffer = Vec::new();
275 let mut final_frame = false;
276 let mut opcode = None;
277
278 while !final_frame {
280 let mut frame_buffer = BytesMut::new();
282
283 loop {
285 let mut temp_buf = [0u8; 2];
286 let n = stream.read(&mut temp_buf).await?;
287 if n == 0 {
288 self.state = ConnectionState::Closed;
289 return Ok(None);
290 }
291 frame_buffer.extend_from_slice(&temp_buf[..n]);
292
293 if frame_buffer.len() >= 2 {
294 break;
295 }
296 }
297
298 match Frame::parse(&mut frame_buffer) {
300 Ok(frame) => {
301 match frame.opcode {
303 Opcode::Ping => {
304 let ping_data = frame.payload.to_vec();
305 stream.write_all(&Frame::pong(ping_data).to_bytes()).await?;
307 stream.flush().await?;
308 continue;
309 }
310 Opcode::Pong => {
311 continue;
315 }
316 Opcode::Close => {
317 let close_code = if frame.payload.len() >= 2 {
319 let code_bytes = &frame.payload[..2];
320 u16::from_be_bytes([code_bytes[0], code_bytes[1]])
321 } else {
322 1000 };
324
325 let close_reason = if frame.payload.len() > 2 {
326 String::from_utf8_lossy(&frame.payload[2..]).to_string()
327 } else {
328 String::new()
329 };
330
331 self.state = ConnectionState::Closing;
332 return Ok(Some(Message::close(Some(close_code), Some(close_reason))));
333 }
334 Opcode::Continuation | Opcode::Text | Opcode::Binary => {
335 if opcode.is_none() {
337 opcode = Some(frame.opcode);
338 }
339
340 message_buffer.extend_from_slice(&frame.payload);
341 final_frame = frame.fin;
342
343 if !final_frame && frame.opcode != Opcode::Continuation {
344 return Err(aerosocket_core::Error::Other("Expected continuation frame".to_string()));
345 }
346 }
347 _ => {
348 return Err(aerosocket_core::Error::Other("Unsupported opcode".to_string()));
349 }
350 }
351 }
352 Err(_e) => {
353 let mut temp_buf = [0u8; 1024];
355 match stream.read(&mut temp_buf).await {
356 Ok(0) => {
357 self.state = ConnectionState::Closed;
358 return Ok(None);
359 }
360 Ok(n) => {
361 frame_buffer.extend_from_slice(&temp_buf[..n]);
362 }
363 Err(e) => return Err(e),
364 }
365 continue;
366 }
367 }
368 }
369
370 let message = match opcode.unwrap_or(Opcode::Text) {
372 Opcode::Text => {
373 let text = String::from_utf8_lossy(&message_buffer).to_string();
374 Message::text(text)
375 }
376 Opcode::Binary => {
377 let data = Bytes::from(message_buffer.clone());
378 Message::binary(data)
379 }
380 _ => return Err(aerosocket_core::Error::Other("Invalid message opcode".to_string())),
381 };
382
383 self.metadata.messages_received += 1;
385 self.metadata.bytes_received += message_buffer.len() as u64;
386
387 Ok(Some(message))
388 } else {
389 Err(aerosocket_core::Error::Other("Connection not established".to_string()))
390 }
391 }
392
393 pub async fn close(&mut self, code: Option<u16>, reason: Option<&str>) -> Result<()> {
395 self.state = ConnectionState::Closing;
396 self.send(Message::close(code, reason.map(|s| s.to_string()))).await
397 }
398
399 pub fn is_connected(&self) -> bool {
401 self.state == ConnectionState::Connected
402 }
403
404 pub fn is_closed(&self) -> bool {
406 self.state == ConnectionState::Closed
407 }
408
409 pub fn age(&self) -> std::time::Duration {
411 self.metadata.established_at.elapsed()
412 }
413
414 pub fn idle_time(&self) -> std::time::Duration {
416 self.metadata.last_activity_at.elapsed()
417 }
418}
419
420#[derive(Debug, Clone)]
422pub struct ConnectionHandle {
423 id: u64,
425 connection: std::sync::Arc<tokio::sync::Mutex<Connection>>,
427}
428
429impl ConnectionHandle {
430 pub fn new(id: u64, connection: Connection) -> Self {
432 Self {
433 id,
434 connection: std::sync::Arc::new(tokio::sync::Mutex::new(connection)),
435 }
436 }
437
438 pub fn id(&self) -> u64 {
440 self.id
441 }
442
443 pub async fn try_lock(&self) -> Result<tokio::sync::MutexGuard<'_, Connection>> {
445 self.connection.try_lock().map_err(|_| aerosocket_core::Error::Other("Failed to lock connection".to_string()))
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452
453 #[test]
454 fn test_connection_creation() {
455 let remote = "127.0.0.1:12345".parse().unwrap();
456 let local = "127.0.0.1:8080".parse().unwrap();
457 let conn = Connection::new(remote, local);
458
459 assert_eq!(conn.remote_addr(), remote);
460 assert_eq!(conn.local_addr(), local);
461 assert_eq!(conn.state(), ConnectionState::Connecting);
462 assert!(!conn.is_connected());
463 assert!(!conn.is_closed());
464 }
465
466 #[tokio::test]
467 async fn test_connection_handle() {
468 let remote = "127.0.0.1:12345".parse().unwrap();
469 let local = "127.0.0.1:8080".parse().unwrap();
470 let conn = Connection::new(remote, local);
471 let handle = ConnectionHandle::new(1, conn);
472
473 assert_eq!(handle.id(), 1);
474 assert!(handle.try_lock().await.is_ok());
475 }
476}