use bebytes::BeBytes;
use libc::MSG_NOSIGNAL;
use crate::{
socket::{socketaddr_to_sockaddr, storage_to_socket_addr, Socket},
time::DateTime,
CommonError,
};
use core::ops::Deref;
use std::{
io,
net::SocketAddr,
os::fd::{AsRawFd, RawFd},
};
pub enum SocketError {
BindFailed(io::Error),
ListenFailed(io::Error),
AcceptFailed(io::Error),
}
pub struct TimestampedTcpSocket {
inner: RawFd,
}
impl Drop for TimestampedTcpSocket {
fn drop(&mut self) {
unsafe { libc::close(self.inner) };
}
}
impl AsRawFd for TimestampedTcpSocket {
fn as_raw_fd(&self) -> RawFd {
self.inner
}
}
impl From<&mut i32> for TimestampedTcpSocket {
fn from(value: &mut i32) -> Self {
Self::new(value.as_raw_fd())
}
}
impl Deref for TimestampedTcpSocket {
type Target = RawFd;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl TimestampedTcpSocket {
pub fn new(socket: RawFd) -> Self {
unsafe {
libc::setsockopt(
socket.as_raw_fd(),
libc::SOL_SOCKET,
libc::SO_REUSEADDR,
1 as *const _,
std::mem::size_of::<i32>() as u32,
);
}
Self { inner: socket }
}
pub fn bind(addr: &SocketAddr) -> Result<Self, CommonError> {
let socket_fd = match addr {
SocketAddr::V4(_) => unsafe { libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0) },
SocketAddr::V6(_) => unsafe { libc::socket(libc::AF_INET6, libc::SOCK_STREAM, 0) },
};
if socket_fd < 0 {
return Err(CommonError::SocketCreateFailed(io::Error::last_os_error()));
}
let (sock_addr, sock_addr_len) = socketaddr_to_sockaddr(addr);
let sock_addr_ptr = &sock_addr as *const _;
if unsafe { libc::bind(socket_fd, sock_addr_ptr, sock_addr_len) } < 0 {
return Err(CommonError::SocketBindFailed(io::Error::last_os_error()));
}
Ok(TimestampedTcpSocket { inner: socket_fd })
}
pub fn listen(&self, backlog: i32) -> Result<(), CommonError> {
if unsafe { libc::listen(self.inner, backlog) } < 0 {
Err(CommonError::SocketListenFailed(io::Error::last_os_error()))
} else {
Ok(())
}
}
pub fn accept(&self) -> Result<(TimestampedTcpSocket, SocketAddr), CommonError> {
let mut addr_storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
let mut addr_len = std::mem::size_of_val(&addr_storage) as libc::socklen_t;
let new_socket_fd = unsafe {
libc::accept(
self.inner,
&mut addr_storage as *mut libc::sockaddr_storage as *mut libc::sockaddr,
&mut addr_len,
)
};
if new_socket_fd < 0 {
return Err(CommonError::SocketAcceptFailed(io::Error::last_os_error()));
}
let client_addr = storage_to_socket_addr(&addr_storage)?;
Ok((
TimestampedTcpSocket {
inner: new_socket_fd,
},
client_addr,
))
}
pub fn connect(&mut self, addr: SocketAddr) -> Result<i32, CommonError> {
let socket_fd = self.inner;
if socket_fd < 0 {
return Err(CommonError::SocketCreateFailed(io::Error::last_os_error()));
}
let (sock_addr, sock_addr_len) = socketaddr_to_sockaddr(&addr);
let sock_addr_ptr = &sock_addr as *const _;
let result = unsafe { libc::connect(socket_fd, sock_addr_ptr, sock_addr_len) };
log::debug!("Connect result: {}", result);
if result < 0 {
let err = io::Error::last_os_error();
unsafe { libc::close(socket_fd) };
return Err(CommonError::SocketConnectFailed(err));
}
Ok(result)
}
}
impl Socket<TimestampedTcpSocket> for TimestampedTcpSocket {
unsafe fn from_raw_fd(fd: RawFd) -> TimestampedTcpSocket {
Self { inner: fd }
}
fn send(&self, message: impl BeBytes) -> Result<(isize, DateTime), CommonError> {
let bytes = message.to_be_bytes();
let timestamp = DateTime::utc_now();
let result = unsafe {
libc::send(
self.inner,
bytes.as_ptr() as *const libc::c_void,
bytes.len(),
MSG_NOSIGNAL,
)
};
if result < 0 {
let error = io::Error::last_os_error();
return Err(CommonError::from(error));
}
Ok((result, timestamp))
}
fn send_to(
&self,
_address: &SocketAddr,
message: impl BeBytes,
) -> Result<(isize, crate::time::DateTime), CommonError> {
self.send(message)
}
fn receive(&self, buffer: &mut [u8]) -> Result<(isize, DateTime), CommonError> {
let timestamp = DateTime::utc_now();
let result = unsafe {
libc::recv(
self.inner,
buffer.as_mut_ptr() as *mut libc::c_void,
buffer.len(),
MSG_NOSIGNAL,
)
};
if result < 0 {
let error = io::Error::last_os_error();
return Err(CommonError::from(error));
}
Ok((result, timestamp))
}
fn receive_from(
&self,
buffer: &mut [u8],
) -> Result<(isize, SocketAddr, DateTime), CommonError> {
let (result, timestamp) = self.receive(buffer)?;
let mut addr_storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
let mut addr_len = std::mem::size_of::<libc::sockaddr_storage>() as libc::socklen_t;
if unsafe {
libc::getpeername(
self.inner,
&mut addr_storage as *mut _ as *mut _,
&mut addr_len,
)
} == -1
{
return Err(CommonError::SocketGetPeerName(io::Error::last_os_error()));
}
let peer_address = storage_to_socket_addr(&addr_storage)?;
Ok((result, peer_address, timestamp))
}
}