1use anyhow::{anyhow, Result};
2use bytes::BytesMut;
3use log::{debug, warn};
4use std::io::ErrorKind;
5use std::os::fd::{FromRawFd, IntoRawFd};
6use std::os::unix::io::{AsRawFd, RawFd};
7use std::sync::Arc;
8use std::{io, mem};
9use tokio::net::UdpSocket;
10use tokio::select;
11use tokio::sync::mpsc::{channel, Receiver, Sender};
12use tokio::task::JoinHandle;
13
14const RAW_SOCKET_TRANSMIT_QUEUE_LEN: usize = 3000;
15const RAW_SOCKET_RECEIVE_QUEUE_LEN: usize = 3000;
16
17#[derive(Debug)]
18pub enum RawSocketProtocol {
19 Icmpv4,
20 Icmpv6,
21 Ethernet,
22}
23
24impl RawSocketProtocol {
25 pub fn to_socket_domain(&self) -> i32 {
26 match self {
27 RawSocketProtocol::Icmpv4 => libc::AF_INET,
28 RawSocketProtocol::Icmpv6 => libc::AF_INET6,
29 RawSocketProtocol::Ethernet => libc::AF_PACKET,
30 }
31 }
32
33 pub fn to_socket_protocol(&self) -> u16 {
34 match self {
35 RawSocketProtocol::Icmpv4 => libc::IPPROTO_ICMP as u16,
36 RawSocketProtocol::Icmpv6 => libc::IPPROTO_ICMPV6 as u16,
37 RawSocketProtocol::Ethernet => (libc::ETH_P_ALL as u16).to_be(),
38 }
39 }
40
41 pub fn to_socket_type(&self) -> i32 {
42 libc::SOCK_RAW
43 }
44}
45
46const SIOCGIFINDEX: libc::c_ulong = 0x8933;
47const SIOCGIFMTU: libc::c_ulong = 0x8921;
48
49#[derive(Debug)]
50pub struct RawSocketHandle {
51 protocol: RawSocketProtocol,
52 lower: libc::c_int,
53}
54
55impl AsRawFd for RawSocketHandle {
56 fn as_raw_fd(&self) -> RawFd {
57 self.lower
58 }
59}
60
61impl IntoRawFd for RawSocketHandle {
62 fn into_raw_fd(self) -> RawFd {
63 let fd = self.lower;
64 mem::forget(self);
65 fd
66 }
67}
68
69impl RawSocketHandle {
70 pub fn new(protocol: RawSocketProtocol) -> io::Result<RawSocketHandle> {
71 let lower = unsafe {
72 let lower = libc::socket(
73 protocol.to_socket_domain(),
74 protocol.to_socket_type() | libc::SOCK_NONBLOCK,
75 protocol.to_socket_protocol() as i32,
76 );
77 if lower == -1 {
78 return Err(io::Error::last_os_error());
79 }
80 lower
81 };
82
83 Ok(RawSocketHandle { protocol, lower })
84 }
85
86 pub fn bound_to_interface(interface: &str, protocol: RawSocketProtocol) -> Result<Self> {
87 let mut socket = RawSocketHandle::new(protocol)?;
88 socket.bind_to_interface(interface)?;
89 Ok(socket)
90 }
91
92 pub fn bind_to_interface(&mut self, interface: &str) -> io::Result<()> {
93 let mut ifreq = ifreq_for(interface);
94 let sockaddr = libc::sockaddr_ll {
95 sll_family: libc::AF_PACKET as u16,
96 sll_protocol: self.protocol.to_socket_protocol(),
97 sll_ifindex: ifreq_ioctl(self.lower, &mut ifreq, SIOCGIFINDEX)?,
98 sll_hatype: 1,
99 sll_pkttype: 0,
100 sll_halen: 6,
101 sll_addr: [0; 8],
102 };
103
104 unsafe {
105 let res = libc::bind(
106 self.lower,
107 &sockaddr as *const libc::sockaddr_ll as *const libc::sockaddr,
108 mem::size_of::<libc::sockaddr_ll>() as libc::socklen_t,
109 );
110 if res == -1 {
111 return Err(io::Error::last_os_error());
112 }
113 }
114
115 Ok(())
116 }
117
118 pub fn mtu_of_interface(&mut self, interface: &str) -> io::Result<usize> {
119 let mut ifreq = ifreq_for(interface);
120 ifreq_ioctl(self.lower, &mut ifreq, SIOCGIFMTU).map(|mtu| mtu as usize)
121 }
122
123 pub fn recv(&self, buffer: &mut [u8]) -> io::Result<usize> {
124 unsafe {
125 let len = libc::recv(
126 self.lower,
127 buffer.as_mut_ptr() as *mut libc::c_void,
128 buffer.len(),
129 0,
130 );
131 if len == -1 {
132 return Err(io::Error::last_os_error());
133 }
134 Ok(len as usize)
135 }
136 }
137
138 pub fn send(&self, buffer: &[u8]) -> io::Result<usize> {
139 unsafe {
140 let len = libc::send(
141 self.lower,
142 buffer.as_ptr() as *const libc::c_void,
143 buffer.len(),
144 0,
145 );
146 if len == -1 {
147 return Err(io::Error::last_os_error());
148 }
149 Ok(len as usize)
150 }
151 }
152}
153
154impl Drop for RawSocketHandle {
155 fn drop(&mut self) {
156 unsafe {
157 libc::close(self.lower);
158 }
159 }
160}
161
162#[repr(C)]
163#[derive(Debug)]
164struct Ifreq {
165 ifr_name: [libc::c_char; libc::IF_NAMESIZE],
166 ifr_data: libc::c_int,
167}
168
169fn ifreq_for(name: &str) -> Ifreq {
170 let mut ifreq = Ifreq {
171 ifr_name: [0; libc::IF_NAMESIZE],
172 ifr_data: 0,
173 };
174 for (i, byte) in name.as_bytes().iter().enumerate() {
175 ifreq.ifr_name[i] = *byte as libc::c_char
176 }
177 ifreq
178}
179
180fn ifreq_ioctl(
181 lower: libc::c_int,
182 ifreq: &mut Ifreq,
183 cmd: libc::c_ulong,
184) -> io::Result<libc::c_int> {
185 unsafe {
186 let res = libc::ioctl(lower, cmd as _, ifreq as *mut Ifreq);
187 if res == -1 {
188 return Err(io::Error::last_os_error());
189 }
190 }
191
192 Ok(ifreq.ifr_data)
193}
194
195pub struct AsyncRawSocketChannel {
196 pub sender: Sender<BytesMut>,
197 pub receiver: Receiver<BytesMut>,
198 _task: Arc<JoinHandle<()>>,
199}
200
201enum AsyncRawSocketChannelSelect {
202 TransmitPacket(Option<BytesMut>),
203 Readable(()),
204}
205
206impl AsyncRawSocketChannel {
207 pub fn new(mtu: usize, socket: RawSocketHandle) -> Result<AsyncRawSocketChannel> {
208 let (transmit_sender, transmit_receiver) = channel(RAW_SOCKET_TRANSMIT_QUEUE_LEN);
209 let (receive_sender, receive_receiver) = channel(RAW_SOCKET_RECEIVE_QUEUE_LEN);
210 let task = AsyncRawSocketChannel::launch(mtu, socket, transmit_receiver, receive_sender)?;
211 Ok(AsyncRawSocketChannel {
212 sender: transmit_sender,
213 receiver: receive_receiver,
214 _task: Arc::new(task),
215 })
216 }
217
218 fn launch(
219 mtu: usize,
220 socket: RawSocketHandle,
221 transmit_receiver: Receiver<BytesMut>,
222 receive_sender: Sender<BytesMut>,
223 ) -> Result<JoinHandle<()>> {
224 Ok(tokio::task::spawn(async move {
225 if let Err(error) =
226 AsyncRawSocketChannel::process(mtu, socket, transmit_receiver, receive_sender).await
227 {
228 warn!("failed to process raw socket: {}", error);
229 }
230 }))
231 }
232
233 async fn process(
234 mtu: usize,
235 socket: RawSocketHandle,
236 mut transmit_receiver: Receiver<BytesMut>,
237 receive_sender: Sender<BytesMut>,
238 ) -> Result<()> {
239 let socket = unsafe { std::net::UdpSocket::from_raw_fd(socket.into_raw_fd()) };
240 let socket = UdpSocket::from_std(socket)?;
241
242 let tear_off_size = 100 * mtu;
243 let mut buffer: BytesMut = BytesMut::with_capacity(tear_off_size);
244 loop {
245 if buffer.capacity() < mtu {
246 buffer = BytesMut::with_capacity(tear_off_size);
247 }
248
249 let selection = select! {
250 x = transmit_receiver.recv() => AsyncRawSocketChannelSelect::TransmitPacket(x),
251 x = socket.readable() => AsyncRawSocketChannelSelect::Readable(x?),
252 };
253
254 match selection {
255 AsyncRawSocketChannelSelect::Readable(_) => {
256 buffer.resize(mtu, 0);
257 match socket.try_recv(&mut buffer) {
258 Ok(len) => {
259 if len == 0 {
260 continue;
261 }
262 let packet = buffer.split_to(len);
263 if let Err(error) = receive_sender.try_send(packet) {
264 debug!(
265 "failed to process received packet from raw socket: {}",
266 error
267 );
268 }
269 }
270
271 Err(ref error) => {
272 if error.kind() == ErrorKind::WouldBlock {
273 continue;
274 }
275
276 if error.raw_os_error() == Some(6) {
278 break;
279 }
280
281 return Err(anyhow!("failed to read from raw socket: {}", error));
282 }
283 };
284 }
285
286 AsyncRawSocketChannelSelect::TransmitPacket(Some(packet)) => {
287 match socket.try_send(&packet) {
288 Ok(_len) => {}
289 Err(ref error) => {
290 if error.kind() == ErrorKind::WouldBlock {
291 debug!("failed to transmit: would block");
292 continue;
293 }
294
295 if error.raw_os_error() == Some(6) {
297 break;
298 }
299
300 return Err(anyhow!(
301 "failed to write {} bytes to raw socket: {}",
302 packet.len(),
303 error
304 ));
305 }
306 };
307 }
308
309 AsyncRawSocketChannelSelect::TransmitPacket(None) => {
310 break;
311 }
312 }
313 }
314
315 Ok(())
316 }
317}