Skip to main content

fips_core/transport/ethernet/
socket_linux.rs

1//! AF_PACKET socket creation, binding, and ioctl helpers (Linux).
2
3use crate::transport::TransportError;
4use std::os::unix::io::{AsRawFd, RawFd};
5
6/// Wrapper around an AF_PACKET SOCK_DGRAM file descriptor.
7///
8/// Owns the fd and closes it on drop. Provides synchronous send/recv
9/// methods used by the async wrappers via `AsyncFd`.
10pub struct PacketSocket {
11    fd: RawFd,
12    if_index: i32,
13    ethertype: u16,
14}
15
16impl PacketSocket {
17    /// Create and bind an AF_PACKET SOCK_DGRAM socket.
18    ///
19    /// Returns an error with a clear message if CAP_NET_RAW is missing.
20    pub fn open(interface: &str, ethertype: u16) -> Result<Self, TransportError> {
21        let fd = unsafe {
22            libc::socket(
23                libc::AF_PACKET,
24                libc::SOCK_DGRAM,
25                (ethertype).to_be() as i32,
26            )
27        };
28        if fd < 0 {
29            let err = std::io::Error::last_os_error();
30            if err.raw_os_error() == Some(libc::EPERM) {
31                return Err(TransportError::StartFailed(
32                    "AF_PACKET requires CAP_NET_RAW capability \
33                     (run as root or use: setcap cap_net_raw=ep <binary>)"
34                        .into(),
35                ));
36            }
37            return Err(TransportError::StartFailed(format!(
38                "socket(AF_PACKET) failed: {}",
39                err
40            )));
41        }
42
43        // Look up interface index
44        let if_index = get_if_index(fd, interface)?;
45
46        // Bind to the interface
47        let mut sll: libc::sockaddr_ll = unsafe { std::mem::zeroed() };
48        sll.sll_family = libc::AF_PACKET as u16;
49        sll.sll_protocol = ethertype.to_be();
50        sll.sll_ifindex = if_index;
51
52        let ret = unsafe {
53            libc::bind(
54                fd,
55                &sll as *const libc::sockaddr_ll as *const libc::sockaddr,
56                std::mem::size_of::<libc::sockaddr_ll>() as libc::socklen_t,
57            )
58        };
59        if ret < 0 {
60            let err = std::io::Error::last_os_error();
61            unsafe { libc::close(fd) };
62            return Err(TransportError::StartFailed(format!(
63                "bind(AF_PACKET, {}) failed: {}",
64                interface, err
65            )));
66        }
67
68        // Set non-blocking for async integration
69        let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
70        if flags < 0 {
71            let err = std::io::Error::last_os_error();
72            unsafe { libc::close(fd) };
73            return Err(TransportError::StartFailed(format!(
74                "fcntl(F_GETFL) failed: {}",
75                err
76            )));
77        }
78        let ret = unsafe { libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) };
79        if ret < 0 {
80            let err = std::io::Error::last_os_error();
81            unsafe { libc::close(fd) };
82            return Err(TransportError::StartFailed(format!(
83                "fcntl(F_SETFL, O_NONBLOCK) failed: {}",
84                err
85            )));
86        }
87
88        Ok(Self {
89            fd,
90            if_index,
91            ethertype,
92        })
93    }
94
95    /// Get the interface index.
96    pub fn if_index(&self) -> i32 {
97        self.if_index
98    }
99
100    /// Get the local MAC address of the bound interface.
101    pub fn local_mac(&self) -> Result<[u8; 6], TransportError> {
102        get_mac_addr(self.fd, self.if_index)
103    }
104
105    /// Get the interface MTU.
106    pub fn interface_mtu(&self) -> Result<u16, TransportError> {
107        get_if_mtu(self.fd, self.if_index)
108    }
109
110    /// Set the socket receive buffer size.
111    pub fn set_recv_buffer_size(&self, size: usize) -> Result<(), TransportError> {
112        let size = size as libc::c_int;
113        let ret = unsafe {
114            libc::setsockopt(
115                self.fd,
116                libc::SOL_SOCKET,
117                libc::SO_RCVBUF,
118                &size as *const libc::c_int as *const libc::c_void,
119                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
120            )
121        };
122        if ret < 0 {
123            return Err(TransportError::StartFailed(format!(
124                "setsockopt(SO_RCVBUF) failed: {}",
125                std::io::Error::last_os_error()
126            )));
127        }
128        Ok(())
129    }
130
131    /// Set the socket send buffer size.
132    pub fn set_send_buffer_size(&self, size: usize) -> Result<(), TransportError> {
133        let size = size as libc::c_int;
134        let ret = unsafe {
135            libc::setsockopt(
136                self.fd,
137                libc::SOL_SOCKET,
138                libc::SO_SNDBUF,
139                &size as *const libc::c_int as *const libc::c_void,
140                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
141            )
142        };
143        if ret < 0 {
144            return Err(TransportError::StartFailed(format!(
145                "setsockopt(SO_SNDBUF) failed: {}",
146                std::io::Error::last_os_error()
147            )));
148        }
149        Ok(())
150    }
151
152    /// Send a payload to a destination MAC address.
153    ///
154    /// Returns the number of bytes sent, or an io::Error.
155    pub fn send_to(&self, data: &[u8], dest_mac: &[u8; 6]) -> std::io::Result<usize> {
156        let mut sll: libc::sockaddr_ll = unsafe { std::mem::zeroed() };
157        sll.sll_family = libc::AF_PACKET as u16;
158        sll.sll_protocol = self.ethertype.to_be();
159        sll.sll_ifindex = self.if_index;
160        sll.sll_halen = 6;
161        sll.sll_addr[..6].copy_from_slice(dest_mac);
162
163        let ret = unsafe {
164            libc::sendto(
165                self.fd,
166                data.as_ptr() as *const libc::c_void,
167                data.len(),
168                0,
169                &sll as *const libc::sockaddr_ll as *const libc::sockaddr,
170                std::mem::size_of::<libc::sockaddr_ll>() as libc::socklen_t,
171            )
172        };
173        if ret < 0 {
174            Err(std::io::Error::last_os_error())
175        } else {
176            Ok(ret as usize)
177        }
178    }
179
180    /// Receive a payload and source MAC address.
181    ///
182    /// Returns (bytes_read, source_mac), or an io::Error.
183    pub fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, [u8; 6])> {
184        let mut sll: libc::sockaddr_ll = unsafe { std::mem::zeroed() };
185        let mut sll_len = std::mem::size_of::<libc::sockaddr_ll>() as libc::socklen_t;
186
187        let ret = unsafe {
188            libc::recvfrom(
189                self.fd,
190                buf.as_mut_ptr() as *mut libc::c_void,
191                buf.len(),
192                0,
193                &mut sll as *mut libc::sockaddr_ll as *mut libc::sockaddr,
194                &mut sll_len,
195            )
196        };
197        if ret < 0 {
198            return Err(std::io::Error::last_os_error());
199        }
200
201        let mut src_mac = [0u8; 6];
202        src_mac.copy_from_slice(&sll.sll_addr[..6]);
203
204        Ok((ret as usize, src_mac))
205    }
206}
207
208impl AsRawFd for PacketSocket {
209    fn as_raw_fd(&self) -> RawFd {
210        self.fd
211    }
212}
213
214impl Drop for PacketSocket {
215    fn drop(&mut self) {
216        unsafe {
217            libc::close(self.fd);
218        }
219    }
220}
221
222// ============================================================================
223// ioctl helpers
224// ============================================================================
225
226/// Get the interface index by name.
227fn get_if_index(_fd: RawFd, interface: &str) -> Result<i32, TransportError> {
228    let c_name = std::ffi::CString::new(interface).map_err(|_| {
229        TransportError::StartFailed(format!("invalid interface name: {}", interface))
230    })?;
231
232    let idx = unsafe { libc::if_nametoindex(c_name.as_ptr()) };
233    if idx == 0 {
234        return Err(TransportError::StartFailed(format!(
235            "interface not found: {} ({})",
236            interface,
237            std::io::Error::last_os_error()
238        )));
239    }
240    Ok(idx as i32)
241}
242
243/// Get the MAC address of an interface by its index.
244fn get_mac_addr(fd: RawFd, if_index: i32) -> Result<[u8; 6], TransportError> {
245    // First get the interface name from the index
246    let mut ifr: libc::ifreq = unsafe { std::mem::zeroed() };
247
248    // Use if_indextoname to get the name
249    let mut name_buf = [0u8; libc::IFNAMSIZ];
250    let ret = unsafe {
251        libc::if_indextoname(
252            if_index as libc::c_uint,
253            name_buf.as_mut_ptr() as *mut libc::c_char,
254        )
255    };
256    if ret.is_null() {
257        return Err(TransportError::StartFailed(format!(
258            "if_indextoname({}) failed: {}",
259            if_index,
260            std::io::Error::last_os_error()
261        )));
262    }
263
264    // Copy name into ifreq
265    let name_len = name_buf
266        .iter()
267        .position(|&b| b == 0)
268        .unwrap_or(name_buf.len());
269    let copy_len = name_len.min(libc::IFNAMSIZ - 1);
270    unsafe {
271        std::ptr::copy_nonoverlapping(
272            name_buf.as_ptr(),
273            ifr.ifr_name.as_mut_ptr() as *mut u8,
274            copy_len,
275        );
276    }
277
278    #[cfg(target_env = "musl")]
279    let ioctl_req = libc::SIOCGIFHWADDR as libc::c_int;
280    #[cfg(not(target_env = "musl"))]
281    let ioctl_req = libc::SIOCGIFHWADDR as libc::c_ulong;
282    let ret = unsafe { libc::ioctl(fd, ioctl_req, &ifr) };
283    if ret < 0 {
284        return Err(TransportError::StartFailed(format!(
285            "ioctl(SIOCGIFHWADDR) failed: {}",
286            std::io::Error::last_os_error()
287        )));
288    }
289
290    let mut mac = [0u8; 6];
291    unsafe {
292        let sa_data = ifr.ifr_ifru.ifru_hwaddr.sa_data;
293        for (i, byte) in mac.iter_mut().enumerate() {
294            *byte = sa_data[i] as u8;
295        }
296    }
297
298    Ok(mac)
299}
300
301/// Get the MTU of an interface by its index.
302fn get_if_mtu(fd: RawFd, if_index: i32) -> Result<u16, TransportError> {
303    let mut ifr: libc::ifreq = unsafe { std::mem::zeroed() };
304
305    // Get the interface name from index
306    let mut name_buf = [0u8; libc::IFNAMSIZ];
307    let ret = unsafe {
308        libc::if_indextoname(
309            if_index as libc::c_uint,
310            name_buf.as_mut_ptr() as *mut libc::c_char,
311        )
312    };
313    if ret.is_null() {
314        return Err(TransportError::StartFailed(format!(
315            "if_indextoname({}) failed: {}",
316            if_index,
317            std::io::Error::last_os_error()
318        )));
319    }
320
321    let name_len = name_buf
322        .iter()
323        .position(|&b| b == 0)
324        .unwrap_or(name_buf.len());
325    let copy_len = name_len.min(libc::IFNAMSIZ - 1);
326    unsafe {
327        std::ptr::copy_nonoverlapping(
328            name_buf.as_ptr(),
329            ifr.ifr_name.as_mut_ptr() as *mut u8,
330            copy_len,
331        );
332    }
333
334    #[cfg(target_env = "musl")]
335    let ioctl_req = libc::SIOCGIFMTU as libc::c_int;
336    #[cfg(not(target_env = "musl"))]
337    let ioctl_req = libc::SIOCGIFMTU as libc::c_ulong;
338    let ret = unsafe { libc::ioctl(fd, ioctl_req, &ifr) };
339    if ret < 0 {
340        return Err(TransportError::StartFailed(format!(
341            "ioctl(SIOCGIFMTU) failed: {}",
342            std::io::Error::last_os_error()
343        )));
344    }
345
346    let mtu = unsafe { ifr.ifr_ifru.ifru_mtu } as u16;
347    Ok(mtu)
348}