use super::error::CommonError;
use crate::libc_call;
use crate::time::DateTime;
use bebytes::BeBytes;
use libc::iovec;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::os::fd::{AsRawFd, RawFd};
const CMSG_SPACE_SIZE: usize = 128;
pub const DEFAULT_BUFFER_SIZE: usize = 4096;
pub trait Socket<T: AsRawFd>: Sized + AsRawFd {
unsafe fn from_raw_fd(fd: RawFd) -> T;
fn send(&self, message: impl BeBytes) -> Result<(isize, DateTime), CommonError>;
fn send_to(
&self,
address: &SocketAddr,
message: impl BeBytes,
) -> Result<(isize, DateTime), CommonError>;
fn receive(&self, buffer: &mut [u8]) -> Result<(isize, DateTime), CommonError>;
fn receive_from(&self, buffer: &mut [u8])
-> Result<(isize, SocketAddr, DateTime), CommonError>;
fn set_socket_options(
&mut self,
level: i32,
name: i32,
value: Option<i32>,
) -> Result<i32, CommonError> {
libc_call!(setsockopt(
self.as_raw_fd(),
level,
name,
&value.unwrap_or(0) as *const std::ffi::c_int as *const std::ffi::c_void,
std::mem::size_of_val(&value) as libc::socklen_t
))
.map_err(CommonError::Io)
}
fn set_fcntl_options(&self) -> Result<i32, CommonError> {
let flags = libc_call!(fcntl(self.as_raw_fd(), libc::F_GETFL)).map_err(CommonError::Io)?;
let new_flags = flags | libc::O_NONBLOCK | libc::O_CLOEXEC;
libc_call!(fcntl(self.as_raw_fd(), libc::F_SETFL, new_flags)).map_err(CommonError::Io)
}
fn set_timestamping_options(&mut self) -> Result<i32, CommonError> {
let value = libc::SOF_TIMESTAMPING_SOFTWARE
| libc::SOF_TIMESTAMPING_RX_SOFTWARE
| libc::SOF_TIMESTAMPING_TX_SOFTWARE;
self.set_socket_options(libc::SOL_SOCKET, libc::SO_TIMESTAMPING, Some(value as i32))
}
}
pub fn socketaddr_to_sockaddr(addr: &SocketAddr) -> (libc::sockaddr, u32) {
let mut storage: libc::sockaddr_storage = unsafe { core::mem::zeroed() };
log::debug!("addr: {}", addr.to_string());
let (sock_addr, sock_addr_len) = match addr {
SocketAddr::V4(a) => {
let sockaddr_in: *mut libc::sockaddr_in =
&mut storage as *mut _ as *mut libc::sockaddr_in;
unsafe {
(*sockaddr_in).sin_family = libc::AF_INET as libc::sa_family_t;
(*sockaddr_in).sin_port = a.port().to_be();
(*sockaddr_in).sin_addr.s_addr = u32::from_ne_bytes(a.ip().octets());
}
(
sockaddr_in as *const libc::sockaddr,
core::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t,
)
}
SocketAddr::V6(a) => {
let sockaddr_in6: *mut libc::sockaddr_in6 =
&mut storage as *mut _ as *mut libc::sockaddr_in6;
unsafe {
(*sockaddr_in6).sin6_family = libc::AF_INET6 as libc::sa_family_t;
(*sockaddr_in6).sin6_port = a.port().to_be();
(*sockaddr_in6)
.sin6_addr
.s6_addr
.copy_from_slice(&a.ip().octets());
(*sockaddr_in6).sin6_flowinfo = a.flowinfo();
(*sockaddr_in6).sin6_scope_id = a.scope_id();
}
(
sockaddr_in6 as *const libc::sockaddr,
core::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t,
)
}
};
(unsafe { sock_addr.read() }, sock_addr_len)
}
pub fn storage_to_socket_addr(
addr_storage: &libc::sockaddr_storage,
) -> Result<SocketAddr, CommonError> {
let socket_addr = match addr_storage.ss_family as i32 {
libc::AF_INET => {
let sockaddr: &libc::sockaddr_in = unsafe { core::mem::transmute(addr_storage) };
let ip_bytes = sockaddr.sin_addr.s_addr.to_be_bytes();
SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(
ip_bytes[3],
ip_bytes[2],
ip_bytes[1],
ip_bytes[0],
)),
sockaddr.sin_port.to_be(),
)
}
libc::AF_INET6 => {
let sockaddr: &libc::sockaddr_in6 = unsafe { core::mem::transmute(&addr_storage) };
SocketAddr::new(
IpAddr::V6(Ipv6Addr::from(sockaddr.sin6_addr.s6_addr)),
sockaddr.sin6_port.to_be(),
)
}
_ => return Err(CommonError::UnknownAddressFamily),
};
Ok(socket_addr)
}
pub fn to_msghdr(bytes: &mut [u8], address: &mut SocketAddr) -> libc::msghdr {
let msg_iov = iovec {
iov_base: bytes.as_mut_ptr() as *mut libc::c_void,
iov_len: bytes.len(),
};
let (mut sockaddr, _) = socketaddr_to_sockaddr(address);
libc::msghdr {
msg_name: &mut sockaddr as *mut _ as *mut libc::c_void,
msg_namelen: core::mem::size_of_val(&sockaddr) as u32,
msg_iov: &msg_iov as *const _ as *mut _,
msg_iovlen: core::mem::size_of_val(&msg_iov),
msg_control: [0; CMSG_SPACE_SIZE].as_mut_ptr() as *mut libc::c_void,
msg_controllen: CMSG_SPACE_SIZE,
msg_flags: 0,
}
}
pub fn retrieve_data_from_headers(
msg_hdrs: Vec<libc::mmsghdr>,
) -> Result<Vec<DateTime>, CommonError> {
let mut received_data = Vec::new();
for msg_hdr in msg_hdrs.iter() {
log::trace!("msg_hdr: {:?}", msg_hdr.msg_hdr.msg_name);
let timestamp = retrieve_data_from_header(&msg_hdr.msg_hdr)?;
received_data.push(timestamp);
}
Ok(received_data)
}
pub fn retrieve_data_from_header(msg_hdr: &libc::msghdr) -> Result<DateTime, CommonError> {
let mut cmsg_ptr = unsafe { libc::CMSG_FIRSTHDR(msg_hdr as *const libc::msghdr) };
let mut result = Err(CommonError::Generic("No tx timestamp found".to_string()));
while !cmsg_ptr.is_null() {
unsafe {
if (*cmsg_ptr).cmsg_level == libc::SOL_SOCKET
&& (*cmsg_ptr).cmsg_type == libc::SCM_TIMESTAMPING
{
let ts_ptr = libc::CMSG_DATA(cmsg_ptr) as *const [libc::timespec; 3];
let ts = { *ts_ptr }[0]; result = Ok(DateTime::from_timespec(ts));
log::debug!("Timestamp: {:?}", result);
}
if (*cmsg_ptr).cmsg_level == libc::IPPROTO_IP && (*cmsg_ptr).cmsg_type == libc::IP_TOS {
let tos_value: u8 = *(libc::CMSG_DATA(cmsg_ptr) as *const u8);
log::debug!("TOS value: {}", tos_value);
}
cmsg_ptr = libc::CMSG_NXTHDR(msg_hdr as *const libc::msghdr, cmsg_ptr);
}
}
result
}
pub fn init_vec_of_mmsghdr(
max_msg: usize,
msg_buffers: &mut [[u8; DEFAULT_BUFFER_SIZE]],
addresses: &mut [SocketAddr],
) -> Vec<libc::mmsghdr> {
let mut msgvec: Vec<libc::mmsghdr> = vec![unsafe { core::mem::zeroed() }; max_msg];
for (i, (msg, buffer)) in msgvec
.iter_mut()
.zip(&mut msg_buffers.iter_mut())
.enumerate()
{
let socket_addr_index = i % addresses.len();
msg.msg_hdr = to_msghdr(buffer, &mut addresses[socket_addr_index]);
}
msgvec
}