cloudpub_common/transport/
tcp.rs1use crate::config::{TcpConfig, TransportConfig};
2use crate::constants::MESSAGE_TIMEOUT_SECS;
3
4use super::{AddrMaybeCached, ProtobufStream, SocketOpts, Transport};
5pub use crate::unix_tcp::{Listener, NamedSocketAddr, SocketAddr, Stream};
6use crate::utils::host_port_pair;
7use anyhow::{Context as _, Result};
8use async_http_proxy::{http_connect_tokio, http_connect_tokio_with_basic_auth};
9use async_trait::async_trait;
10use socket2::{SockRef, TcpKeepalive};
11#[cfg(unix)]
12use std::os::fd::RawFd;
13use std::str::FromStr;
14use std::time::Duration;
15type RawTcpStream = Stream;
16use crate::protocol::message::Message as ProtocolMessage;
17use crate::protocol::{read_message, write_message};
18use std::pin::Pin;
19use std::task::{Context, Poll};
20use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
21use tokio::time::timeout;
22use tracing::trace;
23use url::Url;
24
25#[derive(Debug)]
26pub struct TcpStream {
27 inner: RawTcpStream,
28}
29
30impl TcpStream {
31 pub fn new(stream: RawTcpStream) -> Self {
32 Self { inner: stream }
33 }
34
35 pub fn into_inner(self) -> RawTcpStream {
36 self.inner
37 }
38
39 pub fn get_ref(&self) -> &RawTcpStream {
40 &self.inner
41 }
42
43 pub fn get_mut(&mut self) -> &mut RawTcpStream {
44 &mut self.inner
45 }
46
47 pub fn into_stream(self) -> Stream {
48 self.inner
49 }
50}
51
52impl AsyncRead for TcpStream {
53 fn poll_read(
54 mut self: Pin<&mut Self>,
55 cx: &mut Context<'_>,
56 buf: &mut ReadBuf<'_>,
57 ) -> Poll<std::io::Result<()>> {
58 Pin::new(&mut self.inner).poll_read(cx, buf)
59 }
60}
61
62impl AsyncWrite for TcpStream {
63 fn poll_write(
64 mut self: Pin<&mut Self>,
65 cx: &mut Context<'_>,
66 buf: &[u8],
67 ) -> Poll<Result<usize, std::io::Error>> {
68 Pin::new(&mut self.inner).poll_write(cx, buf)
69 }
70
71 fn poll_flush(
72 mut self: Pin<&mut Self>,
73 cx: &mut Context<'_>,
74 ) -> Poll<Result<(), std::io::Error>> {
75 Pin::new(&mut self.inner).poll_flush(cx)
76 }
77
78 fn poll_shutdown(
79 mut self: Pin<&mut Self>,
80 cx: &mut Context<'_>,
81 ) -> Poll<Result<(), std::io::Error>> {
82 Pin::new(&mut self.inner).poll_shutdown(cx)
83 }
84}
85
86#[async_trait]
87impl ProtobufStream for TcpStream {
88 async fn recv_message(&mut self) -> anyhow::Result<Option<ProtocolMessage>> {
89 let timeout_duration = Duration::from_secs(MESSAGE_TIMEOUT_SECS);
90 match timeout(timeout_duration, read_message(&mut self.inner)).await {
91 Ok(Ok(msg)) => Ok(Some(msg)),
92 Ok(Err(e)) => {
93 if let Some(io_err) = e.downcast_ref::<std::io::Error>() {
94 if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
95 return Ok(None);
96 }
97 }
98 Err(e)
99 }
100 Err(_) => Err(anyhow::anyhow!(
101 "Timeout reading message after {} seconds",
102 MESSAGE_TIMEOUT_SECS
103 )),
104 }
105 }
106
107 async fn send_message(&mut self, msg: &ProtocolMessage) -> anyhow::Result<()> {
108 let timeout_duration = Duration::from_secs(MESSAGE_TIMEOUT_SECS);
109 timeout(timeout_duration, write_message(&mut self.inner, msg))
110 .await
111 .map_err(|_| {
112 anyhow::anyhow!(
113 "Timeout writing message after {} seconds",
114 MESSAGE_TIMEOUT_SECS
115 )
116 })?
117 }
118
119 async fn close(&mut self) -> anyhow::Result<()> {
120 self.inner
121 .shutdown()
122 .await
123 .context("Failed to shutdown stream")
124 }
125}
126
127#[derive(Debug)]
128pub struct TcpTransport {
129 pub socket_opts: SocketOpts,
130 pub cfg: TcpConfig,
131}
132
133#[async_trait]
134impl Transport for TcpTransport {
135 type Acceptor = Listener;
136 type Stream = TcpStream;
137 type RawStream = Stream;
138
139 fn new(config: &TransportConfig) -> Result<Self> {
140 Ok(TcpTransport {
141 socket_opts: SocketOpts::for_control_channel(),
142 cfg: config.tcp.clone(),
143 })
144 }
145
146 #[cfg(unix)]
147 fn as_raw_fd(conn: &Self::Stream) -> RawFd {
148 use std::os::fd::AsRawFd;
149 match conn.get_ref() {
150 Stream::Tcp(tcp_stream) => tcp_stream.as_raw_fd(),
151 #[cfg(unix)]
152 Stream::Unix(unix_stream) => unix_stream.as_raw_fd(),
153 }
154 }
155
156 fn hint(conn: &Self::Stream, opt: SocketOpts) {
157 opt.apply(conn.get_ref());
158 }
159
160 async fn bind(&self, addr: NamedSocketAddr) -> Result<Self::Acceptor> {
161 #[cfg(unix)]
162 if let NamedSocketAddr::Unix(path) = &addr {
163 if path.exists() {
165 tokio::fs::remove_file(path).await?;
166 }
167 }
168 Ok(Listener::bind(&addr).await?)
169 }
170
171 async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::RawStream, SocketAddr)> {
172 let (s, addr) = a.accept().await?;
173 self.socket_opts.apply(&s);
174 Ok((s, addr))
175 }
176
177 async fn handshake(&self, conn: Self::RawStream) -> Result<Self::Stream> {
178 Ok(TcpStream::new(conn))
179 }
180
181 async fn connect(&self, addr: &AddrMaybeCached) -> Result<Self::Stream> {
182 let s = tcp_connect_with_proxy(addr, self.cfg.proxy.as_ref()).await?;
183 self.socket_opts.apply(&s);
184 Ok(TcpStream::new(s))
185 }
186}
187
188pub fn try_set_tcp_keepalive(
192 conn: &RawTcpStream,
193 keepalive_duration: Duration,
194 keepalive_interval: Duration,
195) -> Result<()> {
196 match conn {
197 Stream::Tcp(tcp_stream) => {
198 let s = SockRef::from(tcp_stream);
199 let keepalive = TcpKeepalive::new()
200 .with_time(keepalive_duration)
201 .with_interval(keepalive_interval);
202
203 trace!(
204 "Set TCP keepalive {:?} {:?}",
205 keepalive_duration,
206 keepalive_interval
207 );
208
209 Ok(s.set_tcp_keepalive(&keepalive)?)
210 }
211 #[cfg(unix)]
212 Stream::Unix(_) => {
213 Ok(())
215 }
216 }
217}
218
219pub async fn tcp_connect_with_proxy(addr: &AddrMaybeCached, proxy: Option<&Url>) -> Result<Stream> {
222 if let Some(url) = proxy {
223 let addr = &addr.addr;
224 let proxy_addr = format!(
225 "{}:{}",
226 url.host_str().expect("proxy url should have host field"),
227 url.port().expect("proxy url should have port field")
228 );
229 let mut s = Stream::connect(&NamedSocketAddr::from_str(&proxy_addr)?).await?;
230
231 let auth = if !url.username().is_empty() || url.password().is_some() {
232 Some(async_socks5::Auth {
233 username: url.username().into(),
234 password: url.password().unwrap_or("").into(),
235 })
236 } else {
237 None
238 };
239 match url.scheme() {
240 "socks5" => {
241 async_socks5::connect(&mut s, host_port_pair(addr)?, auth).await?;
242 }
243 "http" => {
244 let (host, port) = host_port_pair(addr)?;
245 match auth {
246 Some(auth) => {
247 http_connect_tokio_with_basic_auth(
248 &mut s,
249 host,
250 port,
251 &auth.username,
252 &auth.password,
253 )
254 .await?
255 }
256 None => http_connect_tokio(&mut s, host, port).await?,
257 }
258 }
259 _ => panic!("unknown proxy scheme"),
260 }
261 Ok(s)
262 } else {
263 Ok(match addr.socket_addr.as_ref() {
264 Some(s) => Stream::connect(s).await?,
265 None => Stream::connect(&NamedSocketAddr::from_str(&addr.addr)?).await?,
266 })
267 }
268}