1use std::{
2 io::{self, ErrorKind},
3 mem,
4 net::{Ipv4Addr, Ipv6Addr, SocketAddr},
5 os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
6 pin::Pin,
7 ptr,
8 sync::atomic::{AtomicBool, Ordering},
9 task::{self, Poll},
10};
11
12use log::{debug, error, warn};
13use pin_project::pin_project;
14use socket2::{Domain, Protocol, SockAddr, Socket, Type};
15use tokio::{
16 io::{AsyncRead, AsyncWrite, ReadBuf},
17 net::{TcpSocket, TcpStream as TokioTcpStream, UdpSocket},
18};
19use tokio_tfo::TfoStream;
20
21use crate::net::{
22 AcceptOpts, AddrFamily, ConnectOpts,
23 sys::{set_common_sockopt_after_connect, set_common_sockopt_for_connect, socket_bind_dual_stack},
24 udp::{BatchRecvMessage, BatchSendMessage},
25};
26
27#[pin_project(project = TcpStreamProj)]
29pub enum TcpStream {
30 Standard(#[pin] TokioTcpStream),
31 FastOpen(#[pin] TfoStream),
32}
33
34impl TcpStream {
35 pub async fn connect(addr: SocketAddr, opts: &ConnectOpts) -> io::Result<Self> {
36 if opts.tcp.mptcp {
37 return Self::connect_mptcp(addr, opts).await;
38 }
39
40 let socket = match addr {
41 SocketAddr::V4(..) => TcpSocket::new_v4()?,
42 SocketAddr::V6(..) => TcpSocket::new_v6()?,
43 };
44
45 Self::connect_with_socket(socket, addr, opts).await
46 }
47
48 async fn connect_mptcp(addr: SocketAddr, opts: &ConnectOpts) -> io::Result<Self> {
49 let socket = create_mptcp_socket(&addr)?;
50 Self::connect_with_socket(socket, addr, opts).await
51 }
52
53 async fn connect_with_socket(socket: TcpSocket, addr: SocketAddr, opts: &ConnectOpts) -> io::Result<Self> {
54 #[cfg(target_os = "android")]
57 if !addr.ip().is_loopback() {
58 android::vpn_protect(&socket, opts).await?;
59 }
60
61 if let Some(mark) = opts.fwmark {
64 let ret = unsafe {
65 libc::setsockopt(
66 socket.as_raw_fd(),
67 libc::SOL_SOCKET,
68 libc::SO_MARK,
69 &mark as *const _ as *const _,
70 mem::size_of_val(&mark) as libc::socklen_t,
71 )
72 };
73 if ret != 0 {
74 let err = io::Error::last_os_error();
75 error!("set SO_MARK error: {}", err);
76 return Err(err);
77 }
78 }
79
80 if let Some(ref iface) = opts.bind_interface {
82 set_bindtodevice(&socket, iface)?;
83 }
84
85 set_common_sockopt_for_connect(addr, &socket, opts)?;
86
87 if !opts.tcp.fastopen {
88 let stream = socket.connect(addr).await?;
90 set_common_sockopt_after_connect(&stream, opts)?;
91
92 return Ok(Self::Standard(stream));
93 }
94
95 let stream = TfoStream::connect_with_socket(socket, addr).await?;
96 set_common_sockopt_after_connect(&stream, opts)?;
97
98 Ok(Self::FastOpen(stream))
99 }
100
101 pub fn local_addr(&self) -> io::Result<SocketAddr> {
102 match *self {
103 Self::Standard(ref s) => s.local_addr(),
104 Self::FastOpen(ref s) => s.local_addr(),
105 }
106 }
107
108 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
109 match *self {
110 Self::Standard(ref s) => s.peer_addr(),
111 Self::FastOpen(ref s) => s.peer_addr(),
112 }
113 }
114
115 pub fn nodelay(&self) -> io::Result<bool> {
116 match *self {
117 Self::Standard(ref s) => s.nodelay(),
118 Self::FastOpen(ref s) => s.nodelay(),
119 }
120 }
121
122 pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
123 match *self {
124 Self::Standard(ref s) => s.set_nodelay(nodelay),
125 Self::FastOpen(ref s) => s.set_nodelay(nodelay),
126 }
127 }
128}
129
130impl AsRawFd for TcpStream {
131 fn as_raw_fd(&self) -> RawFd {
132 match *self {
133 Self::Standard(ref s) => s.as_raw_fd(),
134 Self::FastOpen(ref s) => s.as_raw_fd(),
135 }
136 }
137}
138
139impl AsyncRead for TcpStream {
140 fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
141 match self.project() {
142 TcpStreamProj::Standard(s) => s.poll_read(cx, buf),
143 TcpStreamProj::FastOpen(s) => s.poll_read(cx, buf),
144 }
145 }
146}
147
148impl AsyncWrite for TcpStream {
149 fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
150 match self.project() {
151 TcpStreamProj::Standard(s) => s.poll_write(cx, buf),
152 TcpStreamProj::FastOpen(s) => s.poll_write(cx, buf),
153 }
154 }
155
156 fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
157 match self.project() {
158 TcpStreamProj::Standard(s) => s.poll_flush(cx),
159 TcpStreamProj::FastOpen(s) => s.poll_flush(cx),
160 }
161 }
162
163 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
164 match self.project() {
165 TcpStreamProj::Standard(s) => s.poll_shutdown(cx),
166 TcpStreamProj::FastOpen(s) => s.poll_shutdown(cx),
167 }
168 }
169}
170
171pub fn set_tcp_fastopen<S: AsRawFd>(socket: &S) -> io::Result<()> {
175 let queue: libc::c_int = 1024;
184
185 unsafe {
186 let ret = libc::setsockopt(
187 socket.as_raw_fd(),
188 libc::IPPROTO_TCP,
189 libc::TCP_FASTOPEN,
190 &queue as *const _ as *const libc::c_void,
191 mem::size_of_val(&queue) as libc::socklen_t,
192 );
193
194 if ret != 0 {
195 let err = io::Error::last_os_error();
196 error!("set TCP_FASTOPEN error: {}", err);
197 return Err(err);
198 }
199 }
200
201 Ok(())
202}
203
204fn create_mptcp_socket(bind_addr: &SocketAddr) -> io::Result<TcpSocket> {
205 unsafe {
208 let family = match bind_addr {
209 SocketAddr::V4(..) => libc::AF_INET,
210 SocketAddr::V6(..) => libc::AF_INET6,
211 };
212 let fd = libc::socket(family, libc::SOCK_STREAM, libc::IPPROTO_MPTCP);
213 if fd < 0 {
214 let err = io::Error::last_os_error();
215 return Err(err);
216 }
217 let socket = Socket::from_raw_fd(fd);
218 socket.set_nonblocking(true)?;
219 Ok(TcpSocket::from_raw_fd(socket.into_raw_fd()))
220 }
221}
222
223pub async fn create_inbound_tcp_socket(bind_addr: &SocketAddr, accept_opts: &AcceptOpts) -> io::Result<TcpSocket> {
225 if accept_opts.tcp.mptcp {
226 create_mptcp_socket(bind_addr)
227 } else {
228 match bind_addr {
229 SocketAddr::V4(..) => TcpSocket::new_v4(),
230 SocketAddr::V6(..) => TcpSocket::new_v6(),
231 }
232 }
233}
234
235#[inline]
237pub fn set_disable_ip_fragmentation<S: AsRawFd>(af: AddrFamily, socket: &S) -> io::Result<()> {
238 unsafe {
242 let value: i32 = libc::IP_PMTUDISC_DO;
243 let ret = libc::setsockopt(
244 socket.as_raw_fd(),
245 libc::IPPROTO_IP,
246 libc::IP_MTU_DISCOVER,
247 &value as *const _ as *const _,
248 mem::size_of_val(&value) as libc::socklen_t,
249 );
250
251 if ret < 0 {
252 return Err(io::Error::last_os_error());
253 }
254
255 if af == AddrFamily::Ipv6 {
256 let value: i32 = libc::IP_PMTUDISC_DO;
257 let ret = libc::setsockopt(
258 socket.as_raw_fd(),
259 libc::IPPROTO_IPV6,
260 libc::IPV6_MTU_DISCOVER,
261 &value as *const _ as *const _,
262 mem::size_of_val(&value) as libc::socklen_t,
263 );
264
265 if ret < 0 {
266 return Err(io::Error::last_os_error());
267 }
268 }
269 }
270
271 Ok(())
272}
273
274#[inline]
276pub async fn create_outbound_udp_socket(af: AddrFamily, config: &ConnectOpts) -> io::Result<UdpSocket> {
277 let bind_addr = match (af, config.bind_local_addr) {
278 (AddrFamily::Ipv4, Some(SocketAddr::V4(addr))) => addr.into(),
279 (AddrFamily::Ipv4, Some(SocketAddr::V6(addr))) => {
280 match addr.ip().to_ipv4_mapped() {
282 Some(addr) => SocketAddr::new(addr.into(), 0),
283 None => return Err(io::Error::new(ErrorKind::InvalidInput, "Invalid IPv6 address")),
284 }
285 }
286 (AddrFamily::Ipv6, Some(SocketAddr::V6(addr))) => addr.into(),
287 (AddrFamily::Ipv6, Some(SocketAddr::V4(addr))) => {
288 SocketAddr::new(addr.ip().to_ipv6_mapped().into(), 0)
290 }
291 (AddrFamily::Ipv4, ..) => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
292 (AddrFamily::Ipv6, ..) => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
293 };
294
295 bind_outbound_udp_socket(&bind_addr, config).await
296}
297
298pub async fn bind_outbound_udp_socket(bind_addr: &SocketAddr, config: &ConnectOpts) -> io::Result<UdpSocket> {
300 let af = AddrFamily::from(bind_addr);
301
302 let socket = if af != AddrFamily::Ipv6 {
303 UdpSocket::bind(bind_addr).await?
304 } else {
305 let socket = Socket::new(Domain::for_address(*bind_addr), Type::DGRAM, Some(Protocol::UDP))?;
306 socket_bind_dual_stack(&socket, bind_addr, false)?;
307
308 socket.set_nonblocking(true)?;
310 UdpSocket::from_std(socket.into())?
311 };
312
313 if !config.udp.allow_fragmentation {
314 if let Err(err) = set_disable_ip_fragmentation(af, &socket) {
315 warn!("failed to disable IP fragmentation, error: {}", err);
316 }
317 }
318
319 #[cfg(target_os = "android")]
322 android::vpn_protect(&socket, config).await?;
323
324 if let Some(mark) = config.fwmark {
327 let ret = unsafe {
328 libc::setsockopt(
329 socket.as_raw_fd(),
330 libc::SOL_SOCKET,
331 libc::SO_MARK,
332 &mark as *const _ as *const _,
333 mem::size_of_val(&mark) as libc::socklen_t,
334 )
335 };
336 if ret != 0 {
337 let err = io::Error::last_os_error();
338 error!("set SO_MARK error: {}", err);
339 return Err(err);
340 }
341 }
342
343 if let Some(ref iface) = config.bind_interface {
345 set_bindtodevice(&socket, iface)?;
346 }
347
348 Ok(socket)
349}
350
351fn set_bindtodevice<S: AsRawFd>(socket: &S, iface: &str) -> io::Result<()> {
352 let iface_bytes = iface.as_bytes();
353
354 unsafe {
355 let ret = libc::setsockopt(
356 socket.as_raw_fd(),
357 libc::SOL_SOCKET,
358 libc::SO_BINDTODEVICE,
359 iface_bytes.as_ptr() as *const _ as *const libc::c_void,
360 iface_bytes.len() as libc::socklen_t,
361 );
362
363 if ret != 0 {
364 let err = io::Error::last_os_error();
365 error!("set SO_BINDTODEVICE error: {}", err);
366 return Err(err);
367 }
368 }
369
370 Ok(())
371}
372
373#[cfg(target_os = "android")]
374mod android {
375 use std::{
376 io::{self, ErrorKind},
377 os::unix::io::{AsRawFd, RawFd},
378 path::Path,
379 time::Duration,
380 };
381 use tokio::{io::AsyncReadExt, time};
382
383 use super::super::uds::UnixStream;
384 use super::ConnectOpts;
385
386 async fn send_vpn_protect_uds<P: AsRef<Path>>(protect_path: P, fd: RawFd) -> io::Result<()> {
392 let mut stream = UnixStream::connect(protect_path).await?;
393
394 let dummy: [u8; 1] = [1];
396 let fds: [RawFd; 1] = [fd];
397 stream.send_with_fd(&dummy, &fds).await?;
398
399 let mut response = [0; 1];
401 stream.read_exact(&mut response).await?;
402
403 if response[0] == 0xFF {
404 return Err(io::Error::other("protect() failed"));
405 }
406
407 Ok(())
408 }
409
410 pub async fn vpn_protect<S>(socket: &S, opts: &ConnectOpts) -> io::Result<()>
414 where
415 S: AsRawFd + Send + Sync + 'static,
416 {
417 if let Some(ref path) = opts.vpn_protect_path {
419 match time::timeout(Duration::from_secs(3), send_vpn_protect_uds(path, socket.as_raw_fd())).await {
422 Ok(Ok(..)) => {}
423 Ok(Err(err)) => return Err(err),
424 Err(..) => return Err(io::Error::new(ErrorKind::TimedOut, "protect() timeout")),
425 }
426 }
427
428 if let Some(ref protect) = opts.vpn_socket_protect {
430 protect.protect(socket.as_raw_fd())?;
431 }
432
433 Ok(())
434 }
435}
436
437static SUPPORT_BATCH_SEND_RECV_MSG: AtomicBool = AtomicBool::new(true);
438
439fn recvmsg_fallback<S: AsRawFd>(sock: &S, msg: &mut BatchRecvMessage<'_>) -> io::Result<()> {
440 let mut hdr: libc::msghdr = unsafe { mem::zeroed() };
441
442 let addr_storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
443 let addr_len = mem::size_of_val(&addr_storage) as libc::socklen_t;
444 let sock_addr = unsafe { SockAddr::new(addr_storage, addr_len) };
445 hdr.msg_name = sock_addr.as_ptr() as *mut _;
446 hdr.msg_namelen = sock_addr.len() as _;
447
448 hdr.msg_iov = msg.data.as_ptr() as *mut _;
449 hdr.msg_iovlen = msg.data.len() as _;
450
451 let ret = unsafe { libc::recvmsg(sock.as_raw_fd(), &mut hdr as *mut _, 0) };
452 if ret < 0 {
453 return Err(io::Error::last_os_error());
454 }
455
456 msg.addr = sock_addr.as_socket().expect("SockAddr.as_socket");
457 msg.data_len = ret as usize;
458
459 Ok(())
460}
461
462pub fn batch_recvmsg<S: AsRawFd>(sock: &S, msgs: &mut [BatchRecvMessage<'_>]) -> io::Result<usize> {
463 if msgs.is_empty() {
464 return Ok(0);
465 }
466
467 if !SUPPORT_BATCH_SEND_RECV_MSG.load(Ordering::Relaxed) {
468 recvmsg_fallback(sock, &mut msgs[0])?;
469 return Ok(1);
470 }
471
472 let mut vec_msg_name = Vec::with_capacity(msgs.len());
473 let mut vec_msg_hdr = Vec::with_capacity(msgs.len());
474
475 for msg in msgs.iter_mut() {
476 let mut hdr: libc::mmsghdr = unsafe { mem::zeroed() };
477
478 let addr_storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
479 let addr_len = mem::size_of_val(&addr_storage) as libc::socklen_t;
480
481 vec_msg_name.push(unsafe { SockAddr::new(addr_storage, addr_len) });
482 let sock_addr = vec_msg_name.last_mut().unwrap();
483 hdr.msg_hdr.msg_name = sock_addr.as_ptr() as *mut _;
484 hdr.msg_hdr.msg_namelen = sock_addr.len() as _;
485
486 hdr.msg_hdr.msg_iov = msg.data.as_ptr() as *mut _;
487 hdr.msg_hdr.msg_iovlen = msg.data.len() as _;
488
489 vec_msg_hdr.push(hdr);
490 }
491
492 let ret = unsafe {
493 libc::recvmmsg(
494 sock.as_raw_fd(),
495 vec_msg_hdr.as_mut_ptr(),
496 vec_msg_hdr.len() as _,
497 0,
498 ptr::null_mut(),
499 )
500 };
501 if ret < 0 {
502 let err = io::Error::last_os_error();
503 if let Some(libc::ENOSYS) = err.raw_os_error() {
504 debug!("recvmmsg is not supported, fallback to recvmsg, error: {:?}", err);
505 SUPPORT_BATCH_SEND_RECV_MSG.store(false, Ordering::Relaxed);
506
507 recvmsg_fallback(sock, &mut msgs[0])?;
508 return Ok(1);
509 }
510 return Err(err);
511 }
512
513 for idx in 0..ret as usize {
514 let msg = &mut msgs[idx];
515 let hdr = &vec_msg_hdr[idx];
516 let name = &vec_msg_name[idx];
517 msg.addr = name.as_socket().expect("SockAddr.as_socket");
518 msg.data_len = hdr.msg_len as usize;
519 }
520
521 Ok(ret as usize)
522}
523
524fn sendmsg_fallback<S: AsRawFd>(sock: &S, msg: &mut BatchSendMessage<'_>) -> io::Result<()> {
525 let mut hdr: libc::msghdr = unsafe { mem::zeroed() };
526
527 let sock_addr = msg.addr.map(SockAddr::from);
528 if let Some(ref sa) = sock_addr {
529 hdr.msg_name = sa.as_ptr() as *mut _;
530 hdr.msg_namelen = sa.len() as _;
531 }
532
533 hdr.msg_iov = msg.data.as_ptr() as *mut _;
534 hdr.msg_iovlen = msg.data.len() as _;
535
536 let ret = unsafe { libc::sendmsg(sock.as_raw_fd(), &hdr as *const _, 0) };
537 if ret < 0 {
538 return Err(io::Error::last_os_error());
539 }
540 msg.data_len = ret as usize;
541
542 Ok(())
543}
544
545pub fn batch_sendmsg<S: AsRawFd>(sock: &S, msgs: &mut [BatchSendMessage<'_>]) -> io::Result<usize> {
546 if msgs.is_empty() {
547 return Ok(0);
548 }
549
550 if !SUPPORT_BATCH_SEND_RECV_MSG.load(Ordering::Relaxed) {
551 sendmsg_fallback(sock, &mut msgs[0])?;
552 return Ok(1);
553 }
554
555 let mut vec_msg_name = Vec::with_capacity(msgs.len());
556 let mut vec_msg_hdr = Vec::with_capacity(msgs.len());
557
558 for msg in msgs.iter_mut() {
559 let mut hdr: libc::mmsghdr = unsafe { mem::zeroed() };
560
561 if let Some(addr) = msg.addr {
562 vec_msg_name.push(SockAddr::from(addr));
563 let sock_addr = vec_msg_name.last_mut().unwrap();
564 hdr.msg_hdr.msg_name = sock_addr.as_ptr() as *mut _;
565 hdr.msg_hdr.msg_namelen = sock_addr.len() as _;
566 }
567
568 hdr.msg_hdr.msg_iov = msg.data.as_ptr() as *mut _;
569 hdr.msg_hdr.msg_iovlen = msg.data.len() as _;
570
571 vec_msg_hdr.push(hdr);
572 }
573
574 let ret = unsafe { libc::sendmmsg(sock.as_raw_fd(), vec_msg_hdr.as_mut_ptr(), vec_msg_hdr.len() as _, 0) };
575 if ret < 0 {
576 let err = io::Error::last_os_error();
577 if let Some(libc::ENOSYS) = err.raw_os_error() {
578 debug!("sendmmsg is not supported, fallback to sendmsg, error: {:?}", err);
579 SUPPORT_BATCH_SEND_RECV_MSG.store(false, Ordering::Relaxed);
580
581 sendmsg_fallback(sock, &mut msgs[0])?;
582 return Ok(1);
583 }
584 return Err(err);
585 }
586
587 for idx in 0..ret as usize {
588 let msg = &mut msgs[idx];
589 let hdr = &vec_msg_hdr[idx];
590 msg.data_len = hdr.msg_len as usize;
591 }
592
593 Ok(ret as usize)
594}