1use std::future::Future;
2use std::io;
3use std::net::SocketAddr;
4use std::pin::Pin;
5use std::sync::{
6 Arc,
7 atomic::{AtomicBool, Ordering},
8};
9use std::task::{Context, Poll};
10
11use bytes::Bytes;
12use thiserror::Error;
13use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14use tokio::sync::{mpsc, oneshot};
15
16use crate::concurrency::FastMutex;
17use crate::server::{PeerDisconnectReason, PeerId, SendOptions};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct ConnectionId(u64);
21
22impl ConnectionId {
23 pub const fn from_u64(value: u64) -> Self {
24 Self(value)
25 }
26
27 pub const fn as_u64(self) -> u64 {
28 self.0
29 }
30}
31
32impl From<PeerId> for ConnectionId {
33 fn from(value: PeerId) -> Self {
34 Self::from_u64(value.as_u64())
35 }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub struct ConnectionMetadata {
40 id: ConnectionId,
41 remote_addr: SocketAddr,
42}
43
44impl ConnectionMetadata {
45 pub const fn id(self) -> ConnectionId {
46 self.id
47 }
48
49 pub const fn remote_addr(self) -> SocketAddr {
50 self.remote_addr
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub enum RemoteDisconnectReason {
56 Requested,
57 RemoteDisconnectionNotification { reason_code: Option<u8> },
58 RemoteDetectLostConnection,
59 WorkerStopped { shard_id: usize },
60}
61
62impl From<PeerDisconnectReason> for RemoteDisconnectReason {
63 fn from(value: PeerDisconnectReason) -> Self {
64 match value {
65 PeerDisconnectReason::Requested => Self::Requested,
66 PeerDisconnectReason::RemoteDisconnectionNotification { reason_code } => {
67 Self::RemoteDisconnectionNotification { reason_code }
68 }
69 PeerDisconnectReason::RemoteDetectLostConnection => Self::RemoteDetectLostConnection,
70 PeerDisconnectReason::WorkerStopped { shard_id } => Self::WorkerStopped { shard_id },
71 }
72 }
73}
74
75#[derive(Debug, Clone, PartialEq, Eq)]
76pub enum ConnectionCloseReason {
77 RequestedByLocal,
78 PeerDisconnected(RemoteDisconnectReason),
79 ListenerStopped,
80 InboundBackpressure,
81 TransportError(String),
82}
83
84#[derive(Debug, Error, Clone, PartialEq, Eq)]
85pub enum RecvError {
86 #[error("connection closed: {reason:?}")]
87 ConnectionClosed { reason: ConnectionCloseReason },
88 #[error("decode error: {message}")]
89 DecodeError { message: String },
90 #[error("connection receive channel closed")]
91 ChannelClosed,
92}
93
94pub mod queue {
95 use thiserror::Error;
96
97 #[derive(Debug, Error, Clone, PartialEq, Eq)]
98 pub enum SendQueueError {
99 #[error("connection command channel closed")]
100 CommandChannelClosed,
101 #[error("connection command response dropped")]
102 ResponseDropped,
103 #[error("transport send failed: {message}")]
104 Transport { message: String },
105 }
106}
107
108#[derive(Debug)]
109pub(crate) enum ConnectionInbound {
110 Packet(Bytes),
111 DecodeError(String),
112 Closed(ConnectionCloseReason),
113}
114
115#[derive(Debug)]
116pub(crate) enum ConnectionCommand {
117 Send {
118 peer_id: PeerId,
119 payload: Bytes,
120 options: SendOptions,
121 response: oneshot::Sender<io::Result<()>>,
122 },
123 Disconnect {
124 peer_id: PeerId,
125 response: oneshot::Sender<io::Result<()>>,
126 },
127 DisconnectNoWait {
128 peer_id: PeerId,
129 },
130 Shutdown {
131 response: oneshot::Sender<io::Result<()>>,
132 },
133}
134
135#[derive(Debug)]
136pub(crate) struct ConnectionSharedState {
137 closed: AtomicBool,
138 close_reason: FastMutex<Option<ConnectionCloseReason>>,
139}
140
141impl ConnectionSharedState {
142 pub(crate) fn new() -> Self {
143 Self {
144 closed: AtomicBool::new(false),
145 close_reason: FastMutex::new(None),
146 }
147 }
148
149 pub(crate) fn mark_closed(&self, reason: ConnectionCloseReason) {
150 self.closed.store(true, Ordering::Release);
151 *self.close_reason.lock() = Some(reason);
152 }
153
154 pub(crate) fn is_closed(&self) -> bool {
155 self.closed.load(Ordering::Acquire)
156 }
157
158 pub(crate) fn close_reason(&self) -> Option<ConnectionCloseReason> {
159 self.close_reason.lock().clone()
160 }
161}
162
163type BoxSendFuture = Pin<Box<dyn Future<Output = Result<(), queue::SendQueueError>> + Send>>;
164type BoxIoFuture = Pin<Box<dyn Future<Output = io::Result<()>> + Send>>;
165
166struct PendingWrite {
167 len: usize,
168 fut: BoxSendFuture,
169}
170
171fn is_eof_close_reason(reason: &ConnectionCloseReason) -> bool {
172 matches!(
173 reason,
174 ConnectionCloseReason::RequestedByLocal
175 | ConnectionCloseReason::PeerDisconnected(_)
176 | ConnectionCloseReason::ListenerStopped
177 )
178}
179
180fn close_reason_to_io_error(reason: ConnectionCloseReason) -> io::Error {
181 if is_eof_close_reason(&reason) {
182 io::Error::new(
183 io::ErrorKind::UnexpectedEof,
184 format!("connection closed: {reason:?}"),
185 )
186 } else {
187 io::Error::new(
188 io::ErrorKind::BrokenPipe,
189 format!("connection closed: {reason:?}"),
190 )
191 }
192}
193
194fn send_queue_error_to_io_error(error: queue::SendQueueError) -> io::Error {
195 match error {
196 queue::SendQueueError::CommandChannelClosed => io::Error::new(
197 io::ErrorKind::BrokenPipe,
198 "connection command channel closed",
199 ),
200 queue::SendQueueError::ResponseDropped => io::Error::new(
201 io::ErrorKind::BrokenPipe,
202 "connection command response dropped",
203 ),
204 queue::SendQueueError::Transport { message } => {
205 io::Error::new(io::ErrorKind::BrokenPipe, message)
206 }
207 }
208}
209
210fn send_command_future(
211 shared: Arc<ConnectionSharedState>,
212 command_tx: mpsc::Sender<ConnectionCommand>,
213 peer_id: PeerId,
214 payload: Bytes,
215 options: SendOptions,
216) -> BoxSendFuture {
217 Box::pin(async move {
218 if shared.is_closed() {
219 return Err(queue::SendQueueError::Transport {
220 message: "connection already closed".to_string(),
221 });
222 }
223
224 let (response_tx, response_rx) = oneshot::channel();
225 command_tx
226 .send(ConnectionCommand::Send {
227 peer_id,
228 payload,
229 options,
230 response: response_tx,
231 })
232 .await
233 .map_err(|_| queue::SendQueueError::CommandChannelClosed)?;
234
235 match response_rx.await {
236 Ok(Ok(())) => Ok(()),
237 Ok(Err(err)) => Err(queue::SendQueueError::Transport {
238 message: err.to_string(),
239 }),
240 Err(_) => Err(queue::SendQueueError::ResponseDropped),
241 }
242 })
243}
244
245fn disconnect_command_future(
246 shared: Arc<ConnectionSharedState>,
247 command_tx: mpsc::Sender<ConnectionCommand>,
248 peer_id: PeerId,
249) -> BoxIoFuture {
250 Box::pin(async move {
251 if shared.is_closed() {
252 return Ok(());
253 }
254
255 let (response_tx, response_rx) = oneshot::channel();
256 command_tx
257 .send(ConnectionCommand::Disconnect {
258 peer_id,
259 response: response_tx,
260 })
261 .await
262 .map_err(|_| {
263 io::Error::new(
264 io::ErrorKind::BrokenPipe,
265 "connection command channel closed",
266 )
267 })?;
268
269 match response_rx.await {
270 Ok(result) => result,
271 Err(_) => Err(io::Error::new(
272 io::ErrorKind::BrokenPipe,
273 "connection command response dropped",
274 )),
275 }
276 })
277}
278
279fn fill_read_buf_from_payload(read_buf: &mut ReadBuf<'_>, payload: &mut Bytes) {
280 let copy_len = payload.len().min(read_buf.remaining());
281 if copy_len == 0 {
282 return;
283 }
284
285 let copied = payload.split_to(copy_len);
286 read_buf.put_slice(&copied);
287}
288
289pub struct Connection {
290 remote_addr: SocketAddr,
291 id: ConnectionId,
292 peer_id: PeerId,
293 command_tx: mpsc::Sender<ConnectionCommand>,
294 inbound_rx: mpsc::Receiver<ConnectionInbound>,
295 shared: Arc<ConnectionSharedState>,
296}
297
298impl Connection {
299 pub(crate) fn new(
300 peer_id: PeerId,
301 address: SocketAddr,
302 command_tx: mpsc::Sender<ConnectionCommand>,
303 inbound_rx: mpsc::Receiver<ConnectionInbound>,
304 shared: Arc<ConnectionSharedState>,
305 ) -> Self {
306 Self {
307 remote_addr: address,
308 id: ConnectionId::from(peer_id),
309 peer_id,
310 command_tx,
311 inbound_rx,
312 shared,
313 }
314 }
315
316 pub fn id(&self) -> ConnectionId {
317 self.id
318 }
319
320 pub fn remote_addr(&self) -> SocketAddr {
321 self.remote_addr
322 }
323
324 pub fn metadata(&self) -> ConnectionMetadata {
325 ConnectionMetadata {
326 id: self.id,
327 remote_addr: self.remote_addr,
328 }
329 }
330
331 pub(crate) fn peer_id(&self) -> PeerId {
332 self.peer_id
333 }
334
335 pub fn close_reason(&self) -> Option<ConnectionCloseReason> {
336 self.shared.close_reason()
337 }
338
339 pub(crate) async fn send_with_options(
340 &self,
341 payload: impl Into<Bytes>,
342 options: SendOptions,
343 ) -> Result<(), queue::SendQueueError> {
344 send_command_future(
345 self.shared.clone(),
346 self.command_tx.clone(),
347 self.peer_id,
348 payload.into(),
349 options,
350 )
351 .await
352 }
353
354 pub async fn send_bytes(&self, payload: impl Into<Bytes>) -> Result<(), queue::SendQueueError> {
355 self.send_with_options(payload, SendOptions::default())
356 .await
357 }
358
359 pub async fn send(&self, payload: impl AsRef<[u8]>) -> Result<(), queue::SendQueueError> {
360 self.send_bytes(Bytes::copy_from_slice(payload.as_ref()))
361 .await
362 }
363
364 pub async fn send_compat(
365 &self,
366 stream: &[u8],
367 _immediate: bool,
368 ) -> Result<(), queue::SendQueueError> {
369 self.send(stream).await
370 }
371
372 pub async fn recv_bytes(&mut self) -> Result<Bytes, RecvError> {
373 match self.inbound_rx.recv().await {
374 Some(ConnectionInbound::Packet(payload)) => Ok(payload),
375 Some(ConnectionInbound::DecodeError(message)) => {
376 Err(RecvError::DecodeError { message })
377 }
378 Some(ConnectionInbound::Closed(reason)) => {
379 self.shared.mark_closed(reason.clone());
380 Err(RecvError::ConnectionClosed { reason })
381 }
382 None => {
383 if let Some(reason) = self.shared.close_reason() {
384 Err(RecvError::ConnectionClosed { reason })
385 } else {
386 self.shared
387 .mark_closed(ConnectionCloseReason::ListenerStopped);
388 Err(RecvError::ChannelClosed)
389 }
390 }
391 }
392 }
393
394 pub async fn recv(&mut self) -> Result<Vec<u8>, RecvError> {
395 self.recv_bytes().await.map(|payload| payload.to_vec())
396 }
397
398 pub async fn close(&self) {
399 if self.shared.is_closed() {
400 return;
401 }
402
403 let (response_tx, response_rx) = oneshot::channel();
404 if self
405 .command_tx
406 .send(ConnectionCommand::Disconnect {
407 peer_id: self.peer_id,
408 response: response_tx,
409 })
410 .await
411 .is_err()
412 {
413 self.shared
414 .mark_closed(ConnectionCloseReason::ListenerStopped);
415 return;
416 }
417
418 if response_rx.await.is_ok() {
419 self.shared
420 .mark_closed(ConnectionCloseReason::RequestedByLocal);
421 }
422 }
423
424 pub async fn is_closed(&self) -> bool {
425 self.shared.is_closed()
426 }
427
428 pub fn into_io(self) -> ConnectionIo {
429 ConnectionIo::new(self)
430 }
431}
432
433impl Drop for Connection {
434 fn drop(&mut self) {
435 if self.shared.is_closed() {
436 return;
437 }
438
439 let _ = self
440 .command_tx
441 .try_send(ConnectionCommand::DisconnectNoWait {
442 peer_id: self.peer_id,
443 });
444 }
445}
446
447pub struct ConnectionIo {
448 connection: Connection,
449 read_remainder: Option<Bytes>,
450 write_in_flight: Option<PendingWrite>,
451 shutdown_in_flight: Option<BoxIoFuture>,
452}
453
454impl ConnectionIo {
455 fn new(connection: Connection) -> Self {
456 Self {
457 connection,
458 read_remainder: None,
459 write_in_flight: None,
460 shutdown_in_flight: None,
461 }
462 }
463
464 pub fn connection(&self) -> &Connection {
465 &self.connection
466 }
467
468 pub fn connection_mut(&mut self) -> &mut Connection {
469 &mut self.connection
470 }
471
472 pub fn into_inner(self) -> Connection {
473 self.connection
474 }
475
476 fn poll_pending_write(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Option<usize>>> {
477 let Some(mut state) = self.write_in_flight.take() else {
478 return Poll::Ready(Ok(None));
479 };
480
481 match state.fut.as_mut().poll(cx) {
482 Poll::Ready(Ok(())) => Poll::Ready(Ok(Some(state.len))),
483 Poll::Ready(Err(error)) => Poll::Ready(Err(send_queue_error_to_io_error(error))),
484 Poll::Pending => {
485 self.write_in_flight = Some(state);
486 Poll::Pending
487 }
488 }
489 }
490}
491
492impl AsyncRead for ConnectionIo {
493 fn poll_read(
494 mut self: Pin<&mut Self>,
495 cx: &mut Context<'_>,
496 read_buf: &mut ReadBuf<'_>,
497 ) -> Poll<io::Result<()>> {
498 if read_buf.remaining() == 0 {
499 return Poll::Ready(Ok(()));
500 }
501
502 if let Some(mut remainder) = self.read_remainder.take() {
503 fill_read_buf_from_payload(read_buf, &mut remainder);
504 if !remainder.is_empty() {
505 self.read_remainder = Some(remainder);
506 }
507 return Poll::Ready(Ok(()));
508 }
509
510 match Pin::new(&mut self.connection.inbound_rx).poll_recv(cx) {
511 Poll::Ready(Some(ConnectionInbound::Packet(mut payload))) => {
512 fill_read_buf_from_payload(read_buf, &mut payload);
513 if !payload.is_empty() {
514 self.read_remainder = Some(payload);
515 }
516 Poll::Ready(Ok(()))
517 }
518 Poll::Ready(Some(ConnectionInbound::DecodeError(message))) => {
519 Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, message)))
520 }
521 Poll::Ready(Some(ConnectionInbound::Closed(reason))) => {
522 self.connection.shared.mark_closed(reason.clone());
523 if is_eof_close_reason(&reason) {
524 Poll::Ready(Ok(()))
525 } else {
526 Poll::Ready(Err(close_reason_to_io_error(reason)))
527 }
528 }
529 Poll::Ready(None) => {
530 if let Some(reason) = self.connection.shared.close_reason() {
531 if is_eof_close_reason(&reason) {
532 Poll::Ready(Ok(()))
533 } else {
534 Poll::Ready(Err(close_reason_to_io_error(reason)))
535 }
536 } else {
537 self.connection
538 .shared
539 .mark_closed(ConnectionCloseReason::ListenerStopped);
540 Poll::Ready(Ok(()))
541 }
542 }
543 Poll::Pending => Poll::Pending,
544 }
545 }
546}
547
548impl AsyncWrite for ConnectionIo {
549 fn poll_write(
550 mut self: Pin<&mut Self>,
551 cx: &mut Context<'_>,
552 buf: &[u8],
553 ) -> Poll<io::Result<usize>> {
554 if self.shutdown_in_flight.is_some() {
555 return Poll::Ready(Err(io::Error::new(
556 io::ErrorKind::BrokenPipe,
557 "connection shutdown already in progress",
558 )));
559 }
560
561 match self.as_mut().get_mut().poll_pending_write(cx) {
562 Poll::Ready(Ok(Some(written))) => return Poll::Ready(Ok(written)),
563 Poll::Ready(Err(error)) => return Poll::Ready(Err(error)),
564 Poll::Ready(Ok(None)) => {}
565 Poll::Pending => return Poll::Pending,
566 }
567
568 if buf.is_empty() {
569 return Poll::Ready(Ok(0));
570 }
571
572 if self.connection.shared.is_closed() {
573 return Poll::Ready(Err(io::Error::new(
574 io::ErrorKind::BrokenPipe,
575 "connection already closed",
576 )));
577 }
578
579 let payload = Bytes::copy_from_slice(buf);
580 self.write_in_flight = Some(PendingWrite {
581 len: buf.len(),
582 fut: send_command_future(
583 self.connection.shared.clone(),
584 self.connection.command_tx.clone(),
585 self.connection.peer_id,
586 payload,
587 SendOptions::default(),
588 ),
589 });
590
591 match self.as_mut().get_mut().poll_pending_write(cx) {
592 Poll::Ready(Ok(Some(written))) => Poll::Ready(Ok(written)),
593 Poll::Ready(Ok(None)) => Poll::Ready(Ok(0)),
594 Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
595 Poll::Pending => Poll::Pending,
596 }
597 }
598
599 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
600 match self.as_mut().get_mut().poll_pending_write(cx) {
601 Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
602 Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
603 Poll::Pending => Poll::Pending,
604 }
605 }
606
607 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
608 match self.as_mut().poll_flush(cx) {
609 Poll::Ready(Ok(())) => {}
610 Poll::Ready(Err(error)) => return Poll::Ready(Err(error)),
611 Poll::Pending => return Poll::Pending,
612 }
613
614 if self.connection.shared.is_closed() {
615 return Poll::Ready(Ok(()));
616 }
617
618 if self.shutdown_in_flight.is_none() {
619 self.shutdown_in_flight = Some(disconnect_command_future(
620 self.connection.shared.clone(),
621 self.connection.command_tx.clone(),
622 self.connection.peer_id,
623 ));
624 }
625
626 let Some(mut shutdown_future) = self.shutdown_in_flight.take() else {
627 return Poll::Ready(Ok(()));
628 };
629
630 match shutdown_future.as_mut().poll(cx) {
631 Poll::Ready(Ok(())) => {
632 self.connection
633 .shared
634 .mark_closed(ConnectionCloseReason::RequestedByLocal);
635 Poll::Ready(Ok(()))
636 }
637 Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
638 Poll::Pending => {
639 self.shutdown_in_flight = Some(shutdown_future);
640 Poll::Pending
641 }
642 }
643 }
644}