1use std::future::Future;
7use std::io;
8use std::net::SocketAddr;
9use std::pin::Pin;
10use std::sync::{
11 Arc,
12 atomic::{AtomicBool, Ordering},
13};
14use std::task::{Context, Poll};
15
16use bytes::Bytes;
17use thiserror::Error;
18use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
19use tokio::sync::{mpsc, oneshot};
20
21use crate::concurrency::FastMutex;
22use crate::server::{PeerDisconnectReason, PeerId, SendOptions};
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25pub struct ConnectionId(u64);
27
28impl ConnectionId {
29 pub const fn from_u64(value: u64) -> Self {
31 Self(value)
32 }
33
34 pub const fn as_u64(self) -> u64 {
36 self.0
37 }
38}
39
40impl From<PeerId> for ConnectionId {
41 fn from(value: PeerId) -> Self {
42 Self::from_u64(value.as_u64())
43 }
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub struct ConnectionMetadata {
49 id: ConnectionId,
50 remote_addr: SocketAddr,
51}
52
53impl ConnectionMetadata {
54 pub const fn id(self) -> ConnectionId {
56 self.id
57 }
58
59 pub const fn remote_addr(self) -> SocketAddr {
61 self.remote_addr
62 }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
66pub enum RemoteDisconnectReason {
68 Requested,
69 RemoteDisconnectionNotification { reason_code: Option<u8> },
70 RemoteDetectLostConnection,
71 WorkerStopped { shard_id: usize },
72}
73
74impl From<PeerDisconnectReason> for RemoteDisconnectReason {
75 fn from(value: PeerDisconnectReason) -> Self {
76 match value {
77 PeerDisconnectReason::Requested => Self::Requested,
78 PeerDisconnectReason::RemoteDisconnectionNotification { reason_code } => {
79 Self::RemoteDisconnectionNotification { reason_code }
80 }
81 PeerDisconnectReason::RemoteDetectLostConnection => Self::RemoteDetectLostConnection,
82 PeerDisconnectReason::WorkerStopped { shard_id } => Self::WorkerStopped { shard_id },
83 }
84 }
85}
86
87#[derive(Debug, Clone, PartialEq, Eq)]
88pub enum ConnectionCloseReason {
90 RequestedByLocal,
91 PeerDisconnected(RemoteDisconnectReason),
92 ListenerStopped,
93 InboundBackpressure,
94 TransportError(String),
95}
96
97#[derive(Debug, Error, Clone, PartialEq, Eq)]
98pub enum RecvError {
100 #[error("connection closed: {reason:?}")]
101 ConnectionClosed { reason: ConnectionCloseReason },
102 #[error("decode error: {message}")]
103 DecodeError { message: String },
104 #[error("connection receive channel closed")]
105 ChannelClosed,
106}
107
108pub mod queue {
109 use thiserror::Error;
110
111 #[derive(Debug, Error, Clone, PartialEq, Eq)]
112 pub enum SendQueueError {
114 #[error("connection command channel closed")]
115 CommandChannelClosed,
116 #[error("connection command response dropped")]
117 ResponseDropped,
118 #[error("transport send failed: {message}")]
119 Transport { message: String },
120 }
121}
122
123#[derive(Debug)]
124pub(crate) enum ConnectionInbound {
125 Packet(Bytes),
126 DecodeError(String),
127 Closed(ConnectionCloseReason),
128}
129
130#[derive(Debug)]
131pub(crate) enum ConnectionCommand {
132 Send {
133 peer_id: PeerId,
134 payload: Bytes,
135 options: SendOptions,
136 response: oneshot::Sender<io::Result<()>>,
137 },
138 Disconnect {
139 peer_id: PeerId,
140 response: oneshot::Sender<io::Result<()>>,
141 },
142 DisconnectNoWait {
143 peer_id: PeerId,
144 },
145 Shutdown {
146 response: oneshot::Sender<io::Result<()>>,
147 },
148}
149
150#[derive(Debug)]
151pub(crate) struct ConnectionSharedState {
152 closed: AtomicBool,
153 close_reason: FastMutex<Option<ConnectionCloseReason>>,
154}
155
156impl ConnectionSharedState {
157 pub(crate) fn new() -> Self {
158 Self {
159 closed: AtomicBool::new(false),
160 close_reason: FastMutex::new(None),
161 }
162 }
163
164 pub(crate) fn mark_closed(&self, reason: ConnectionCloseReason) {
165 self.closed.store(true, Ordering::Release);
166 *self.close_reason.lock() = Some(reason);
167 }
168
169 pub(crate) fn is_closed(&self) -> bool {
170 self.closed.load(Ordering::Acquire)
171 }
172
173 pub(crate) fn close_reason(&self) -> Option<ConnectionCloseReason> {
174 self.close_reason.lock().clone()
175 }
176}
177
178type BoxSendFuture = Pin<Box<dyn Future<Output = Result<(), queue::SendQueueError>> + Send>>;
179type BoxIoFuture = Pin<Box<dyn Future<Output = io::Result<()>> + Send>>;
180
181struct PendingWrite {
182 len: usize,
183 fut: BoxSendFuture,
184}
185
186fn is_eof_close_reason(reason: &ConnectionCloseReason) -> bool {
187 matches!(
188 reason,
189 ConnectionCloseReason::RequestedByLocal
190 | ConnectionCloseReason::PeerDisconnected(_)
191 | ConnectionCloseReason::ListenerStopped
192 )
193}
194
195fn close_reason_to_io_error(reason: ConnectionCloseReason) -> io::Error {
196 if is_eof_close_reason(&reason) {
197 io::Error::new(
198 io::ErrorKind::UnexpectedEof,
199 format!("connection closed: {reason:?}"),
200 )
201 } else {
202 io::Error::new(
203 io::ErrorKind::BrokenPipe,
204 format!("connection closed: {reason:?}"),
205 )
206 }
207}
208
209fn send_queue_error_to_io_error(error: queue::SendQueueError) -> io::Error {
210 match error {
211 queue::SendQueueError::CommandChannelClosed => io::Error::new(
212 io::ErrorKind::BrokenPipe,
213 "connection command channel closed",
214 ),
215 queue::SendQueueError::ResponseDropped => io::Error::new(
216 io::ErrorKind::BrokenPipe,
217 "connection command response dropped",
218 ),
219 queue::SendQueueError::Transport { message } => {
220 io::Error::new(io::ErrorKind::BrokenPipe, message)
221 }
222 }
223}
224
225fn send_command_future(
226 shared: Arc<ConnectionSharedState>,
227 command_tx: mpsc::Sender<ConnectionCommand>,
228 peer_id: PeerId,
229 payload: Bytes,
230 options: SendOptions,
231) -> BoxSendFuture {
232 Box::pin(async move {
233 if shared.is_closed() {
234 return Err(queue::SendQueueError::Transport {
235 message: "connection already closed".to_string(),
236 });
237 }
238
239 let (response_tx, response_rx) = oneshot::channel();
240 command_tx
241 .send(ConnectionCommand::Send {
242 peer_id,
243 payload,
244 options,
245 response: response_tx,
246 })
247 .await
248 .map_err(|_| queue::SendQueueError::CommandChannelClosed)?;
249
250 match response_rx.await {
251 Ok(Ok(())) => Ok(()),
252 Ok(Err(err)) => Err(queue::SendQueueError::Transport {
253 message: err.to_string(),
254 }),
255 Err(_) => Err(queue::SendQueueError::ResponseDropped),
256 }
257 })
258}
259
260fn disconnect_command_future(
261 shared: Arc<ConnectionSharedState>,
262 command_tx: mpsc::Sender<ConnectionCommand>,
263 peer_id: PeerId,
264) -> BoxIoFuture {
265 Box::pin(async move {
266 if shared.is_closed() {
267 return Ok(());
268 }
269
270 let (response_tx, response_rx) = oneshot::channel();
271 command_tx
272 .send(ConnectionCommand::Disconnect {
273 peer_id,
274 response: response_tx,
275 })
276 .await
277 .map_err(|_| {
278 io::Error::new(
279 io::ErrorKind::BrokenPipe,
280 "connection command channel closed",
281 )
282 })?;
283
284 match response_rx.await {
285 Ok(result) => result,
286 Err(_) => Err(io::Error::new(
287 io::ErrorKind::BrokenPipe,
288 "connection command response dropped",
289 )),
290 }
291 })
292}
293
294fn fill_read_buf_from_payload(read_buf: &mut ReadBuf<'_>, payload: &mut Bytes) {
295 let copy_len = payload.len().min(read_buf.remaining());
296 if copy_len == 0 {
297 return;
298 }
299
300 let copied = payload.split_to(copy_len);
301 read_buf.put_slice(&copied);
302}
303
304pub struct Connection {
305 remote_addr: SocketAddr,
306 id: ConnectionId,
307 peer_id: PeerId,
308 command_tx: mpsc::Sender<ConnectionCommand>,
309 inbound_rx: mpsc::Receiver<ConnectionInbound>,
310 shared: Arc<ConnectionSharedState>,
311}
312
313impl Connection {
314 pub(crate) fn new(
315 peer_id: PeerId,
316 address: SocketAddr,
317 command_tx: mpsc::Sender<ConnectionCommand>,
318 inbound_rx: mpsc::Receiver<ConnectionInbound>,
319 shared: Arc<ConnectionSharedState>,
320 ) -> Self {
321 Self {
322 remote_addr: address,
323 id: ConnectionId::from(peer_id),
324 peer_id,
325 command_tx,
326 inbound_rx,
327 shared,
328 }
329 }
330
331 pub fn id(&self) -> ConnectionId {
333 self.id
334 }
335
336 pub fn remote_addr(&self) -> SocketAddr {
338 self.remote_addr
339 }
340
341 pub fn metadata(&self) -> ConnectionMetadata {
343 ConnectionMetadata {
344 id: self.id,
345 remote_addr: self.remote_addr,
346 }
347 }
348
349 pub(crate) fn peer_id(&self) -> PeerId {
350 self.peer_id
351 }
352
353 pub fn close_reason(&self) -> Option<ConnectionCloseReason> {
355 self.shared.close_reason()
356 }
357
358 pub(crate) async fn send_with_options(
359 &self,
360 payload: impl Into<Bytes>,
361 options: SendOptions,
362 ) -> Result<(), queue::SendQueueError> {
363 send_command_future(
364 self.shared.clone(),
365 self.command_tx.clone(),
366 self.peer_id,
367 payload.into(),
368 options,
369 )
370 .await
371 }
372
373 pub async fn send_bytes(&self, payload: impl Into<Bytes>) -> Result<(), queue::SendQueueError> {
375 self.send_with_options(payload, SendOptions::default())
376 .await
377 }
378
379 pub async fn send(&self, payload: impl AsRef<[u8]>) -> Result<(), queue::SendQueueError> {
381 self.send_bytes(Bytes::copy_from_slice(payload.as_ref()))
382 .await
383 }
384
385 pub async fn send_compat(
387 &self,
388 stream: &[u8],
389 _immediate: bool,
390 ) -> Result<(), queue::SendQueueError> {
391 self.send(stream).await
392 }
393
394 pub async fn recv_bytes(&mut self) -> Result<Bytes, RecvError> {
396 match self.inbound_rx.recv().await {
397 Some(ConnectionInbound::Packet(payload)) => Ok(payload),
398 Some(ConnectionInbound::DecodeError(message)) => {
399 Err(RecvError::DecodeError { message })
400 }
401 Some(ConnectionInbound::Closed(reason)) => {
402 self.shared.mark_closed(reason.clone());
403 Err(RecvError::ConnectionClosed { reason })
404 }
405 None => {
406 if let Some(reason) = self.shared.close_reason() {
407 Err(RecvError::ConnectionClosed { reason })
408 } else {
409 self.shared
410 .mark_closed(ConnectionCloseReason::ListenerStopped);
411 Err(RecvError::ChannelClosed)
412 }
413 }
414 }
415 }
416
417 pub async fn recv(&mut self) -> Result<Vec<u8>, RecvError> {
419 self.recv_bytes().await.map(|payload| payload.to_vec())
420 }
421
422 pub async fn close(&self) {
424 if self.shared.is_closed() {
425 return;
426 }
427
428 let (response_tx, response_rx) = oneshot::channel();
429 if self
430 .command_tx
431 .send(ConnectionCommand::Disconnect {
432 peer_id: self.peer_id,
433 response: response_tx,
434 })
435 .await
436 .is_err()
437 {
438 self.shared
439 .mark_closed(ConnectionCloseReason::ListenerStopped);
440 return;
441 }
442
443 if response_rx.await.is_ok() {
444 self.shared
445 .mark_closed(ConnectionCloseReason::RequestedByLocal);
446 }
447 }
448
449 pub async fn is_closed(&self) -> bool {
451 self.shared.is_closed()
452 }
453
454 pub fn into_io(self) -> ConnectionIo {
456 ConnectionIo::new(self)
457 }
458}
459
460impl Drop for Connection {
461 fn drop(&mut self) {
462 if self.shared.is_closed() {
463 return;
464 }
465
466 let _ = self
467 .command_tx
468 .try_send(ConnectionCommand::DisconnectNoWait {
469 peer_id: self.peer_id,
470 });
471 }
472}
473
474pub struct ConnectionIo {
476 connection: Connection,
477 read_remainder: Option<Bytes>,
478 write_in_flight: Option<PendingWrite>,
479 shutdown_in_flight: Option<BoxIoFuture>,
480}
481
482impl ConnectionIo {
483 fn new(connection: Connection) -> Self {
484 Self {
485 connection,
486 read_remainder: None,
487 write_in_flight: None,
488 shutdown_in_flight: None,
489 }
490 }
491
492 pub fn connection(&self) -> &Connection {
494 &self.connection
495 }
496
497 pub fn connection_mut(&mut self) -> &mut Connection {
499 &mut self.connection
500 }
501
502 pub fn into_inner(self) -> Connection {
504 self.connection
505 }
506
507 fn poll_pending_write(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Option<usize>>> {
508 let Some(mut state) = self.write_in_flight.take() else {
509 return Poll::Ready(Ok(None));
510 };
511
512 match state.fut.as_mut().poll(cx) {
513 Poll::Ready(Ok(())) => Poll::Ready(Ok(Some(state.len))),
514 Poll::Ready(Err(error)) => Poll::Ready(Err(send_queue_error_to_io_error(error))),
515 Poll::Pending => {
516 self.write_in_flight = Some(state);
517 Poll::Pending
518 }
519 }
520 }
521}
522
523impl AsyncRead for ConnectionIo {
524 fn poll_read(
525 mut self: Pin<&mut Self>,
526 cx: &mut Context<'_>,
527 read_buf: &mut ReadBuf<'_>,
528 ) -> Poll<io::Result<()>> {
529 if read_buf.remaining() == 0 {
530 return Poll::Ready(Ok(()));
531 }
532
533 if let Some(mut remainder) = self.read_remainder.take() {
534 fill_read_buf_from_payload(read_buf, &mut remainder);
535 if !remainder.is_empty() {
536 self.read_remainder = Some(remainder);
537 }
538 return Poll::Ready(Ok(()));
539 }
540
541 match Pin::new(&mut self.connection.inbound_rx).poll_recv(cx) {
542 Poll::Ready(Some(ConnectionInbound::Packet(mut payload))) => {
543 fill_read_buf_from_payload(read_buf, &mut payload);
544 if !payload.is_empty() {
545 self.read_remainder = Some(payload);
546 }
547 Poll::Ready(Ok(()))
548 }
549 Poll::Ready(Some(ConnectionInbound::DecodeError(message))) => {
550 Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, message)))
551 }
552 Poll::Ready(Some(ConnectionInbound::Closed(reason))) => {
553 self.connection.shared.mark_closed(reason.clone());
554 if is_eof_close_reason(&reason) {
555 Poll::Ready(Ok(()))
556 } else {
557 Poll::Ready(Err(close_reason_to_io_error(reason)))
558 }
559 }
560 Poll::Ready(None) => {
561 if let Some(reason) = self.connection.shared.close_reason() {
562 if is_eof_close_reason(&reason) {
563 Poll::Ready(Ok(()))
564 } else {
565 Poll::Ready(Err(close_reason_to_io_error(reason)))
566 }
567 } else {
568 self.connection
569 .shared
570 .mark_closed(ConnectionCloseReason::ListenerStopped);
571 Poll::Ready(Ok(()))
572 }
573 }
574 Poll::Pending => Poll::Pending,
575 }
576 }
577}
578
579impl AsyncWrite for ConnectionIo {
580 fn poll_write(
581 mut self: Pin<&mut Self>,
582 cx: &mut Context<'_>,
583 buf: &[u8],
584 ) -> Poll<io::Result<usize>> {
585 if self.shutdown_in_flight.is_some() {
586 return Poll::Ready(Err(io::Error::new(
587 io::ErrorKind::BrokenPipe,
588 "connection shutdown already in progress",
589 )));
590 }
591
592 match self.as_mut().get_mut().poll_pending_write(cx) {
593 Poll::Ready(Ok(Some(written))) => return Poll::Ready(Ok(written)),
594 Poll::Ready(Err(error)) => return Poll::Ready(Err(error)),
595 Poll::Ready(Ok(None)) => {}
596 Poll::Pending => return Poll::Pending,
597 }
598
599 if buf.is_empty() {
600 return Poll::Ready(Ok(0));
601 }
602
603 if self.connection.shared.is_closed() {
604 return Poll::Ready(Err(io::Error::new(
605 io::ErrorKind::BrokenPipe,
606 "connection already closed",
607 )));
608 }
609
610 let payload = Bytes::copy_from_slice(buf);
611 self.write_in_flight = Some(PendingWrite {
612 len: buf.len(),
613 fut: send_command_future(
614 self.connection.shared.clone(),
615 self.connection.command_tx.clone(),
616 self.connection.peer_id,
617 payload,
618 SendOptions::default(),
619 ),
620 });
621
622 match self.as_mut().get_mut().poll_pending_write(cx) {
623 Poll::Ready(Ok(Some(written))) => Poll::Ready(Ok(written)),
624 Poll::Ready(Ok(None)) => Poll::Ready(Ok(0)),
625 Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
626 Poll::Pending => Poll::Pending,
627 }
628 }
629
630 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
631 match self.as_mut().get_mut().poll_pending_write(cx) {
632 Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
633 Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
634 Poll::Pending => Poll::Pending,
635 }
636 }
637
638 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
639 match self.as_mut().poll_flush(cx) {
640 Poll::Ready(Ok(())) => {}
641 Poll::Ready(Err(error)) => return Poll::Ready(Err(error)),
642 Poll::Pending => return Poll::Pending,
643 }
644
645 if self.connection.shared.is_closed() {
646 return Poll::Ready(Ok(()));
647 }
648
649 if self.shutdown_in_flight.is_none() {
650 self.shutdown_in_flight = Some(disconnect_command_future(
651 self.connection.shared.clone(),
652 self.connection.command_tx.clone(),
653 self.connection.peer_id,
654 ));
655 }
656
657 let Some(mut shutdown_future) = self.shutdown_in_flight.take() else {
658 return Poll::Ready(Ok(()));
659 };
660
661 match shutdown_future.as_mut().poll(cx) {
662 Poll::Ready(Ok(())) => {
663 self.connection
664 .shared
665 .mark_closed(ConnectionCloseReason::RequestedByLocal);
666 Poll::Ready(Ok(()))
667 }
668 Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
669 Poll::Pending => {
670 self.shutdown_in_flight = Some(shutdown_future);
671 Poll::Pending
672 }
673 }
674 }
675}