Skip to main content

relay_core_lib/capture/
linux_tproxy.rs

1#[cfg(target_os = "linux")]
2use std::io;
3#[cfg(target_os = "linux")]
4use std::mem;
5#[cfg(target_os = "linux")]
6use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
7#[cfg(target_os = "linux")]
8use std::os::unix::io::AsRawFd;
9#[cfg(target_os = "linux")]
10use tokio::net::UdpSocket;
11
12#[cfg(target_os = "linux")]
13pub struct LinuxTproxy;
14
15#[cfg(target_os = "linux")]
16impl LinuxTproxy {
17    /// Enable IP_TRANSPARENT and IP_RECVORIGDSTADDR on the socket
18    pub fn enable_tproxy(socket: &UdpSocket) -> io::Result<()> {
19        let fd = socket.as_raw_fd();
20        unsafe {
21            let enable: libc::c_int = 1;
22
23            // IP_TRANSPARENT (19)
24            if libc::setsockopt(
25                fd,
26                libc::SOL_IP,
27                libc::IP_TRANSPARENT,
28                &enable as *const _ as *const libc::c_void,
29                mem::size_of::<libc::c_int>() as libc::socklen_t,
30            ) < 0
31            {
32                return Err(io::Error::last_os_error());
33            }
34
35            // IP_RECVORIGDSTADDR (20)
36            if libc::setsockopt(
37                fd,
38                libc::SOL_IP,
39                libc::IP_RECVORIGDSTADDR,
40                &enable as *const _ as *const libc::c_void,
41                mem::size_of::<libc::c_int>() as libc::socklen_t,
42            ) < 0
43            {
44                return Err(io::Error::last_os_error());
45            }
46
47            // IPV6_TRANSPARENT (75)
48            // We ignore errors here as IPv6 might be disabled
49            let _ = libc::setsockopt(
50                fd,
51                libc::SOL_IPV6,
52                libc::IPV6_TRANSPARENT,
53                &enable as *const _ as *const libc::c_void,
54                mem::size_of::<libc::c_int>() as libc::socklen_t,
55            );
56
57            // IPV6_RECVORIGDSTADDR (74)
58            let _ = libc::setsockopt(
59                fd,
60                libc::SOL_IPV6,
61                libc::IPV6_RECVORIGDSTADDR,
62                &enable as *const _ as *const libc::c_void,
63                mem::size_of::<libc::c_int>() as libc::socklen_t,
64            );
65        }
66        Ok(())
67    }
68
69    /// Create a UDP socket bound to a specific address with IP_TRANSPARENT enabled.
70    /// This allows binding to non-local addresses (spoofing source IP).
71    pub fn create_transparent_udp_socket(addr: SocketAddr) -> io::Result<UdpSocket> {
72        use std::os::unix::io::FromRawFd;
73
74        unsafe {
75            let (domain, sockaddr, socklen) = match addr {
76                SocketAddr::V4(v4) => {
77                    let mut sin: libc::sockaddr_in = mem::zeroed();
78                    sin.sin_family = libc::AF_INET as libc::sa_family_t;
79                    sin.sin_port = v4.port().to_be();
80                    sin.sin_addr.s_addr = u32::from(*v4.ip()).to_be();
81                    (
82                        libc::AF_INET,
83                        &sin as *const _ as *const libc::sockaddr,
84                        mem::size_of::<libc::sockaddr_in>() as libc::socklen_t,
85                    )
86                }
87                SocketAddr::V6(v6) => {
88                    let mut sin6: libc::sockaddr_in6 = mem::zeroed();
89                    sin6.sin6_family = libc::AF_INET6 as libc::sa_family_t;
90                    sin6.sin6_port = v6.port().to_be();
91                    std::ptr::copy_nonoverlapping(
92                        v6.ip().octets().as_ptr(),
93                        sin6.sin6_addr.s6_addr.as_mut_ptr(),
94                        16,
95                    );
96
97                    (
98                        libc::AF_INET6,
99                        &sin6 as *const _ as *const libc::sockaddr,
100                        mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t,
101                    )
102                }
103            };
104
105            let fd = libc::socket(domain, libc::SOCK_DGRAM, 0);
106            if fd < 0 {
107                return Err(io::Error::last_os_error());
108            }
109
110            // Helper to close fd on error
111            let close_fd = |fd: libc::c_int| {
112                libc::close(fd);
113            };
114
115            // Enable IP_TRANSPARENT
116            let enable: libc::c_int = 1;
117            if libc::setsockopt(
118                fd,
119                libc::SOL_IP,
120                libc::IP_TRANSPARENT,
121                &enable as *const _ as *const libc::c_void,
122                mem::size_of::<libc::c_int>() as libc::socklen_t,
123            ) < 0
124            {
125                let err = io::Error::last_os_error();
126                close_fd(fd);
127                return Err(err);
128            }
129
130            if domain == libc::AF_INET6 {
131                // Ignore IPv6 errors for now as it might not be supported
132                let _ = libc::setsockopt(
133                    fd,
134                    libc::SOL_IPV6,
135                    libc::IPV6_TRANSPARENT,
136                    &enable as *const _ as *const libc::c_void,
137                    mem::size_of::<libc::c_int>() as libc::socklen_t,
138                );
139            }
140
141            // SO_REUSEADDR
142            if libc::setsockopt(
143                fd,
144                libc::SOL_SOCKET,
145                libc::SO_REUSEADDR,
146                &enable as *const _ as *const libc::c_void,
147                mem::size_of::<libc::c_int>() as libc::socklen_t,
148            ) < 0
149            {
150                let err = io::Error::last_os_error();
151                close_fd(fd);
152                return Err(err);
153            }
154
155            // Bind
156            if libc::bind(fd, sockaddr, socklen) < 0 {
157                let err = io::Error::last_os_error();
158                close_fd(fd);
159                return Err(err);
160            }
161
162            // Non-blocking (Tokio needs this)
163            let flags = libc::fcntl(fd, libc::F_GETFL);
164            if flags < 0 || libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) < 0 {
165                let err = io::Error::last_os_error();
166                close_fd(fd);
167                return Err(err);
168            }
169
170            let std_socket = std::net::UdpSocket::from_raw_fd(fd);
171            UdpSocket::from_std(std_socket)
172        }
173    }
174
175    /// Receive a packet with original destination address
176    /// Returns (bytes_read, source_addr, original_dest_addr)
177    pub async fn recv_original_dst(
178        socket: &UdpSocket,
179        buf: &mut [u8],
180    ) -> io::Result<(usize, SocketAddr, Option<SocketAddr>)> {
181        let fd = socket.as_raw_fd();
182        socket
183            .async_io(tokio::io::Interest::READABLE, || {
184                let mut iov = libc::iovec {
185                    iov_base: buf.as_mut_ptr() as *mut libc::c_void,
186                    iov_len: buf.len(),
187                };
188
189                // Buffer for control messages (ancillary data)
190                // Enough space for IPv4 or IPv6 address
191                let mut cmsg_buf = [0u8; 64];
192
193                // Prepare sockaddr storage for source address
194                let mut src_addr: libc::sockaddr_storage = unsafe { mem::zeroed() };
195
196                let mut msg = libc::msghdr {
197                    msg_name: &mut src_addr as *mut _ as *mut libc::c_void,
198                    msg_namelen: mem::size_of::<libc::sockaddr_storage>() as libc::socklen_t,
199                    msg_iov: &mut iov,
200                    msg_iovlen: 1,
201                    msg_control: cmsg_buf.as_mut_ptr() as *mut libc::c_void,
202                    msg_controllen: cmsg_buf.len(),
203                    msg_flags: 0,
204                };
205
206                let n = unsafe { libc::recvmsg(fd, &mut msg, 0) };
207
208                if n < 0 {
209                    return Err(io::Error::last_os_error());
210                }
211
212                let source = unsafe { sockaddr_to_socket_addr(&src_addr)? };
213                let orig_dst = unsafe { parse_orig_dst(&msg) };
214
215                Ok((n as usize, source, orig_dst))
216            })
217            .await
218    }
219}
220
221#[cfg(target_os = "linux")]
222unsafe fn sockaddr_to_socket_addr(storage: &libc::sockaddr_storage) -> io::Result<SocketAddr> {
223    match storage.ss_family as libc::c_int {
224        libc::AF_INET => {
225            let addr: &libc::sockaddr_in = unsafe { mem::transmute(storage) };
226            let ip = Ipv4Addr::from(u32::from_be(addr.sin_addr.s_addr));
227            let port = u16::from_be(addr.sin_port);
228            Ok(SocketAddr::V4(SocketAddrV4::new(ip, port)))
229        }
230        libc::AF_INET6 => {
231            let addr: &libc::sockaddr_in6 = unsafe { mem::transmute(storage) };
232            let ip = Ipv6Addr::from(addr.sin6_addr.s6_addr);
233            let port = u16::from_be(addr.sin6_port);
234            Ok(SocketAddr::V6(SocketAddrV6::new(
235                ip,
236                port,
237                addr.sin6_flowinfo,
238                addr.sin6_scope_id,
239            )))
240        }
241        _ => Err(io::Error::new(
242            io::ErrorKind::InvalidData,
243            "Unknown address family",
244        )),
245    }
246}
247
248#[cfg(target_os = "linux")]
249unsafe fn parse_orig_dst(msg: &libc::msghdr) -> Option<SocketAddr> {
250    let mut cmsg: *mut libc::cmsghdr = unsafe { libc::CMSG_FIRSTHDR(msg) };
251
252    while !cmsg.is_null() {
253        unsafe {
254            if (*cmsg).cmsg_level == libc::SOL_IP && (*cmsg).cmsg_type == libc::IP_RECVORIGDSTADDR {
255                let data = libc::CMSG_DATA(cmsg) as *const libc::sockaddr_in;
256                let ip = Ipv4Addr::from(u32::from_be((*data).sin_addr.s_addr));
257                let port = u16::from_be((*data).sin_port);
258                return Some(SocketAddr::V4(SocketAddrV4::new(ip, port)));
259            } else if (*cmsg).cmsg_level == libc::SOL_IPV6
260                && (*cmsg).cmsg_type == libc::IPV6_RECVORIGDSTADDR
261            {
262                let data = libc::CMSG_DATA(cmsg) as *const libc::sockaddr_in6;
263                let ip = Ipv6Addr::from((*data).sin6_addr.s6_addr);
264                let port = u16::from_be((*data).sin6_port);
265                return Some(SocketAddr::V6(SocketAddrV6::new(
266                    ip,
267                    port,
268                    (*data).sin6_flowinfo,
269                    (*data).sin6_scope_id,
270                )));
271            }
272
273            cmsg = libc::CMSG_NXTHDR(msg, cmsg);
274        }
275    }
276
277    None
278}