cloudpub_common/transport/
tcp.rs

1use crate::config::{TcpConfig, TransportConfig};
2use crate::constants::MESSAGE_TIMEOUT_SECS;
3
4use super::{AddrMaybeCached, ProtobufStream, SocketOpts, Transport};
5pub use crate::unix_tcp::{Listener, NamedSocketAddr, SocketAddr, Stream};
6use crate::utils::host_port_pair;
7use anyhow::{Context as _, Result};
8use async_http_proxy::{http_connect_tokio, http_connect_tokio_with_basic_auth};
9use async_trait::async_trait;
10use socket2::{SockRef, TcpKeepalive};
11#[cfg(unix)]
12use std::os::fd::RawFd;
13use std::str::FromStr;
14use std::time::Duration;
15type RawTcpStream = Stream;
16use crate::protocol::message::Message as ProtocolMessage;
17use crate::protocol::{read_message, write_message};
18use std::pin::Pin;
19use std::task::{Context, Poll};
20use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
21use tokio::time::timeout;
22use tracing::trace;
23use url::Url;
24
25#[derive(Debug)]
26pub struct TcpStream {
27    inner: RawTcpStream,
28}
29
30impl TcpStream {
31    pub fn new(stream: RawTcpStream) -> Self {
32        Self { inner: stream }
33    }
34
35    pub fn into_inner(self) -> RawTcpStream {
36        self.inner
37    }
38
39    pub fn get_ref(&self) -> &RawTcpStream {
40        &self.inner
41    }
42
43    pub fn get_mut(&mut self) -> &mut RawTcpStream {
44        &mut self.inner
45    }
46
47    pub fn into_stream(self) -> Stream {
48        self.inner
49    }
50}
51
52impl AsyncRead for TcpStream {
53    fn poll_read(
54        mut self: Pin<&mut Self>,
55        cx: &mut Context<'_>,
56        buf: &mut ReadBuf<'_>,
57    ) -> Poll<std::io::Result<()>> {
58        Pin::new(&mut self.inner).poll_read(cx, buf)
59    }
60}
61
62impl AsyncWrite for TcpStream {
63    fn poll_write(
64        mut self: Pin<&mut Self>,
65        cx: &mut Context<'_>,
66        buf: &[u8],
67    ) -> Poll<Result<usize, std::io::Error>> {
68        Pin::new(&mut self.inner).poll_write(cx, buf)
69    }
70
71    fn poll_flush(
72        mut self: Pin<&mut Self>,
73        cx: &mut Context<'_>,
74    ) -> Poll<Result<(), std::io::Error>> {
75        Pin::new(&mut self.inner).poll_flush(cx)
76    }
77
78    fn poll_shutdown(
79        mut self: Pin<&mut Self>,
80        cx: &mut Context<'_>,
81    ) -> Poll<Result<(), std::io::Error>> {
82        Pin::new(&mut self.inner).poll_shutdown(cx)
83    }
84}
85
86#[async_trait]
87impl ProtobufStream for TcpStream {
88    async fn recv_message(&mut self) -> anyhow::Result<Option<ProtocolMessage>> {
89        let timeout_duration = Duration::from_secs(MESSAGE_TIMEOUT_SECS);
90        match timeout(timeout_duration, read_message(&mut self.inner)).await {
91            Ok(Ok(msg)) => Ok(Some(msg)),
92            Ok(Err(e)) => {
93                if let Some(io_err) = e.downcast_ref::<std::io::Error>() {
94                    if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
95                        return Ok(None);
96                    }
97                }
98                Err(e)
99            }
100            Err(_) => Err(anyhow::anyhow!(
101                "Timeout reading message after {} seconds",
102                MESSAGE_TIMEOUT_SECS
103            )),
104        }
105    }
106
107    async fn send_message(&mut self, msg: &ProtocolMessage) -> anyhow::Result<()> {
108        let timeout_duration = Duration::from_secs(MESSAGE_TIMEOUT_SECS);
109        timeout(timeout_duration, write_message(&mut self.inner, msg))
110            .await
111            .map_err(|_| {
112                anyhow::anyhow!(
113                    "Timeout writing message after {} seconds",
114                    MESSAGE_TIMEOUT_SECS
115                )
116            })?
117    }
118
119    async fn close(&mut self) -> anyhow::Result<()> {
120        self.inner
121            .shutdown()
122            .await
123            .context("Failed to shutdown stream")
124    }
125}
126
127#[derive(Debug)]
128pub struct TcpTransport {
129    pub socket_opts: SocketOpts,
130    pub cfg: TcpConfig,
131}
132
133#[async_trait]
134impl Transport for TcpTransport {
135    type Acceptor = Listener;
136    type Stream = TcpStream;
137    type RawStream = Stream;
138
139    fn new(config: &TransportConfig) -> Result<Self> {
140        Ok(TcpTransport {
141            socket_opts: SocketOpts::for_control_channel(),
142            cfg: config.tcp.clone(),
143        })
144    }
145
146    #[cfg(unix)]
147    fn as_raw_fd(conn: &Self::Stream) -> RawFd {
148        use std::os::fd::AsRawFd;
149        match conn.get_ref() {
150            Stream::Tcp(tcp_stream) => tcp_stream.as_raw_fd(),
151            #[cfg(unix)]
152            Stream::Unix(unix_stream) => unix_stream.as_raw_fd(),
153        }
154    }
155
156    fn hint(conn: &Self::Stream, opt: SocketOpts) {
157        opt.apply(conn.get_ref());
158    }
159
160    async fn bind(&self, addr: NamedSocketAddr) -> Result<Self::Acceptor> {
161        #[cfg(unix)]
162        if let NamedSocketAddr::Unix(path) = &addr {
163            // Ensure the socket file is removed if it exists
164            if path.exists() {
165                tokio::fs::remove_file(path).await?;
166            }
167        }
168        Ok(Listener::bind(&addr).await?)
169    }
170
171    async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::RawStream, SocketAddr)> {
172        let (s, addr) = a.accept().await?;
173        self.socket_opts.apply(&s);
174        Ok((s, addr))
175    }
176
177    async fn handshake(&self, conn: Self::RawStream) -> Result<Self::Stream> {
178        Ok(TcpStream::new(conn))
179    }
180
181    async fn connect(&self, addr: &AddrMaybeCached) -> Result<Self::Stream> {
182        let s = tcp_connect_with_proxy(addr, self.cfg.proxy.as_ref()).await?;
183        self.socket_opts.apply(&s);
184        Ok(TcpStream::new(s))
185    }
186}
187
188// Tokio hesitates to expose this option...So we have to do it on our own :(
189// The good news is that using socket2 it can be easily done, without losing portability.
190// See https://github.com/tokio-rs/tokio/issues/3082
191pub fn try_set_tcp_keepalive(
192    conn: &RawTcpStream,
193    keepalive_duration: Duration,
194    keepalive_interval: Duration,
195) -> Result<()> {
196    match conn {
197        Stream::Tcp(tcp_stream) => {
198            let s = SockRef::from(tcp_stream);
199            let keepalive = TcpKeepalive::new()
200                .with_time(keepalive_duration)
201                .with_interval(keepalive_interval);
202
203            trace!(
204                "Set TCP keepalive {:?} {:?}",
205                keepalive_duration,
206                keepalive_interval
207            );
208
209            Ok(s.set_tcp_keepalive(&keepalive)?)
210        }
211        #[cfg(unix)]
212        Stream::Unix(_) => {
213            // Unix sockets don't support TCP keepalive
214            Ok(())
215        }
216    }
217}
218
219/// Create a TcpStream using a proxy
220/// e.g. socks5://user:pass@127.0.0.1:1080 http://127.0.0.1:8080
221pub async fn tcp_connect_with_proxy(addr: &AddrMaybeCached, proxy: Option<&Url>) -> Result<Stream> {
222    if let Some(url) = proxy {
223        let addr = &addr.addr;
224        let proxy_addr = format!(
225            "{}:{}",
226            url.host_str().expect("proxy url should have host field"),
227            url.port().expect("proxy url should have port field")
228        );
229        let mut s = Stream::connect(&NamedSocketAddr::from_str(&proxy_addr)?).await?;
230
231        let auth = if !url.username().is_empty() || url.password().is_some() {
232            Some(async_socks5::Auth {
233                username: url.username().into(),
234                password: url.password().unwrap_or("").into(),
235            })
236        } else {
237            None
238        };
239        match url.scheme() {
240            "socks5" => {
241                async_socks5::connect(&mut s, host_port_pair(addr)?, auth).await?;
242            }
243            "http" => {
244                let (host, port) = host_port_pair(addr)?;
245                match auth {
246                    Some(auth) => {
247                        http_connect_tokio_with_basic_auth(
248                            &mut s,
249                            host,
250                            port,
251                            &auth.username,
252                            &auth.password,
253                        )
254                        .await?
255                    }
256                    None => http_connect_tokio(&mut s, host, port).await?,
257                }
258            }
259            _ => panic!("unknown proxy scheme"),
260        }
261        Ok(s)
262    } else {
263        Ok(match addr.socket_addr.as_ref() {
264            Some(s) => Stream::connect(s).await?,
265            None => Stream::connect(&NamedSocketAddr::from_str(&addr.addr)?).await?,
266        })
267    }
268}