Skip to main content

msg_transport/tcp/
mod.rs

1use futures::future::BoxFuture;
2use std::{
3    io,
4    net::SocketAddr,
5    task::{Context, Poll},
6};
7use tokio::net::{TcpListener, TcpStream};
8use tracing::debug;
9
10use msg_common::async_error;
11
12use crate::{Acceptor, PeerAddress, Transport, TransportExt};
13
14mod stats;
15pub use stats::TcpStats;
16
17#[derive(Debug, Default)]
18pub struct Config;
19
20#[derive(Debug, Default)]
21pub struct Tcp {
22    #[allow(unused)]
23    config: Config,
24    listener: Option<tokio::net::TcpListener>,
25}
26
27impl Tcp {
28    pub fn new(config: Config) -> Self {
29        Self { config, listener: None }
30    }
31}
32
33impl PeerAddress<SocketAddr> for TcpStream {
34    fn peer_addr(&self) -> io::Result<SocketAddr> {
35        self.peer_addr()
36    }
37}
38
39#[async_trait::async_trait]
40impl Transport<SocketAddr> for Tcp {
41    type Stats = TcpStats;
42    type Io = TcpStream;
43
44    type Control = ();
45
46    type Error = io::Error;
47
48    type Connect = BoxFuture<'static, Result<Self::Io, Self::Error>>;
49    type Accept = BoxFuture<'static, Result<Self::Io, Self::Error>>;
50
51    fn local_addr(&self) -> Option<SocketAddr> {
52        self.listener.as_ref().and_then(|l| l.local_addr().ok())
53    }
54
55    async fn bind(&mut self, addr: SocketAddr) -> Result<(), Self::Error> {
56        let listener = TcpListener::bind(addr).await?;
57
58        self.listener = Some(listener);
59
60        Ok(())
61    }
62
63    fn connect(&mut self, addr: SocketAddr) -> Self::Connect {
64        Box::pin(async move {
65            let stream = TcpStream::connect(addr).await?;
66            stream.set_nodelay(true)?;
67
68            Ok(stream)
69        })
70    }
71
72    fn poll_accept(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Accept> {
73        let this = self.get_mut();
74
75        let Some(ref listener) = this.listener else {
76            return Poll::Ready(async_error(io::ErrorKind::NotConnected.into()));
77        };
78
79        match listener.poll_accept(cx) {
80            Poll::Ready(Ok((io, addr))) => {
81                debug!(%addr, "accepted connection");
82
83                Poll::Ready(Box::pin(async move {
84                    io.set_nodelay(true)?;
85                    Ok(io)
86                }))
87            }
88            Poll::Ready(Err(e)) => Poll::Ready(async_error(e)),
89            Poll::Pending => Poll::Pending,
90        }
91    }
92}
93
94impl TransportExt<SocketAddr> for Tcp {
95    fn accept(&mut self) -> Acceptor<'_, Self, SocketAddr>
96    where
97        Self: Sized + Unpin,
98    {
99        Acceptor::new(self)
100    }
101}