use crate::transport::TransportError;
pub const ETHERNET_BROADCAST: [u8; 6] = [0xff; 6];
#[cfg(target_os = "linux")]
#[path = "socket_linux.rs"]
mod platform;
#[cfg(target_os = "macos")]
#[path = "socket_macos.rs"]
mod platform;
#[cfg(any(target_os = "linux", target_os = "macos"))]
pub use platform::PacketSocket;
#[cfg(target_os = "linux")]
mod async_impl {
use super::PacketSocket;
use crate::transport::TransportError;
use tokio::io::unix::AsyncFd;
pub struct AsyncPacketSocket {
inner: AsyncFd<PacketSocket>,
}
impl AsyncPacketSocket {
pub fn new(socket: PacketSocket) -> Result<Self, TransportError> {
let async_fd = AsyncFd::new(socket)
.map_err(|e| TransportError::StartFailed(format!("AsyncFd::new failed: {}", e)))?;
Ok(Self { inner: async_fd })
}
pub async fn send_to(
&self,
data: &[u8],
dest_mac: &[u8; 6],
) -> Result<usize, TransportError> {
loop {
let mut guard = self
.inner
.writable()
.await
.map_err(|e| TransportError::SendFailed(format!("writable wait: {}", e)))?;
match guard.try_io(|inner| inner.get_ref().send_to(data, dest_mac)) {
Ok(Ok(n)) => return Ok(n),
Ok(Err(e)) => return Err(TransportError::SendFailed(format!("{}", e))),
Err(_would_block) => continue,
}
}
}
pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, [u8; 6]), TransportError> {
loop {
let mut guard = self
.inner
.readable()
.await
.map_err(|e| TransportError::RecvFailed(format!("readable wait: {}", e)))?;
match guard.try_io(|inner| inner.get_ref().recv_from(buf)) {
Ok(Ok(result)) => return Ok(result),
Ok(Err(e)) => return Err(TransportError::RecvFailed(format!("{}", e))),
Err(_would_block) => continue,
}
}
}
pub fn get_ref(&self) -> &PacketSocket {
self.inner.get_ref()
}
pub fn shutdown(&self) {}
}
}
#[cfg(target_os = "macos")]
mod async_impl {
use super::PacketSocket;
use crate::transport::TransportError;
use std::os::unix::io::AsRawFd;
use std::sync::Arc;
type Frame = (Vec<u8>, [u8; 6]);
pub struct AsyncPacketSocket {
inner: Arc<PacketSocket>,
rx: tokio::sync::Mutex<tokio::sync::mpsc::Receiver<Frame>>,
reader_thread: Option<std::thread::JoinHandle<()>>,
}
impl AsyncPacketSocket {
pub fn new(socket: PacketSocket) -> Result<Self, TransportError> {
let (tx, rx) = tokio::sync::mpsc::channel::<Frame>(1024);
let inner = Arc::new(socket);
let reader_socket = Arc::clone(&inner);
let reader_thread = std::thread::Builder::new()
.name("bpf-reader".into())
.spawn(move || {
let bpf_fd = reader_socket.as_raw_fd();
let shutdown_fd = reader_socket.shutdown_read_fd();
let bpf_buflen = reader_socket.bpf_buflen();
let mut read_buf = vec![0u8; bpf_buflen];
let mut parse_buf = vec![0u8; bpf_buflen];
let mut parse_offset: usize = 0;
let mut parse_len: usize = 0;
let nfds = bpf_fd.max(shutdown_fd) + 1;
loop {
while let Some(result) = super::platform::parse_next_frame(
&parse_buf,
&mut parse_offset,
parse_len,
&mut read_buf,
) {
match result {
Ok((n, mac)) => {
let data = read_buf[..n].to_vec();
if tx.blocking_send((data, mac)).is_err() {
return;
}
}
Err(_) => break,
}
}
unsafe {
let mut read_fds: libc::fd_set = std::mem::zeroed();
libc::FD_ZERO(&mut read_fds);
libc::FD_SET(bpf_fd, &mut read_fds);
libc::FD_SET(shutdown_fd, &mut read_fds);
let ret = libc::select(
nfds,
&mut read_fds,
std::ptr::null_mut(),
std::ptr::null_mut(),
std::ptr::null_mut(),
);
if ret < 0 {
let err = std::io::Error::last_os_error();
if err.kind() == std::io::ErrorKind::Interrupted {
continue;
}
break;
}
if libc::FD_ISSET(shutdown_fd, &read_fds) {
break; }
}
let ret = unsafe {
libc::read(
bpf_fd,
parse_buf.as_mut_ptr() as *mut libc::c_void,
bpf_buflen,
)
};
if ret <= 0 {
if ret < 0 {
let err = std::io::Error::last_os_error();
if err.raw_os_error() == Some(libc::EBADF) {
break;
}
}
parse_len = 0;
parse_offset = 0;
continue;
}
parse_len = ret as usize;
parse_offset = 0;
}
})
.map_err(|e| TransportError::StartFailed(format!("reader thread: {}", e)))?;
Ok(Self {
inner,
rx: tokio::sync::Mutex::new(rx),
reader_thread: Some(reader_thread),
})
}
pub async fn send_to(
&self,
data: &[u8],
dest_mac: &[u8; 6],
) -> Result<usize, TransportError> {
let socket = Arc::clone(&self.inner);
let data = data.to_vec();
let dest = *dest_mac;
tokio::task::spawn_blocking(move || {
socket
.send_to(&data, &dest)
.map_err(|e| TransportError::SendFailed(format!("{}", e)))
})
.await
.map_err(|e| TransportError::SendFailed(format!("spawn_blocking: {}", e)))?
}
pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, [u8; 6]), TransportError> {
let mut rx = self.rx.lock().await;
match rx.recv().await {
Some((data, mac)) => {
let n = data.len().min(buf.len());
buf[..n].copy_from_slice(&data[..n]);
Ok((n, mac))
}
None => Err(TransportError::RecvFailed("reader thread stopped".into())),
}
}
pub fn get_ref(&self) -> &PacketSocket {
&self.inner
}
pub fn shutdown(&self) {
self.inner.request_shutdown();
}
}
impl Drop for AsyncPacketSocket {
fn drop(&mut self) {
self.inner.request_shutdown();
if let Some(handle) = self.reader_thread.take() {
let _ = handle.join();
}
}
}
}
#[cfg(any(target_os = "linux", target_os = "macos"))]
pub use async_impl::AsyncPacketSocket;
#[cfg(any(target_os = "linux", target_os = "macos"))]
impl PacketSocket {
pub fn into_async(self) -> Result<AsyncPacketSocket, TransportError> {
AsyncPacketSocket::new(self)
}
}
#[cfg(windows)]
pub struct PacketSocket;
#[cfg(windows)]
pub struct AsyncPacketSocket;