relay_core_lib/capture/
linux_tproxy.rs1#[cfg(target_os = "linux")]
2use std::io;
3#[cfg(target_os = "linux")]
4use std::os::unix::io::AsRawFd;
5#[cfg(target_os = "linux")]
6use tokio::net::UdpSocket;
7#[cfg(target_os = "linux")]
8use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6, Ipv4Addr, Ipv6Addr};
9#[cfg(target_os = "linux")]
10use std::mem;
11
12#[cfg(target_os = "linux")]
13pub struct LinuxTproxy;
14
15#[cfg(target_os = "linux")]
16impl LinuxTproxy {
17 pub fn enable_tproxy(socket: &UdpSocket) -> io::Result<()> {
19 let fd = socket.as_raw_fd();
20 unsafe {
21 let enable: libc::c_int = 1;
22
23 if libc::setsockopt(
25 fd,
26 libc::SOL_IP,
27 libc::IP_TRANSPARENT,
28 &enable as *const _ as *const libc::c_void,
29 mem::size_of::<libc::c_int>() as libc::socklen_t,
30 ) < 0 {
31 return Err(io::Error::last_os_error());
32 }
33
34 if libc::setsockopt(
36 fd,
37 libc::SOL_IP,
38 libc::IP_RECVORIGDSTADDR,
39 &enable as *const _ as *const libc::c_void,
40 mem::size_of::<libc::c_int>() as libc::socklen_t,
41 ) < 0 {
42 return Err(io::Error::last_os_error());
43 }
44
45 let _ = libc::setsockopt(
48 fd,
49 libc::SOL_IPV6,
50 libc::IPV6_TRANSPARENT,
51 &enable as *const _ as *const libc::c_void,
52 mem::size_of::<libc::c_int>() as libc::socklen_t,
53 );
54
55 let _ = libc::setsockopt(
57 fd,
58 libc::SOL_IPV6,
59 libc::IPV6_RECVORIGDSTADDR,
60 &enable as *const _ as *const libc::c_void,
61 mem::size_of::<libc::c_int>() as libc::socklen_t,
62 );
63 }
64 Ok(())
65 }
66
67 pub fn create_transparent_udp_socket(addr: SocketAddr) -> io::Result<UdpSocket> {
70 use std::os::unix::io::FromRawFd;
71
72 unsafe {
73 let (domain, sockaddr, socklen) = match addr {
74 SocketAddr::V4(v4) => {
75 let mut sin: libc::sockaddr_in = mem::zeroed();
76 sin.sin_family = libc::AF_INET as libc::sa_family_t;
77 sin.sin_port = v4.port().to_be();
78 sin.sin_addr.s_addr = u32::from(*v4.ip()).to_be();
79 (
80 libc::AF_INET,
81 &sin as *const _ as *const libc::sockaddr,
82 mem::size_of::<libc::sockaddr_in>() as libc::socklen_t,
83 )
84 },
85 SocketAddr::V6(v6) => {
86 let mut sin6: libc::sockaddr_in6 = mem::zeroed();
87 sin6.sin6_family = libc::AF_INET6 as libc::sa_family_t;
88 sin6.sin6_port = v6.port().to_be();
89 std::ptr::copy_nonoverlapping(
90 v6.ip().octets().as_ptr(),
91 sin6.sin6_addr.s6_addr.as_mut_ptr(),
92 16
93 );
94
95 (
96 libc::AF_INET6,
97 &sin6 as *const _ as *const libc::sockaddr,
98 mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t,
99 )
100 }
101 };
102
103 let fd = libc::socket(domain, libc::SOCK_DGRAM, 0);
104 if fd < 0 {
105 return Err(io::Error::last_os_error());
106 }
107
108 let close_fd = |fd: libc::c_int| {
110 libc::close(fd);
111 };
112
113 let enable: libc::c_int = 1;
115 if libc::setsockopt(
116 fd,
117 libc::SOL_IP,
118 libc::IP_TRANSPARENT,
119 &enable as *const _ as *const libc::c_void,
120 mem::size_of::<libc::c_int>() as libc::socklen_t,
121 ) < 0 {
122 let err = io::Error::last_os_error();
123 close_fd(fd);
124 return Err(err);
125 }
126
127 if domain == libc::AF_INET6 {
128 let _ = libc::setsockopt(
130 fd,
131 libc::SOL_IPV6,
132 libc::IPV6_TRANSPARENT,
133 &enable as *const _ as *const libc::c_void,
134 mem::size_of::<libc::c_int>() as libc::socklen_t,
135 );
136 }
137
138 if libc::setsockopt(
140 fd,
141 libc::SOL_SOCKET,
142 libc::SO_REUSEADDR,
143 &enable as *const _ as *const libc::c_void,
144 mem::size_of::<libc::c_int>() as libc::socklen_t,
145 ) < 0 {
146 let err = io::Error::last_os_error();
147 close_fd(fd);
148 return Err(err);
149 }
150
151 if libc::bind(fd, sockaddr, socklen) < 0 {
153 let err = io::Error::last_os_error();
154 close_fd(fd);
155 return Err(err);
156 }
157
158 let flags = libc::fcntl(fd, libc::F_GETFL);
160 if flags < 0 || libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) < 0 {
161 let err = io::Error::last_os_error();
162 close_fd(fd);
163 return Err(err);
164 }
165
166 let std_socket = std::net::UdpSocket::from_raw_fd(fd);
167 UdpSocket::from_std(std_socket)
168 }
169 }
170
171 pub async fn recv_original_dst(socket: &UdpSocket, buf: &mut [u8]) -> io::Result<(usize, SocketAddr, Option<SocketAddr>)> {
174 let fd = socket.as_raw_fd();
175 socket.async_io(tokio::io::Interest::READABLE, || {
176 let mut iov = libc::iovec {
177 iov_base: buf.as_mut_ptr() as *mut libc::c_void,
178 iov_len: buf.len(),
179 };
180
181 let mut cmsg_buf = [0u8; 64];
184
185 let mut src_addr: libc::sockaddr_storage = unsafe { mem::zeroed() };
187
188 let mut msg = libc::msghdr {
189 msg_name: &mut src_addr as *mut _ as *mut libc::c_void,
190 msg_namelen: mem::size_of::<libc::sockaddr_storage>() as libc::socklen_t,
191 msg_iov: &mut iov,
192 msg_iovlen: 1,
193 msg_control: cmsg_buf.as_mut_ptr() as *mut libc::c_void,
194 msg_controllen: cmsg_buf.len(),
195 msg_flags: 0,
196 };
197
198 let n = unsafe { libc::recvmsg(fd, &mut msg, 0) };
199
200 if n < 0 {
201 return Err(io::Error::last_os_error());
202 }
203
204 let source = unsafe { sockaddr_to_socket_addr(&src_addr)? };
205 let orig_dst = unsafe { parse_orig_dst(&msg) };
206
207 Ok((n as usize, source, orig_dst))
208 }).await
209 }
210}
211
212#[cfg(target_os = "linux")]
213unsafe fn sockaddr_to_socket_addr(storage: &libc::sockaddr_storage) -> io::Result<SocketAddr> {
214 match storage.ss_family as libc::c_int {
215 libc::AF_INET => {
216 let addr: &libc::sockaddr_in = unsafe { mem::transmute(storage) };
217 let ip = Ipv4Addr::from(u32::from_be(addr.sin_addr.s_addr));
218 let port = u16::from_be(addr.sin_port);
219 Ok(SocketAddr::V4(SocketAddrV4::new(ip, port)))
220 }
221 libc::AF_INET6 => {
222 let addr: &libc::sockaddr_in6 = unsafe { mem::transmute(storage) };
223 let ip = Ipv6Addr::from(addr.sin6_addr.s6_addr);
224 let port = u16::from_be(addr.sin6_port);
225 Ok(SocketAddr::V6(SocketAddrV6::new(ip, port, addr.sin6_flowinfo, addr.sin6_scope_id)))
226 }
227 _ => Err(io::Error::new(io::ErrorKind::InvalidData, "Unknown address family")),
228 }
229}
230
231#[cfg(target_os = "linux")]
232unsafe fn parse_orig_dst(msg: &libc::msghdr) -> Option<SocketAddr> {
233 let mut cmsg: *mut libc::cmsghdr = unsafe { libc::CMSG_FIRSTHDR(msg) };
234
235 while !cmsg.is_null() {
236 unsafe {
237 if (*cmsg).cmsg_level == libc::SOL_IP && (*cmsg).cmsg_type == libc::IP_RECVORIGDSTADDR {
238 let data = libc::CMSG_DATA(cmsg) as *const libc::sockaddr_in;
239 let ip = Ipv4Addr::from(u32::from_be((*data).sin_addr.s_addr));
240 let port = u16::from_be((*data).sin_port);
241 return Some(SocketAddr::V4(SocketAddrV4::new(ip, port)));
242 } else if (*cmsg).cmsg_level == libc::SOL_IPV6 && (*cmsg).cmsg_type == libc::IPV6_RECVORIGDSTADDR {
243 let data = libc::CMSG_DATA(cmsg) as *const libc::sockaddr_in6;
244 let ip = Ipv6Addr::from((*data).sin6_addr.s6_addr);
245 let port = u16::from_be((*data).sin6_port);
246 return Some(SocketAddr::V6(SocketAddrV6::new(ip, port, (*data).sin6_flowinfo, (*data).sin6_scope_id)));
247 }
248
249 cmsg = libc::CMSG_NXTHDR(msg, cmsg);
250 }
251 }
252
253 None
254}