shadowsocks/net/sys/unix/linux/
mod.rs

1use std::{
2    io::{self, ErrorKind},
3    mem,
4    net::{Ipv4Addr, Ipv6Addr, SocketAddr},
5    os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
6    pin::Pin,
7    ptr,
8    sync::atomic::{AtomicBool, Ordering},
9    task::{self, Poll},
10};
11
12use log::{debug, error, warn};
13use pin_project::pin_project;
14use socket2::{Domain, Protocol, SockAddr, Socket, Type};
15use tokio::{
16    io::{AsyncRead, AsyncWrite, ReadBuf},
17    net::{TcpSocket, TcpStream as TokioTcpStream, UdpSocket},
18};
19use tokio_tfo::TfoStream;
20
21use crate::net::{
22    AcceptOpts, AddrFamily, ConnectOpts,
23    sys::{set_common_sockopt_after_connect, set_common_sockopt_for_connect, socket_bind_dual_stack},
24    udp::{BatchRecvMessage, BatchSendMessage},
25};
26
27/// A `TcpStream` that supports TFO (TCP Fast Open)
28#[pin_project(project = TcpStreamProj)]
29pub enum TcpStream {
30    Standard(#[pin] TokioTcpStream),
31    FastOpen(#[pin] TfoStream),
32}
33
34impl TcpStream {
35    pub async fn connect(addr: SocketAddr, opts: &ConnectOpts) -> io::Result<Self> {
36        if opts.tcp.mptcp {
37            return Self::connect_mptcp(addr, opts).await;
38        }
39
40        let socket = match addr {
41            SocketAddr::V4(..) => TcpSocket::new_v4()?,
42            SocketAddr::V6(..) => TcpSocket::new_v6()?,
43        };
44
45        Self::connect_with_socket(socket, addr, opts).await
46    }
47
48    async fn connect_mptcp(addr: SocketAddr, opts: &ConnectOpts) -> io::Result<Self> {
49        let socket = create_mptcp_socket(&addr)?;
50        Self::connect_with_socket(socket, addr, opts).await
51    }
52
53    async fn connect_with_socket(socket: TcpSocket, addr: SocketAddr, opts: &ConnectOpts) -> io::Result<Self> {
54        // Any traffic to localhost should not be protected
55        // This is a workaround for VPNService
56        #[cfg(target_os = "android")]
57        if !addr.ip().is_loopback() {
58            android::vpn_protect(&socket, opts).await?;
59        }
60
61        // Set SO_MARK for mark-based routing on Linux (since 2.6.25)
62        // NOTE: This will require CAP_NET_ADMIN capability (root in most cases)
63        if let Some(mark) = opts.fwmark {
64            let ret = unsafe {
65                libc::setsockopt(
66                    socket.as_raw_fd(),
67                    libc::SOL_SOCKET,
68                    libc::SO_MARK,
69                    &mark as *const _ as *const _,
70                    mem::size_of_val(&mark) as libc::socklen_t,
71                )
72            };
73            if ret != 0 {
74                let err = io::Error::last_os_error();
75                error!("set SO_MARK error: {}", err);
76                return Err(err);
77            }
78        }
79
80        // Set SO_BINDTODEVICE for binding to a specific interface
81        if let Some(ref iface) = opts.bind_interface {
82            set_bindtodevice(&socket, iface)?;
83        }
84
85        set_common_sockopt_for_connect(addr, &socket, opts)?;
86
87        if !opts.tcp.fastopen {
88            // If TFO is not enabled, it just works like a normal TcpStream
89            let stream = socket.connect(addr).await?;
90            set_common_sockopt_after_connect(&stream, opts)?;
91
92            return Ok(Self::Standard(stream));
93        }
94
95        let stream = TfoStream::connect_with_socket(socket, addr).await?;
96        set_common_sockopt_after_connect(&stream, opts)?;
97
98        Ok(Self::FastOpen(stream))
99    }
100
101    pub fn local_addr(&self) -> io::Result<SocketAddr> {
102        match *self {
103            Self::Standard(ref s) => s.local_addr(),
104            Self::FastOpen(ref s) => s.local_addr(),
105        }
106    }
107
108    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
109        match *self {
110            Self::Standard(ref s) => s.peer_addr(),
111            Self::FastOpen(ref s) => s.peer_addr(),
112        }
113    }
114
115    pub fn nodelay(&self) -> io::Result<bool> {
116        match *self {
117            Self::Standard(ref s) => s.nodelay(),
118            Self::FastOpen(ref s) => s.nodelay(),
119        }
120    }
121
122    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
123        match *self {
124            Self::Standard(ref s) => s.set_nodelay(nodelay),
125            Self::FastOpen(ref s) => s.set_nodelay(nodelay),
126        }
127    }
128}
129
130impl AsRawFd for TcpStream {
131    fn as_raw_fd(&self) -> RawFd {
132        match *self {
133            Self::Standard(ref s) => s.as_raw_fd(),
134            Self::FastOpen(ref s) => s.as_raw_fd(),
135        }
136    }
137}
138
139impl AsyncRead for TcpStream {
140    fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
141        match self.project() {
142            TcpStreamProj::Standard(s) => s.poll_read(cx, buf),
143            TcpStreamProj::FastOpen(s) => s.poll_read(cx, buf),
144        }
145    }
146}
147
148impl AsyncWrite for TcpStream {
149    fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
150        match self.project() {
151            TcpStreamProj::Standard(s) => s.poll_write(cx, buf),
152            TcpStreamProj::FastOpen(s) => s.poll_write(cx, buf),
153        }
154    }
155
156    fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
157        match self.project() {
158            TcpStreamProj::Standard(s) => s.poll_flush(cx),
159            TcpStreamProj::FastOpen(s) => s.poll_flush(cx),
160        }
161    }
162
163    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
164        match self.project() {
165            TcpStreamProj::Standard(s) => s.poll_shutdown(cx),
166            TcpStreamProj::FastOpen(s) => s.poll_shutdown(cx),
167        }
168    }
169}
170
171/// Enable `TCP_FASTOPEN`
172///
173/// `TCP_FASTOPEN` was supported since Linux 3.7
174pub fn set_tcp_fastopen<S: AsRawFd>(socket: &S) -> io::Result<()> {
175    // https://lwn.net/Articles/508865/
176    //
177    // The option value, qlen, specifies this server's limit on the size of the queue of TFO requests that have
178    // not yet completed the three-way handshake (see the remarks on prevention of resource-exhaustion attacks above).
179    //
180    // It was recommended to be `5` in this document.
181    //
182    // But since mio's TcpListener sets backlogs to 1024, it would be nice to have 1024 slots for handshaking TFO requests.
183    let queue: libc::c_int = 1024;
184
185    unsafe {
186        let ret = libc::setsockopt(
187            socket.as_raw_fd(),
188            libc::IPPROTO_TCP,
189            libc::TCP_FASTOPEN,
190            &queue as *const _ as *const libc::c_void,
191            mem::size_of_val(&queue) as libc::socklen_t,
192        );
193
194        if ret != 0 {
195            let err = io::Error::last_os_error();
196            error!("set TCP_FASTOPEN error: {}", err);
197            return Err(err);
198        }
199    }
200
201    Ok(())
202}
203
204fn create_mptcp_socket(bind_addr: &SocketAddr) -> io::Result<TcpSocket> {
205    // https://www.kernel.org/doc/html/next/networking/mptcp.html
206
207    unsafe {
208        let family = match bind_addr {
209            SocketAddr::V4(..) => libc::AF_INET,
210            SocketAddr::V6(..) => libc::AF_INET6,
211        };
212        let fd = libc::socket(family, libc::SOCK_STREAM, libc::IPPROTO_MPTCP);
213        if fd < 0 {
214            let err = io::Error::last_os_error();
215            return Err(err);
216        }
217        let socket = Socket::from_raw_fd(fd);
218        socket.set_nonblocking(true)?;
219        Ok(TcpSocket::from_raw_fd(socket.into_raw_fd()))
220    }
221}
222
223/// Create a TCP socket for listening
224pub async fn create_inbound_tcp_socket(bind_addr: &SocketAddr, accept_opts: &AcceptOpts) -> io::Result<TcpSocket> {
225    if accept_opts.tcp.mptcp {
226        create_mptcp_socket(bind_addr)
227    } else {
228        match bind_addr {
229            SocketAddr::V4(..) => TcpSocket::new_v4(),
230            SocketAddr::V6(..) => TcpSocket::new_v6(),
231        }
232    }
233}
234
235/// Disable IP fragmentation
236#[inline]
237pub fn set_disable_ip_fragmentation<S: AsRawFd>(af: AddrFamily, socket: &S) -> io::Result<()> {
238    // For Linux, IP_MTU_DISCOVER should be enabled for both IPv4 and IPv6 sockets
239    // https://man7.org/linux/man-pages/man7/ip.7.html
240
241    unsafe {
242        let value: i32 = libc::IP_PMTUDISC_DO;
243        let ret = libc::setsockopt(
244            socket.as_raw_fd(),
245            libc::IPPROTO_IP,
246            libc::IP_MTU_DISCOVER,
247            &value as *const _ as *const _,
248            mem::size_of_val(&value) as libc::socklen_t,
249        );
250
251        if ret < 0 {
252            return Err(io::Error::last_os_error());
253        }
254
255        if af == AddrFamily::Ipv6 {
256            let value: i32 = libc::IP_PMTUDISC_DO;
257            let ret = libc::setsockopt(
258                socket.as_raw_fd(),
259                libc::IPPROTO_IPV6,
260                libc::IPV6_MTU_DISCOVER,
261                &value as *const _ as *const _,
262                mem::size_of_val(&value) as libc::socklen_t,
263            );
264
265            if ret < 0 {
266                return Err(io::Error::last_os_error());
267            }
268        }
269    }
270
271    Ok(())
272}
273
274/// Create a `UdpSocket` with specific address family
275#[inline]
276pub async fn create_outbound_udp_socket(af: AddrFamily, config: &ConnectOpts) -> io::Result<UdpSocket> {
277    let bind_addr = match (af, config.bind_local_addr) {
278        (AddrFamily::Ipv4, Some(SocketAddr::V4(addr))) => addr.into(),
279        (AddrFamily::Ipv4, Some(SocketAddr::V6(addr))) => {
280            // Map IPv6 bind_local_addr to IPv4 if AF is IPv4
281            match addr.ip().to_ipv4_mapped() {
282                Some(addr) => SocketAddr::new(addr.into(), 0),
283                None => return Err(io::Error::new(ErrorKind::InvalidInput, "Invalid IPv6 address")),
284            }
285        }
286        (AddrFamily::Ipv6, Some(SocketAddr::V6(addr))) => addr.into(),
287        (AddrFamily::Ipv6, Some(SocketAddr::V4(addr))) => {
288            // Map IPv4 bind_local_addr to IPv6 if AF is IPv6
289            SocketAddr::new(addr.ip().to_ipv6_mapped().into(), 0)
290        }
291        (AddrFamily::Ipv4, ..) => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
292        (AddrFamily::Ipv6, ..) => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
293    };
294
295    bind_outbound_udp_socket(&bind_addr, config).await
296}
297
298/// Create a `UdpSocket` binded to `bind_addr`
299pub async fn bind_outbound_udp_socket(bind_addr: &SocketAddr, config: &ConnectOpts) -> io::Result<UdpSocket> {
300    let af = AddrFamily::from(bind_addr);
301
302    let socket = if af != AddrFamily::Ipv6 {
303        UdpSocket::bind(bind_addr).await?
304    } else {
305        let socket = Socket::new(Domain::for_address(*bind_addr), Type::DGRAM, Some(Protocol::UDP))?;
306        socket_bind_dual_stack(&socket, bind_addr, false)?;
307
308        // UdpSocket::from_std requires socket to be non-blocked
309        socket.set_nonblocking(true)?;
310        UdpSocket::from_std(socket.into())?
311    };
312
313    if !config.udp.allow_fragmentation {
314        if let Err(err) = set_disable_ip_fragmentation(af, &socket) {
315            warn!("failed to disable IP fragmentation, error: {}", err);
316        }
317    }
318
319    // Any traffic except localhost should be protected
320    // This is a workaround for VPNService
321    #[cfg(target_os = "android")]
322    android::vpn_protect(&socket, config).await?;
323
324    // Set SO_MARK for mark-based routing on Linux (since 2.6.25)
325    // NOTE: This will require CAP_NET_ADMIN capability (root in most cases)
326    if let Some(mark) = config.fwmark {
327        let ret = unsafe {
328            libc::setsockopt(
329                socket.as_raw_fd(),
330                libc::SOL_SOCKET,
331                libc::SO_MARK,
332                &mark as *const _ as *const _,
333                mem::size_of_val(&mark) as libc::socklen_t,
334            )
335        };
336        if ret != 0 {
337            let err = io::Error::last_os_error();
338            error!("set SO_MARK error: {}", err);
339            return Err(err);
340        }
341    }
342
343    // Set SO_BINDTODEVICE for binding to a specific interface
344    if let Some(ref iface) = config.bind_interface {
345        set_bindtodevice(&socket, iface)?;
346    }
347
348    Ok(socket)
349}
350
351fn set_bindtodevice<S: AsRawFd>(socket: &S, iface: &str) -> io::Result<()> {
352    let iface_bytes = iface.as_bytes();
353
354    unsafe {
355        let ret = libc::setsockopt(
356            socket.as_raw_fd(),
357            libc::SOL_SOCKET,
358            libc::SO_BINDTODEVICE,
359            iface_bytes.as_ptr() as *const _ as *const libc::c_void,
360            iface_bytes.len() as libc::socklen_t,
361        );
362
363        if ret != 0 {
364            let err = io::Error::last_os_error();
365            error!("set SO_BINDTODEVICE error: {}", err);
366            return Err(err);
367        }
368    }
369
370    Ok(())
371}
372
373#[cfg(target_os = "android")]
374mod android {
375    use std::{
376        io::{self, ErrorKind},
377        os::unix::io::{AsRawFd, RawFd},
378        path::Path,
379        time::Duration,
380    };
381    use tokio::{io::AsyncReadExt, time};
382
383    use super::super::uds::UnixStream;
384    use super::ConnectOpts;
385
386    /// This is a RPC for Android to `protect()` socket for connecting to remote servers
387    ///
388    /// https://developer.android.com/reference/android/net/VpnService#protect(java.net.Socket)
389    ///
390    /// More detail could be found in [shadowsocks-android](https://github.com/shadowsocks/shadowsocks-android) project.
391    async fn send_vpn_protect_uds<P: AsRef<Path>>(protect_path: P, fd: RawFd) -> io::Result<()> {
392        let mut stream = UnixStream::connect(protect_path).await?;
393
394        // send fds
395        let dummy: [u8; 1] = [1];
396        let fds: [RawFd; 1] = [fd];
397        stream.send_with_fd(&dummy, &fds).await?;
398
399        // receive the return value
400        let mut response = [0; 1];
401        stream.read_exact(&mut response).await?;
402
403        if response[0] == 0xFF {
404            return Err(io::Error::other("protect() failed"));
405        }
406
407        Ok(())
408    }
409
410    /// Try to run VPNService#protect on Android
411    ///
412    /// https://developer.android.com/reference/android/net/VpnService#protect(java.net.Socket)
413    pub async fn vpn_protect<S>(socket: &S, opts: &ConnectOpts) -> io::Result<()>
414    where
415        S: AsRawFd + Send + Sync + 'static,
416    {
417        // shadowsocks-android uses a Unix domain socket to communicate with the VPNService#protect
418        if let Some(ref path) = opts.vpn_protect_path {
419            // RPC calls to `VpnService.protect()`
420            // Timeout in 3 seconds like shadowsocks-libev
421            match time::timeout(Duration::from_secs(3), send_vpn_protect_uds(path, socket.as_raw_fd())).await {
422                Ok(Ok(..)) => {}
423                Ok(Err(err)) => return Err(err),
424                Err(..) => return Err(io::Error::new(ErrorKind::TimedOut, "protect() timeout")),
425            }
426        }
427
428        // Customized SocketProtect
429        if let Some(ref protect) = opts.vpn_socket_protect {
430            protect.protect(socket.as_raw_fd())?;
431        }
432
433        Ok(())
434    }
435}
436
437static SUPPORT_BATCH_SEND_RECV_MSG: AtomicBool = AtomicBool::new(true);
438
439fn recvmsg_fallback<S: AsRawFd>(sock: &S, msg: &mut BatchRecvMessage<'_>) -> io::Result<()> {
440    let mut hdr: libc::msghdr = unsafe { mem::zeroed() };
441
442    let addr_storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
443    let addr_len = mem::size_of_val(&addr_storage) as libc::socklen_t;
444    let sock_addr = unsafe { SockAddr::new(addr_storage, addr_len) };
445    hdr.msg_name = sock_addr.as_ptr() as *mut _;
446    hdr.msg_namelen = sock_addr.len() as _;
447
448    hdr.msg_iov = msg.data.as_ptr() as *mut _;
449    hdr.msg_iovlen = msg.data.len() as _;
450
451    let ret = unsafe { libc::recvmsg(sock.as_raw_fd(), &mut hdr as *mut _, 0) };
452    if ret < 0 {
453        return Err(io::Error::last_os_error());
454    }
455
456    msg.addr = sock_addr.as_socket().expect("SockAddr.as_socket");
457    msg.data_len = ret as usize;
458
459    Ok(())
460}
461
462pub fn batch_recvmsg<S: AsRawFd>(sock: &S, msgs: &mut [BatchRecvMessage<'_>]) -> io::Result<usize> {
463    if msgs.is_empty() {
464        return Ok(0);
465    }
466
467    if !SUPPORT_BATCH_SEND_RECV_MSG.load(Ordering::Relaxed) {
468        recvmsg_fallback(sock, &mut msgs[0])?;
469        return Ok(1);
470    }
471
472    let mut vec_msg_name = Vec::with_capacity(msgs.len());
473    let mut vec_msg_hdr = Vec::with_capacity(msgs.len());
474
475    for msg in msgs.iter_mut() {
476        let mut hdr: libc::mmsghdr = unsafe { mem::zeroed() };
477
478        let addr_storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
479        let addr_len = mem::size_of_val(&addr_storage) as libc::socklen_t;
480
481        vec_msg_name.push(unsafe { SockAddr::new(addr_storage, addr_len) });
482        let sock_addr = vec_msg_name.last_mut().unwrap();
483        hdr.msg_hdr.msg_name = sock_addr.as_ptr() as *mut _;
484        hdr.msg_hdr.msg_namelen = sock_addr.len() as _;
485
486        hdr.msg_hdr.msg_iov = msg.data.as_ptr() as *mut _;
487        hdr.msg_hdr.msg_iovlen = msg.data.len() as _;
488
489        vec_msg_hdr.push(hdr);
490    }
491
492    let ret = unsafe {
493        libc::recvmmsg(
494            sock.as_raw_fd(),
495            vec_msg_hdr.as_mut_ptr(),
496            vec_msg_hdr.len() as _,
497            0,
498            ptr::null_mut(),
499        )
500    };
501    if ret < 0 {
502        let err = io::Error::last_os_error();
503        if let Some(libc::ENOSYS) = err.raw_os_error() {
504            debug!("recvmmsg is not supported, fallback to recvmsg, error: {:?}", err);
505            SUPPORT_BATCH_SEND_RECV_MSG.store(false, Ordering::Relaxed);
506
507            recvmsg_fallback(sock, &mut msgs[0])?;
508            return Ok(1);
509        }
510        return Err(err);
511    }
512
513    for idx in 0..ret as usize {
514        let msg = &mut msgs[idx];
515        let hdr = &vec_msg_hdr[idx];
516        let name = &vec_msg_name[idx];
517        msg.addr = name.as_socket().expect("SockAddr.as_socket");
518        msg.data_len = hdr.msg_len as usize;
519    }
520
521    Ok(ret as usize)
522}
523
524fn sendmsg_fallback<S: AsRawFd>(sock: &S, msg: &mut BatchSendMessage<'_>) -> io::Result<()> {
525    let mut hdr: libc::msghdr = unsafe { mem::zeroed() };
526
527    let sock_addr = msg.addr.map(SockAddr::from);
528    if let Some(ref sa) = sock_addr {
529        hdr.msg_name = sa.as_ptr() as *mut _;
530        hdr.msg_namelen = sa.len() as _;
531    }
532
533    hdr.msg_iov = msg.data.as_ptr() as *mut _;
534    hdr.msg_iovlen = msg.data.len() as _;
535
536    let ret = unsafe { libc::sendmsg(sock.as_raw_fd(), &hdr as *const _, 0) };
537    if ret < 0 {
538        return Err(io::Error::last_os_error());
539    }
540    msg.data_len = ret as usize;
541
542    Ok(())
543}
544
545pub fn batch_sendmsg<S: AsRawFd>(sock: &S, msgs: &mut [BatchSendMessage<'_>]) -> io::Result<usize> {
546    if msgs.is_empty() {
547        return Ok(0);
548    }
549
550    if !SUPPORT_BATCH_SEND_RECV_MSG.load(Ordering::Relaxed) {
551        sendmsg_fallback(sock, &mut msgs[0])?;
552        return Ok(1);
553    }
554
555    let mut vec_msg_name = Vec::with_capacity(msgs.len());
556    let mut vec_msg_hdr = Vec::with_capacity(msgs.len());
557
558    for msg in msgs.iter_mut() {
559        let mut hdr: libc::mmsghdr = unsafe { mem::zeroed() };
560
561        if let Some(addr) = msg.addr {
562            vec_msg_name.push(SockAddr::from(addr));
563            let sock_addr = vec_msg_name.last_mut().unwrap();
564            hdr.msg_hdr.msg_name = sock_addr.as_ptr() as *mut _;
565            hdr.msg_hdr.msg_namelen = sock_addr.len() as _;
566        }
567
568        hdr.msg_hdr.msg_iov = msg.data.as_ptr() as *mut _;
569        hdr.msg_hdr.msg_iovlen = msg.data.len() as _;
570
571        vec_msg_hdr.push(hdr);
572    }
573
574    let ret = unsafe { libc::sendmmsg(sock.as_raw_fd(), vec_msg_hdr.as_mut_ptr(), vec_msg_hdr.len() as _, 0) };
575    if ret < 0 {
576        let err = io::Error::last_os_error();
577        if let Some(libc::ENOSYS) = err.raw_os_error() {
578            debug!("sendmmsg is not supported, fallback to sendmsg, error: {:?}", err);
579            SUPPORT_BATCH_SEND_RECV_MSG.store(false, Ordering::Relaxed);
580
581            sendmsg_fallback(sock, &mut msgs[0])?;
582            return Ok(1);
583        }
584        return Err(err);
585    }
586
587    for idx in 0..ret as usize {
588        let msg = &mut msgs[idx];
589        let hdr = &vec_msg_hdr[idx];
590        msg.data_len = hdr.msg_len as usize;
591    }
592
593    Ok(ret as usize)
594}