1#[cfg(target_os = "linux")]
2use std::io;
3#[cfg(target_os = "linux")]
4use std::mem;
5#[cfg(target_os = "linux")]
6use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
7#[cfg(target_os = "linux")]
8use std::os::unix::io::AsRawFd;
9#[cfg(target_os = "linux")]
10use tokio::net::UdpSocket;
11
12#[cfg(target_os = "linux")]
13pub struct LinuxTproxy;
14
15#[cfg(target_os = "linux")]
16impl LinuxTproxy {
17 pub fn enable_tproxy(socket: &UdpSocket) -> io::Result<()> {
19 let fd = socket.as_raw_fd();
20 unsafe {
21 let enable: libc::c_int = 1;
22
23 if libc::setsockopt(
25 fd,
26 libc::SOL_IP,
27 libc::IP_TRANSPARENT,
28 &enable as *const _ as *const libc::c_void,
29 mem::size_of::<libc::c_int>() as libc::socklen_t,
30 ) < 0
31 {
32 return Err(io::Error::last_os_error());
33 }
34
35 if libc::setsockopt(
37 fd,
38 libc::SOL_IP,
39 libc::IP_RECVORIGDSTADDR,
40 &enable as *const _ as *const libc::c_void,
41 mem::size_of::<libc::c_int>() as libc::socklen_t,
42 ) < 0
43 {
44 return Err(io::Error::last_os_error());
45 }
46
47 let _ = libc::setsockopt(
50 fd,
51 libc::SOL_IPV6,
52 libc::IPV6_TRANSPARENT,
53 &enable as *const _ as *const libc::c_void,
54 mem::size_of::<libc::c_int>() as libc::socklen_t,
55 );
56
57 let _ = libc::setsockopt(
59 fd,
60 libc::SOL_IPV6,
61 libc::IPV6_RECVORIGDSTADDR,
62 &enable as *const _ as *const libc::c_void,
63 mem::size_of::<libc::c_int>() as libc::socklen_t,
64 );
65 }
66 Ok(())
67 }
68
69 pub fn create_transparent_udp_socket(addr: SocketAddr) -> io::Result<UdpSocket> {
72 use std::os::unix::io::FromRawFd;
73
74 unsafe {
75 let (domain, sockaddr, socklen) = match addr {
76 SocketAddr::V4(v4) => {
77 let mut sin: libc::sockaddr_in = mem::zeroed();
78 sin.sin_family = libc::AF_INET as libc::sa_family_t;
79 sin.sin_port = v4.port().to_be();
80 sin.sin_addr.s_addr = u32::from(*v4.ip()).to_be();
81 (
82 libc::AF_INET,
83 &sin as *const _ as *const libc::sockaddr,
84 mem::size_of::<libc::sockaddr_in>() as libc::socklen_t,
85 )
86 }
87 SocketAddr::V6(v6) => {
88 let mut sin6: libc::sockaddr_in6 = mem::zeroed();
89 sin6.sin6_family = libc::AF_INET6 as libc::sa_family_t;
90 sin6.sin6_port = v6.port().to_be();
91 std::ptr::copy_nonoverlapping(
92 v6.ip().octets().as_ptr(),
93 sin6.sin6_addr.s6_addr.as_mut_ptr(),
94 16,
95 );
96
97 (
98 libc::AF_INET6,
99 &sin6 as *const _ as *const libc::sockaddr,
100 mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t,
101 )
102 }
103 };
104
105 let fd = libc::socket(domain, libc::SOCK_DGRAM, 0);
106 if fd < 0 {
107 return Err(io::Error::last_os_error());
108 }
109
110 let close_fd = |fd: libc::c_int| {
112 libc::close(fd);
113 };
114
115 let enable: libc::c_int = 1;
117 if libc::setsockopt(
118 fd,
119 libc::SOL_IP,
120 libc::IP_TRANSPARENT,
121 &enable as *const _ as *const libc::c_void,
122 mem::size_of::<libc::c_int>() as libc::socklen_t,
123 ) < 0
124 {
125 let err = io::Error::last_os_error();
126 close_fd(fd);
127 return Err(err);
128 }
129
130 if domain == libc::AF_INET6 {
131 let _ = libc::setsockopt(
133 fd,
134 libc::SOL_IPV6,
135 libc::IPV6_TRANSPARENT,
136 &enable as *const _ as *const libc::c_void,
137 mem::size_of::<libc::c_int>() as libc::socklen_t,
138 );
139 }
140
141 if libc::setsockopt(
143 fd,
144 libc::SOL_SOCKET,
145 libc::SO_REUSEADDR,
146 &enable as *const _ as *const libc::c_void,
147 mem::size_of::<libc::c_int>() as libc::socklen_t,
148 ) < 0
149 {
150 let err = io::Error::last_os_error();
151 close_fd(fd);
152 return Err(err);
153 }
154
155 if libc::bind(fd, sockaddr, socklen) < 0 {
157 let err = io::Error::last_os_error();
158 close_fd(fd);
159 return Err(err);
160 }
161
162 let flags = libc::fcntl(fd, libc::F_GETFL);
164 if flags < 0 || libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) < 0 {
165 let err = io::Error::last_os_error();
166 close_fd(fd);
167 return Err(err);
168 }
169
170 let std_socket = std::net::UdpSocket::from_raw_fd(fd);
171 UdpSocket::from_std(std_socket)
172 }
173 }
174
175 pub async fn recv_original_dst(
178 socket: &UdpSocket,
179 buf: &mut [u8],
180 ) -> io::Result<(usize, SocketAddr, Option<SocketAddr>)> {
181 let fd = socket.as_raw_fd();
182 socket
183 .async_io(tokio::io::Interest::READABLE, || {
184 let mut iov = libc::iovec {
185 iov_base: buf.as_mut_ptr() as *mut libc::c_void,
186 iov_len: buf.len(),
187 };
188
189 let mut cmsg_buf = [0u8; 64];
192
193 let mut src_addr: libc::sockaddr_storage = unsafe { mem::zeroed() };
195
196 let mut msg = libc::msghdr {
197 msg_name: &mut src_addr as *mut _ as *mut libc::c_void,
198 msg_namelen: mem::size_of::<libc::sockaddr_storage>() as libc::socklen_t,
199 msg_iov: &mut iov,
200 msg_iovlen: 1,
201 msg_control: cmsg_buf.as_mut_ptr() as *mut libc::c_void,
202 msg_controllen: cmsg_buf.len(),
203 msg_flags: 0,
204 };
205
206 let n = unsafe { libc::recvmsg(fd, &mut msg, 0) };
207
208 if n < 0 {
209 return Err(io::Error::last_os_error());
210 }
211
212 let source = unsafe { sockaddr_to_socket_addr(&src_addr)? };
213 let orig_dst = unsafe { parse_orig_dst(&msg) };
214
215 Ok((n as usize, source, orig_dst))
216 })
217 .await
218 }
219}
220
221#[cfg(target_os = "linux")]
222unsafe fn sockaddr_to_socket_addr(storage: &libc::sockaddr_storage) -> io::Result<SocketAddr> {
223 match storage.ss_family as libc::c_int {
224 libc::AF_INET => {
225 let addr: &libc::sockaddr_in = unsafe { mem::transmute(storage) };
226 let ip = Ipv4Addr::from(u32::from_be(addr.sin_addr.s_addr));
227 let port = u16::from_be(addr.sin_port);
228 Ok(SocketAddr::V4(SocketAddrV4::new(ip, port)))
229 }
230 libc::AF_INET6 => {
231 let addr: &libc::sockaddr_in6 = unsafe { mem::transmute(storage) };
232 let ip = Ipv6Addr::from(addr.sin6_addr.s6_addr);
233 let port = u16::from_be(addr.sin6_port);
234 Ok(SocketAddr::V6(SocketAddrV6::new(
235 ip,
236 port,
237 addr.sin6_flowinfo,
238 addr.sin6_scope_id,
239 )))
240 }
241 _ => Err(io::Error::new(
242 io::ErrorKind::InvalidData,
243 "Unknown address family",
244 )),
245 }
246}
247
248#[cfg(target_os = "linux")]
249unsafe fn parse_orig_dst(msg: &libc::msghdr) -> Option<SocketAddr> {
250 let mut cmsg: *mut libc::cmsghdr = unsafe { libc::CMSG_FIRSTHDR(msg) };
251
252 while !cmsg.is_null() {
253 unsafe {
254 if (*cmsg).cmsg_level == libc::SOL_IP && (*cmsg).cmsg_type == libc::IP_RECVORIGDSTADDR {
255 let data = libc::CMSG_DATA(cmsg) as *const libc::sockaddr_in;
256 let ip = Ipv4Addr::from(u32::from_be((*data).sin_addr.s_addr));
257 let port = u16::from_be((*data).sin_port);
258 return Some(SocketAddr::V4(SocketAddrV4::new(ip, port)));
259 } else if (*cmsg).cmsg_level == libc::SOL_IPV6
260 && (*cmsg).cmsg_type == libc::IPV6_RECVORIGDSTADDR
261 {
262 let data = libc::CMSG_DATA(cmsg) as *const libc::sockaddr_in6;
263 let ip = Ipv6Addr::from((*data).sin6_addr.s6_addr);
264 let port = u16::from_be((*data).sin6_port);
265 return Some(SocketAddr::V6(SocketAddrV6::new(
266 ip,
267 port,
268 (*data).sin6_flowinfo,
269 (*data).sin6_scope_id,
270 )));
271 }
272
273 cmsg = libc::CMSG_NXTHDR(msg, cmsg);
274 }
275 }
276
277 None
278}