1use std::{ffi::c_int, io, os::fd::{AsRawFd, RawFd}};
2use tokio::io::unix::AsyncFd;
3
4pub struct SockOpts<'opt> {
5 protocol: c_int,
8 intf: &'opt str,
10}
11
12pub struct RawSock {
13 fd: AsyncFd<RawFd>,
14}
15
16impl RawSock {
17 pub fn new(opts: SockOpts) -> Result<Self, io::Error> {
18 unsafe {
19 if opts.intf.len() >= libc::IFNAMSIZ {
20 return Err(io::Error::other("invalid interface name - exceeds length"));
21 }
22
23 let sock_fd = libc::socket(
24 libc::AF_PACKET,
25 libc::SOCK_RAW | libc::SOCK_NONBLOCK,
26 opts.protocol
27 );
28
29 if sock_fd < 0 {
30 return Err(io::Error::last_os_error())
31 }
32
33 let mut ifreq = libc::ifreq {
34 ifr_name: [0;libc::IFNAMSIZ],
35 ifr_ifru: std::mem::zeroed(),
36 };
37
38 let intf_c = &*(opts.intf.as_bytes() as *const _ as *const [i8]);
39 ifreq.ifr_name[..intf_c.len()].copy_from_slice(intf_c);
40
41 if libc::ioctl(
42 sock_fd,
43 libc::SIOCGIFINDEX,
44 &ifreq as *const _,
45 ) < 0 {
46 return Err(io::Error::last_os_error())
47 }
48
49 let addr = libc::sockaddr_ll {
50 sll_family: libc::AF_PACKET as u16,
51 sll_protocol: u16::to_be(opts.protocol as u16),
52 sll_ifindex: ifreq.ifr_ifru.ifru_ifindex,
53 sll_hatype: 0,
54 sll_pkttype: 0,
55 sll_halen: 0,
56 sll_addr: [0; 8],
57 };
58
59 if libc::bind(sock_fd, &addr as *const _ as *const libc::sockaddr, std::mem::size_of::<libc::sockaddr_ll>() as u32) < 0 {
60 return Err(io::Error::last_os_error())
61 }
62
63 Ok(Self {
64 fd: AsyncFd::new(sock_fd).unwrap(),
65 })
66 }
67 }
68
69 pub async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
70 loop {
71 let guard = self.fd.readable().await?;
72
73 unsafe {
74 let res = libc::recv(
75 guard.get_ref().as_raw_fd(),
76 buf as *mut _ as *mut libc::c_void,
77 buf.len(),
78 0
79 );
80
81 if res < 0 {
82 let err = io::Error::last_os_error();
83
84 match err.kind() {
85 io::ErrorKind::WouldBlock => continue,
86 _ => return Err(err)
87 }
88 } else {
89 return Ok(res as usize)
90 }
91 }
92 }
93 }
94
95 pub async fn write(&self, buf: &[u8]) -> io::Result<usize> {
96 loop {
97 let guard = self.fd.writable().await?;
98
99 unsafe {
100 let res = libc::send(
101 guard.get_ref().as_raw_fd(),
102 buf as *const _ as *const libc::c_void,
103 buf.len(),
104 0,
105 );
106
107 if res < 0 {
108 let err = io::Error::last_os_error();
109
110 match err.kind() {
111 io::ErrorKind::WouldBlock => continue,
112 _ => return Err(err)
113 }
114 } else {
115 return Ok(res as usize)
116 }
117 }
118 }
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
127 async fn test_creation() {
128 let my_sock = RawSock::new(SockOpts { protocol: libc::ETH_P_ALL, intf: "lo" }).unwrap();
129
130 let mut my_buf = [0u8;128];
131
132 let packet: &[u8] = &[
134 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x86, 0xdd, 0x60, 0x04, 0x90, 0x15, 0x00, 0x40, 0x3a, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x80, 0x00, 0xd0, 0x40, 0x00, 0x0a, 0x00, 0x01, 0xb9, 0xb1, 0x09, 0x68, 0x00, 0x00, 0x00, 0x00, 0x27, 0x4b, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
135 ];
136
137 my_sock.write(&packet).await.unwrap();
138 let read_size = my_sock.read(&mut my_buf).await.unwrap();
139
140 assert_eq!(read_size, packet.len());
141 assert_eq!(&my_buf[..read_size], packet);
142 }
143}