use bebytes::BeBytes;
use libc::{iovec, mmsghdr, msghdr, recvmsg, sendmsg, sockaddr_storage, timespec};
use std::io::{self, IoSliceMut};
use std::os::fd::{AsRawFd, RawFd};
use std::ptr;
use std::{io::IoSlice, net::SocketAddr, ops::Deref};
use crate::error::CommonError;
use crate::libc_call;
use crate::socket::{
init_vec_of_mmsghdr, retrieve_data_from_header, socketaddr_to_sockaddr, storage_to_socket_addr,
to_msghdr, Socket, DEFAULT_BUFFER_SIZE,
};
use crate::time::DateTime;
const MAX_MSG: usize = 2;
const CMSG_SPACE_SIZE: usize = 128;
pub struct TimestampedUdpSocket {
inner: RawFd,
}
impl Drop for TimestampedUdpSocket {
fn drop(&mut self) {
unsafe { libc::close(self.inner) };
}
}
impl AsRawFd for TimestampedUdpSocket {
fn as_raw_fd(&self) -> RawFd {
self.inner
}
}
impl From<&mut i32> for TimestampedUdpSocket {
fn from(value: &mut i32) -> Self {
Self::new(value.as_raw_fd())
}
}
impl Deref for TimestampedUdpSocket {
type Target = RawFd;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl TimestampedUdpSocket {
pub fn new(socket: RawFd) -> Self {
Self { inner: socket }
}
pub fn bind(addr: &SocketAddr) -> Result<Self, CommonError> {
#[cfg(target_os = "linux")]
let socket_fd = match addr {
SocketAddr::V4(_) => unsafe { libc::socket(libc::AF_INET, libc::SOCK_DGRAM, 0) },
SocketAddr::V6(_) => unsafe { libc::socket(libc::AF_INET6, libc::SOCK_DGRAM, 0) },
};
if socket_fd < 0 {
return Err(CommonError::SocketCreateFailed(io::Error::last_os_error()));
}
let (sock_addr, sock_addr_len) = socketaddr_to_sockaddr(addr);
let sock_addr_ptr = &sock_addr as *const _;
if unsafe { libc::bind(socket_fd, sock_addr_ptr, sock_addr_len) } < 0 {
return Err(CommonError::SocketBindFailed(io::Error::last_os_error()));
}
Ok(Self { inner: socket_fd })
}
pub fn connect(&self, address: SocketAddr) -> Result<i32, CommonError> {
let (addr, len) = socketaddr_to_sockaddr(&address);
let res = libc_call!(connect(self.inner, &addr as *const _ as *const _, len))
.map_err(CommonError::Io)?;
Ok(res)
}
pub fn receive_from_multiple(
&self,
buffers: &mut [[u8; DEFAULT_BUFFER_SIZE]],
num_messages: usize,
) -> Result<Vec<(usize, SocketAddr, DateTime)>, CommonError> {
let fd = self.as_raw_fd();
let mut msg_hdrs: Vec<mmsghdr> = Vec::new();
for buffer in buffers.iter_mut() {
let mut addr_storage: SocketAddr = unsafe { std::mem::zeroed() };
let buffer_ptr = buffer.as_mut_ptr();
let msg_iov = iovec {
iov_base: buffer_ptr as *mut libc::c_void,
iov_len: buffer.len(),
};
let msg_hdr = msghdr {
msg_name: &mut addr_storage as *mut _ as *mut libc::c_void,
msg_namelen: std::mem::size_of_val(&addr_storage) as u32,
msg_iov: &msg_iov as *const _ as *mut _,
msg_iovlen: core::mem::size_of_val(&msg_iov),
msg_control: [0; CMSG_SPACE_SIZE].as_mut_ptr() as *mut libc::c_void,
msg_controllen: CMSG_SPACE_SIZE,
msg_flags: 0,
};
msg_hdrs.push(mmsghdr {
msg_hdr,
msg_len: std::mem::size_of::<msghdr>() as u32,
});
}
log::trace!("iov ptr {:?}", msg_hdrs[0].msg_hdr.msg_iov);
let (mut timestamp, result) = match recvmmsg_timestamped(fd, &mut msg_hdrs, num_messages) {
Ok(value) => value,
Err(e) => {
log::debug!("Error receiving multiple messages: {:?}", e);
return Err(e);
}
};
log::trace!("iov ptr {:?}", msg_hdrs[0].msg_hdr.msg_iov);
let mut received_data = Vec::new();
for mmsg_hdr in msg_hdrs.iter().take(result as usize) {
let socket_addr = storage_to_socket_addr(unsafe {
&*(mmsg_hdr.msg_hdr.msg_name as *const libc::sockaddr_storage)
})?;
if let Ok(datetime) = retrieve_data_from_header(&mmsg_hdr.msg_hdr) {
timestamp = datetime;
log::debug!("Timestamp {:?} from {:?}", timestamp, socket_addr);
};
received_data.push((mmsg_hdr.msg_len as usize, socket_addr, timestamp));
}
Ok(received_data)
}
pub fn retrieve_tx_timestamps(
&mut self,
addresses: &mut [SocketAddr],
) -> Result<Vec<DateTime>, CommonError> {
let mut timestamps = Vec::new();
let mut msg_buffers: [[u8; DEFAULT_BUFFER_SIZE]; MAX_MSG] = unsafe { core::mem::zeroed() };
let mut msgvec = init_vec_of_mmsghdr(MAX_MSG, &mut msg_buffers, addresses);
let res = unsafe {
libc::recvmmsg(
self.as_raw_fd(),
msgvec.as_mut_ptr(),
msgvec.len() as u32,
libc::MSG_ERRQUEUE,
std::ptr::null_mut::<timespec>(),
)
};
if res >= 0 {
for msg in &msgvec {
if let Ok(date_time) = retrieve_data_from_header(&msg.msg_hdr) {
timestamps.push(date_time);
}
}
Ok(timestamps)
} else {
let last_os_error = std::io::Error::last_os_error();
Err(CommonError::Io(last_os_error))
}
}
pub fn retrieve_tx_timestamp(&mut self) -> Result<(usize, SocketAddr, DateTime), CommonError> {
let mut msg_buffer = [0u8; DEFAULT_BUFFER_SIZE];
let mut address: SocketAddr = unsafe { core::mem::zeroed() };
let mut msgh = to_msghdr(&mut msg_buffer, &mut address);
#[cfg(target_os = "linux")]
{
let res = unsafe { libc::recvmsg(self.as_raw_fd(), &mut msgh, libc::MSG_ERRQUEUE) };
let socket_addr = storage_to_socket_addr(unsafe {
&*(msgh.msg_name as *const libc::sockaddr_storage)
})?;
if res >= 0 {
let datetime = retrieve_data_from_header(&msgh)?;
Ok((res as usize, socket_addr, datetime))
} else {
let err = std::io::Error::last_os_error();
Err(CommonError::Io(err))
}
}
}
}
fn recvmmsg_timestamped(
fd: i32,
msg_hdrs: &mut [mmsghdr],
num_messages: usize,
) -> Result<(DateTime, i32), CommonError> {
let timestamp = DateTime::utc_now();
let result = unsafe {
libc::recvmmsg(
fd,
msg_hdrs.as_mut_ptr(),
num_messages as u32,
0,
ptr::null::<timespec>() as *mut _,
)
};
log::trace!("First buffer {:?}", unsafe {
msg_hdrs[0].msg_hdr.msg_iov.read().iov_len
});
if result < 0 {
let last_os_error = io::Error::last_os_error();
log::debug!("Error receiving multiple messages: {:?}", last_os_error);
return Err(last_os_error.into());
}
Ok((timestamp, result))
}
impl Socket<TimestampedUdpSocket> for TimestampedUdpSocket {
unsafe fn from_raw_fd(fd: RawFd) -> TimestampedUdpSocket {
Self { inner: fd }
}
fn send(&self, buffer: impl BeBytes) -> Result<(isize, DateTime), CommonError> {
let data = buffer.to_be_bytes();
let length = data.len();
let timestamp = DateTime::utc_now();
let result = libc_call!(send(self.inner, data.as_ptr() as *const _, length, 0))
.map_err(CommonError::Io)?;
Ok((result, timestamp))
}
fn send_to(
&self,
address: &SocketAddr,
message: impl BeBytes,
) -> Result<(isize, DateTime), CommonError> {
let fd = self.as_raw_fd();
let bytes = message.to_be_bytes();
let iov = [IoSlice::new(&bytes)];
let (mut sock_addr, _len) = socketaddr_to_sockaddr(address);
log::trace!("Sending to {:?}", sock_addr.sa_data);
let msg = msghdr {
msg_name: &mut sock_addr as *mut _ as *mut libc::c_void,
msg_namelen: core::mem::size_of_val(&sock_addr) as u32,
msg_iov: iov.as_ptr() as *mut libc::iovec,
msg_iovlen: iov.len(),
msg_control: std::ptr::null_mut(),
msg_controllen: 0,
msg_flags: 0,
};
let utc_now = DateTime::utc_now();
let result = unsafe { sendmsg(fd, &msg, 0) };
Ok((result, utc_now))
}
fn receive(&self, _buffer: &mut [u8]) -> Result<(isize, DateTime), CommonError> {
unimplemented!()
}
fn receive_from(
&self,
buffer: &mut [u8],
) -> Result<(isize, SocketAddr, DateTime), CommonError> {
let fd = self.as_raw_fd();
let mut addr_storage: sockaddr_storage = unsafe { core::mem::zeroed() };
let iov = [IoSliceMut::new(buffer)];
let mut msg: msghdr = unsafe { core::mem::zeroed() };
msg.msg_name = &mut addr_storage as *mut _ as *mut libc::c_void;
msg.msg_namelen = core::mem::size_of_val(&addr_storage) as u32;
msg.msg_iov = iov.as_ptr() as *mut iovec;
msg.msg_iovlen = iov.len();
const SPACE_SIZE: usize = unsafe {
libc::CMSG_SPACE(core::mem::size_of::<libc::timeval>() as u32) as usize * MAX_MSG
};
let mut cmsg_space: [u8; SPACE_SIZE] = unsafe { core::mem::zeroed() };
msg.msg_control = cmsg_space.as_mut_ptr() as *mut libc::c_void;
msg.msg_controllen = cmsg_space.len();
let mut timestamp = DateTime::utc_now();
let n = unsafe { recvmsg(fd, &mut msg, 0) };
if n < 0 {
return Err(CommonError::Io(std::io::Error::last_os_error()));
}
let socket_addr =
storage_to_socket_addr(unsafe { &*(msg.msg_name as *const libc::sockaddr_storage) })?;
log::debug!("Socket address: {:?}", socket_addr);
if let Ok(date_time) = retrieve_data_from_header(&msg) {
timestamp = date_time;
log::debug!("Timestamp: {:?}", timestamp);
};
Ok((n, socket_addr, timestamp))
}
}
#[derive(BeBytes, PartialEq, Debug, Clone)]
struct Message {
pub data: Vec<u8>,
}