1use futures::future::BoxFuture;
2use std::{
3 io,
4 net::SocketAddr,
5 task::{Context, Poll},
6};
7use tokio::net::{TcpListener, TcpStream};
8use tracing::debug;
9
10use msg_common::async_error;
11
12use crate::{Acceptor, PeerAddress, Transport, TransportExt};
13
14mod stats;
15pub use stats::TcpStats;
16
17#[derive(Debug, Default)]
18pub struct Config;
19
20#[derive(Debug, Default)]
21pub struct Tcp {
22 #[allow(unused)]
23 config: Config,
24 listener: Option<tokio::net::TcpListener>,
25}
26
27impl Tcp {
28 pub fn new(config: Config) -> Self {
29 Self { config, listener: None }
30 }
31}
32
33impl PeerAddress<SocketAddr> for TcpStream {
34 fn peer_addr(&self) -> io::Result<SocketAddr> {
35 self.peer_addr()
36 }
37}
38
39#[async_trait::async_trait]
40impl Transport<SocketAddr> for Tcp {
41 type Stats = TcpStats;
42 type Io = TcpStream;
43
44 type Control = ();
45
46 type Error = io::Error;
47
48 type Connect = BoxFuture<'static, Result<Self::Io, Self::Error>>;
49 type Accept = BoxFuture<'static, Result<Self::Io, Self::Error>>;
50
51 fn local_addr(&self) -> Option<SocketAddr> {
52 self.listener.as_ref().and_then(|l| l.local_addr().ok())
53 }
54
55 async fn bind(&mut self, addr: SocketAddr) -> Result<(), Self::Error> {
56 let listener = TcpListener::bind(addr).await?;
57
58 self.listener = Some(listener);
59
60 Ok(())
61 }
62
63 fn connect(&mut self, addr: SocketAddr) -> Self::Connect {
64 Box::pin(async move {
65 let stream = TcpStream::connect(addr).await?;
66 stream.set_nodelay(true)?;
67
68 Ok(stream)
69 })
70 }
71
72 fn poll_accept(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Accept> {
73 let this = self.get_mut();
74
75 let Some(ref listener) = this.listener else {
76 return Poll::Ready(async_error(io::ErrorKind::NotConnected.into()));
77 };
78
79 match listener.poll_accept(cx) {
80 Poll::Ready(Ok((io, addr))) => {
81 debug!(%addr, "accepted connection");
82
83 Poll::Ready(Box::pin(async move {
84 io.set_nodelay(true)?;
85 Ok(io)
86 }))
87 }
88 Poll::Ready(Err(e)) => Poll::Ready(async_error(e)),
89 Poll::Pending => Poll::Pending,
90 }
91 }
92}
93
94impl TransportExt<SocketAddr> for Tcp {
95 fn accept(&mut self) -> Acceptor<'_, Self, SocketAddr>
96 where
97 Self: Sized + Unpin,
98 {
99 Acceptor::new(self)
100 }
101}