use crate::socket::to_socket_protocol;
use crate::socket::{IpVersion, SocketOption};
use socket2::{SockAddr, Socket as SystemSocket};
use std::io;
use std::mem::MaybeUninit;
use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream, UdpSocket};
use std::sync::Arc;
use std::time::Duration;
#[derive(Clone, Debug)]
pub struct Socket {
inner: Arc<SystemSocket>,
}
impl Socket {
pub fn new(socket_option: SocketOption) -> io::Result<Socket> {
let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
SystemSocket::new(
socket_option.ip_version.to_domain(),
socket_option.socket_type.to_type(),
Some(to_socket_protocol(protocol)),
)?
} else {
SystemSocket::new(
socket_option.ip_version.to_domain(),
socket_option.socket_type.to_type(),
None,
)?
};
if socket_option.non_blocking {
socket.set_nonblocking(true)?;
}
Ok(Socket {
inner: Arc::new(socket),
})
}
pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
let addr: SockAddr = SockAddr::from(addr);
self.inner.bind(&addr)
}
pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
match self.inner.send(buf) {
Ok(n) => Ok(n),
Err(e) => Err(e),
}
}
pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
let target: SockAddr = SockAddr::from(target);
match self.inner.send_to(buf, &target) {
Ok(n) => Ok(n),
Err(e) => Err(e),
}
}
pub fn receive(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
match self.inner.recv(recv_buf) {
Ok(result) => Ok(result),
Err(e) => Err(e),
}
}
pub fn receive_from(&self, buf: &mut Vec<u8>) -> io::Result<(usize, SocketAddr)> {
let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
match self.inner.recv_from(recv_buf) {
Ok(result) => {
let (n, addr) = result;
match addr.as_socket() {
Some(addr) => return Ok((n, addr)),
None => {
return Err(io::Error::new(
io::ErrorKind::Other,
"Invalid socket address",
))
}
}
}
Err(e) => Err(e),
}
}
pub fn write(&self, buf: &[u8]) -> io::Result<usize> {
match self.inner.send(buf) {
Ok(n) => Ok(n),
Err(e) => Err(e),
}
}
pub fn read(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
match self.inner.recv(recv_buf) {
Ok(result) => Ok(result),
Err(e) => Err(e),
}
}
pub fn ttl(&self, ip_version: IpVersion) -> io::Result<u32> {
match ip_version {
IpVersion::V4 => self.inner.ttl(),
IpVersion::V6 => self.inner.unicast_hops_v6(),
}
}
pub fn set_ttl(&self, ttl: u32, ip_version: IpVersion) -> io::Result<()> {
match ip_version {
IpVersion::V4 => self.inner.set_ttl(ttl),
IpVersion::V6 => self.inner.set_unicast_hops_v6(ttl),
}
}
pub fn tos(&self) -> io::Result<u32> {
self.inner.tos()
}
pub fn set_tos(&self, tos: u32) -> io::Result<()> {
self.inner.set_tos(tos)
}
pub fn receive_tos(&self) -> io::Result<bool> {
self.inner.recv_tos()
}
pub fn set_receive_tos(&self, receive_tos: bool) -> io::Result<()> {
self.inner.set_recv_tos(receive_tos)
}
pub fn connect(&self, addr: &SocketAddr) -> io::Result<()> {
let addr: SockAddr = SockAddr::from(*addr);
self.inner.connect(&addr)
}
pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
let addr: SockAddr = SockAddr::from(*addr);
self.inner.connect_timeout(&addr, timeout)
}
pub fn listen(&self, backlog: i32) -> io::Result<()> {
self.inner.listen(backlog)
}
pub fn accept(&self) -> io::Result<(Socket, SocketAddr)> {
match self.inner.accept() {
Ok((socket, addr)) => Ok((
Socket {
inner: Arc::new(socket),
},
addr.as_socket().unwrap(),
)),
Err(e) => Err(e),
}
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
match self.inner.local_addr() {
Ok(addr) => Ok(addr.as_socket().unwrap()),
Err(e) => Err(e),
}
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
match self.inner.peer_addr() {
Ok(addr) => Ok(addr.as_socket().unwrap()),
Err(e) => Err(e),
}
}
pub fn socket_type(&self) -> io::Result<crate::socket::SocketType> {
match self.inner.r#type() {
Ok(socktype) => Ok(crate::socket::SocketType::from_type(socktype)),
Err(e) => Err(e),
}
}
pub fn try_clone(&self) -> io::Result<Socket> {
match self.inner.try_clone() {
Ok(socket) => Ok(Socket {
inner: Arc::new(socket),
}),
Err(e) => Err(e),
}
}
#[cfg(not(target_os = "windows"))]
pub fn is_nonblocking(&self) -> io::Result<bool> {
self.inner.nonblocking()
}
pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
self.inner.set_nonblocking(nonblocking)
}
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.inner.shutdown(how)
}
pub fn is_broadcast(&self) -> io::Result<bool> {
self.inner.broadcast()
}
pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
self.inner.set_broadcast(broadcast)
}
pub fn get_error(&self) -> io::Result<Option<io::Error>> {
self.inner.take_error()
}
pub fn keepalive(&self) -> io::Result<bool> {
self.inner.keepalive()
}
pub fn set_keepalive(&self, keepalive: bool) -> io::Result<()> {
self.inner.set_keepalive(keepalive)
}
pub fn linger(&self) -> io::Result<Option<Duration>> {
self.inner.linger()
}
pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
self.inner.set_linger(dur)
}
pub fn receive_buffer_size(&self) -> io::Result<usize> {
self.inner.recv_buffer_size()
}
pub fn set_receive_buffer_size(&self, size: usize) -> io::Result<()> {
self.inner.set_recv_buffer_size(size)
}
pub fn receive_timeout(&self) -> io::Result<Option<Duration>> {
self.inner.read_timeout()
}
pub fn set_receive_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
self.inner.set_read_timeout(duration)
}
pub fn reuse_address(&self) -> io::Result<bool> {
self.inner.reuse_address()
}
pub fn set_reuse_address(&self, reuse: bool) -> io::Result<()> {
self.inner.set_reuse_address(reuse)
}
pub fn send_buffer_size(&self) -> io::Result<usize> {
self.inner.send_buffer_size()
}
pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
self.inner.set_send_buffer_size(size)
}
pub fn send_timeout(&self) -> io::Result<Option<Duration>> {
self.inner.write_timeout()
}
pub fn set_send_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
self.inner.set_write_timeout(duration)
}
pub fn is_ip_header_included(&self) -> io::Result<bool> {
self.inner.header_included()
}
pub fn set_ip_header_included(&self, include: bool) -> io::Result<()> {
self.inner.set_header_included(include)
}
pub fn nodelay(&self) -> io::Result<bool> {
self.inner.nodelay()
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.inner.set_nodelay(nodelay)
}
pub fn into_tcp_stream(self) -> io::Result<TcpStream> {
match Arc::try_unwrap(self.inner) {
Ok(socket) => Ok(socket.into()),
Err(_) => Err(io::Error::new(
io::ErrorKind::Other,
"Failed to unwrap socket",
)),
}
}
pub fn into_tcp_listener(self) -> io::Result<TcpListener> {
match Arc::try_unwrap(self.inner) {
Ok(socket) => Ok(socket.into()),
Err(_) => Err(io::Error::new(
io::ErrorKind::Other,
"Failed to unwrap socket",
)),
}
}
pub fn into_udp_socket(self) -> io::Result<UdpSocket> {
match Arc::try_unwrap(self.inner) {
Ok(socket) => Ok(socket.into()),
Err(_) => Err(io::Error::new(
io::ErrorKind::Other,
"Failed to unwrap socket",
)),
}
}
}