Skip to main content

relay_core_lib/capture/
linux_tproxy.rs

1#[cfg(target_os = "linux")]
2use std::io;
3#[cfg(target_os = "linux")]
4use std::os::unix::io::AsRawFd;
5#[cfg(target_os = "linux")]
6use tokio::net::UdpSocket;
7#[cfg(target_os = "linux")]
8use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6, Ipv4Addr, Ipv6Addr};
9#[cfg(target_os = "linux")]
10use std::mem;
11
12#[cfg(target_os = "linux")]
13pub struct LinuxTproxy;
14
15#[cfg(target_os = "linux")]
16impl LinuxTproxy {
17    /// Enable IP_TRANSPARENT and IP_RECVORIGDSTADDR on the socket
18    pub fn enable_tproxy(socket: &UdpSocket) -> io::Result<()> {
19        let fd = socket.as_raw_fd();
20        unsafe {
21            let enable: libc::c_int = 1;
22            
23            // IP_TRANSPARENT (19)
24            if libc::setsockopt(
25                fd,
26                libc::SOL_IP,
27                libc::IP_TRANSPARENT,
28                &enable as *const _ as *const libc::c_void,
29                mem::size_of::<libc::c_int>() as libc::socklen_t,
30            ) < 0 {
31                return Err(io::Error::last_os_error());
32            }
33
34            // IP_RECVORIGDSTADDR (20)
35            if libc::setsockopt(
36                fd,
37                libc::SOL_IP,
38                libc::IP_RECVORIGDSTADDR,
39                &enable as *const _ as *const libc::c_void,
40                mem::size_of::<libc::c_int>() as libc::socklen_t,
41            ) < 0 {
42                return Err(io::Error::last_os_error());
43            }
44            
45            // IPV6_TRANSPARENT (75)
46            // We ignore errors here as IPv6 might be disabled
47            let _ = libc::setsockopt(
48                fd,
49                libc::SOL_IPV6,
50                libc::IPV6_TRANSPARENT,
51                &enable as *const _ as *const libc::c_void,
52                mem::size_of::<libc::c_int>() as libc::socklen_t,
53            );
54
55            // IPV6_RECVORIGDSTADDR (74)
56            let _ = libc::setsockopt(
57                fd,
58                libc::SOL_IPV6,
59                libc::IPV6_RECVORIGDSTADDR,
60                &enable as *const _ as *const libc::c_void,
61                mem::size_of::<libc::c_int>() as libc::socklen_t,
62            );
63        }
64        Ok(())
65    }
66
67    /// Create a UDP socket bound to a specific address with IP_TRANSPARENT enabled.
68    /// This allows binding to non-local addresses (spoofing source IP).
69    pub fn create_transparent_udp_socket(addr: SocketAddr) -> io::Result<UdpSocket> {
70        use std::os::unix::io::FromRawFd;
71
72        unsafe {
73            let (domain, sockaddr, socklen) = match addr {
74                SocketAddr::V4(v4) => {
75                    let mut sin: libc::sockaddr_in = mem::zeroed();
76                    sin.sin_family = libc::AF_INET as libc::sa_family_t;
77                    sin.sin_port = v4.port().to_be();
78                    sin.sin_addr.s_addr = u32::from(*v4.ip()).to_be();
79                    (
80                        libc::AF_INET,
81                        &sin as *const _ as *const libc::sockaddr,
82                        mem::size_of::<libc::sockaddr_in>() as libc::socklen_t,
83                    )
84                },
85                SocketAddr::V6(v6) => {
86                     let mut sin6: libc::sockaddr_in6 = mem::zeroed();
87                     sin6.sin6_family = libc::AF_INET6 as libc::sa_family_t;
88                     sin6.sin6_port = v6.port().to_be();
89                     std::ptr::copy_nonoverlapping(
90                         v6.ip().octets().as_ptr(), 
91                         sin6.sin6_addr.s6_addr.as_mut_ptr(), 
92                         16
93                     );
94                     
95                     (
96                         libc::AF_INET6, 
97                         &sin6 as *const _ as *const libc::sockaddr, 
98                         mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t,
99                     )
100                }
101            };
102
103            let fd = libc::socket(domain, libc::SOCK_DGRAM, 0);
104            if fd < 0 {
105                return Err(io::Error::last_os_error());
106            }
107            
108            // Helper to close fd on error
109            let close_fd = |fd: libc::c_int| {
110                libc::close(fd);
111            };
112
113            // Enable IP_TRANSPARENT
114            let enable: libc::c_int = 1;
115            if libc::setsockopt(
116                fd,
117                libc::SOL_IP,
118                libc::IP_TRANSPARENT,
119                &enable as *const _ as *const libc::c_void,
120                mem::size_of::<libc::c_int>() as libc::socklen_t,
121            ) < 0 {
122                 let err = io::Error::last_os_error();
123                 close_fd(fd);
124                 return Err(err);
125            }
126            
127            if domain == libc::AF_INET6 {
128                 // Ignore IPv6 errors for now as it might not be supported
129                 let _ = libc::setsockopt(
130                    fd,
131                    libc::SOL_IPV6,
132                    libc::IPV6_TRANSPARENT,
133                    &enable as *const _ as *const libc::c_void,
134                    mem::size_of::<libc::c_int>() as libc::socklen_t,
135                );
136            }
137
138            // SO_REUSEADDR
139            if libc::setsockopt(
140                fd,
141                libc::SOL_SOCKET,
142                libc::SO_REUSEADDR,
143                &enable as *const _ as *const libc::c_void,
144                mem::size_of::<libc::c_int>() as libc::socklen_t,
145            ) < 0 {
146                 let err = io::Error::last_os_error();
147                 close_fd(fd);
148                 return Err(err);
149            }
150            
151            // Bind
152            if libc::bind(fd, sockaddr, socklen) < 0 {
153                 let err = io::Error::last_os_error();
154                 close_fd(fd);
155                 return Err(err);
156            }
157            
158            // Non-blocking (Tokio needs this)
159            let flags = libc::fcntl(fd, libc::F_GETFL);
160            if flags < 0 || libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) < 0 {
161                 let err = io::Error::last_os_error();
162                 close_fd(fd);
163                 return Err(err);
164            }
165
166            let std_socket = std::net::UdpSocket::from_raw_fd(fd);
167            UdpSocket::from_std(std_socket)
168        }
169    }
170
171    /// Receive a packet with original destination address
172    /// Returns (bytes_read, source_addr, original_dest_addr)
173    pub async fn recv_original_dst(socket: &UdpSocket, buf: &mut [u8]) -> io::Result<(usize, SocketAddr, Option<SocketAddr>)> {
174        let fd = socket.as_raw_fd();
175        socket.async_io(tokio::io::Interest::READABLE, || {
176            let mut iov = libc::iovec {
177                iov_base: buf.as_mut_ptr() as *mut libc::c_void,
178                iov_len: buf.len(),
179            };
180            
181            // Buffer for control messages (ancillary data)
182            // Enough space for IPv4 or IPv6 address
183            let mut cmsg_buf = [0u8; 64]; 
184            
185            // Prepare sockaddr storage for source address
186            let mut src_addr: libc::sockaddr_storage = unsafe { mem::zeroed() };
187            
188            let mut msg = libc::msghdr {
189                msg_name: &mut src_addr as *mut _ as *mut libc::c_void,
190                msg_namelen: mem::size_of::<libc::sockaddr_storage>() as libc::socklen_t,
191                msg_iov: &mut iov,
192                msg_iovlen: 1,
193                msg_control: cmsg_buf.as_mut_ptr() as *mut libc::c_void,
194                msg_controllen: cmsg_buf.len(),
195                msg_flags: 0,
196            };
197
198            let n = unsafe { libc::recvmsg(fd, &mut msg, 0) };
199            
200            if n < 0 {
201                return Err(io::Error::last_os_error());
202            }
203
204            let source = unsafe { sockaddr_to_socket_addr(&src_addr)? };
205            let orig_dst = unsafe { parse_orig_dst(&msg) };
206
207            Ok((n as usize, source, orig_dst))
208        }).await
209    }
210}
211
212#[cfg(target_os = "linux")]
213unsafe fn sockaddr_to_socket_addr(storage: &libc::sockaddr_storage) -> io::Result<SocketAddr> {
214    match storage.ss_family as libc::c_int {
215        libc::AF_INET => {
216            let addr: &libc::sockaddr_in = unsafe { mem::transmute(storage) };
217            let ip = Ipv4Addr::from(u32::from_be(addr.sin_addr.s_addr));
218            let port = u16::from_be(addr.sin_port);
219            Ok(SocketAddr::V4(SocketAddrV4::new(ip, port)))
220        }
221        libc::AF_INET6 => {
222            let addr: &libc::sockaddr_in6 = unsafe { mem::transmute(storage) };
223            let ip = Ipv6Addr::from(addr.sin6_addr.s6_addr);
224            let port = u16::from_be(addr.sin6_port);
225            Ok(SocketAddr::V6(SocketAddrV6::new(ip, port, addr.sin6_flowinfo, addr.sin6_scope_id)))
226        }
227        _ => Err(io::Error::new(io::ErrorKind::InvalidData, "Unknown address family")),
228    }
229}
230
231#[cfg(target_os = "linux")]
232unsafe fn parse_orig_dst(msg: &libc::msghdr) -> Option<SocketAddr> {
233    let mut cmsg: *mut libc::cmsghdr = unsafe { libc::CMSG_FIRSTHDR(msg) };
234    
235    while !cmsg.is_null() {
236        unsafe {
237            if (*cmsg).cmsg_level == libc::SOL_IP && (*cmsg).cmsg_type == libc::IP_RECVORIGDSTADDR {
238                let data = libc::CMSG_DATA(cmsg) as *const libc::sockaddr_in;
239                let ip = Ipv4Addr::from(u32::from_be((*data).sin_addr.s_addr));
240                let port = u16::from_be((*data).sin_port);
241                return Some(SocketAddr::V4(SocketAddrV4::new(ip, port)));
242            } else if (*cmsg).cmsg_level == libc::SOL_IPV6 && (*cmsg).cmsg_type == libc::IPV6_RECVORIGDSTADDR {
243                 let data = libc::CMSG_DATA(cmsg) as *const libc::sockaddr_in6;
244                 let ip = Ipv6Addr::from((*data).sin6_addr.s6_addr);
245                 let port = u16::from_be((*data).sin6_port);
246                 return Some(SocketAddr::V6(SocketAddrV6::new(ip, port, (*data).sin6_flowinfo, (*data).sin6_scope_id)));
247            }
248        
249            cmsg = libc::CMSG_NXTHDR(msg, cmsg);
250        }
251    }
252    
253    None
254}