use std::fmt;
use std::io::{self, Read, Write};
use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown};
#[cfg(all(unix, feature = "unix"))]
use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream};
use std::time::Duration;
#[cfg(any(unix, target_os = "redox"))]
use libc::MSG_OOB;
#[cfg(windows)]
use winapi::um::winsock2::MSG_OOB;
use crate::sys;
use crate::{Domain, Protocol, SockAddr, Type};
pub struct Socket {
pub(crate) inner: sys::Socket,
}
impl Socket {
pub fn new(domain: Domain, type_: Type, protocol: Option<Protocol>) -> io::Result<Socket> {
let protocol = protocol.map(|p| p.0).unwrap_or(0);
Ok(Socket {
inner: sys::Socket::new(domain.0, type_.0, protocol)?,
})
}
#[cfg(all(unix, feature = "pair"))]
pub fn pair(
domain: Domain,
type_: Type,
protocol: Option<Protocol>,
) -> io::Result<(Socket, Socket)> {
let protocol = protocol.map(|p| p.0).unwrap_or(0);
let sockets = sys::Socket::pair(domain.0, type_.0, protocol)?;
Ok((Socket { inner: sockets.0 }, Socket { inner: sockets.1 }))
}
pub fn into_tcp_stream(self) -> net::TcpStream {
self.into()
}
pub fn into_tcp_listener(self) -> net::TcpListener {
self.into()
}
pub fn into_udp_socket(self) -> net::UdpSocket {
self.into()
}
#[cfg(all(unix, feature = "unix"))]
pub fn into_unix_stream(self) -> UnixStream {
self.into()
}
#[cfg(all(unix, feature = "unix"))]
pub fn into_unix_listener(self) -> UnixListener {
self.into()
}
#[cfg(all(unix, feature = "unix"))]
pub fn into_unix_datagram(self) -> UnixDatagram {
self.into()
}
pub fn connect(&self, addr: &SockAddr) -> io::Result<()> {
self.inner.connect(addr)
}
pub fn connect_timeout(&self, addr: &SockAddr, timeout: Duration) -> io::Result<()> {
self.inner.connect_timeout(addr, timeout)
}
pub fn bind(&self, addr: &SockAddr) -> io::Result<()> {
self.inner.bind(addr)
}
pub fn listen(&self, backlog: i32) -> io::Result<()> {
self.inner.listen(backlog)
}
pub fn accept(&self) -> io::Result<(Socket, SockAddr)> {
self.inner
.accept()
.map(|(socket, addr)| (Socket { inner: socket }, addr))
}
pub fn local_addr(&self) -> io::Result<SockAddr> {
self.inner.local_addr()
}
pub fn peer_addr(&self) -> io::Result<SockAddr> {
self.inner.peer_addr()
}
pub fn try_clone(&self) -> io::Result<Socket> {
self.inner.try_clone().map(|s| Socket { inner: s })
}
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
self.inner.take_error()
}
pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
self.inner.set_nonblocking(nonblocking)
}
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.inner.shutdown(how)
}
pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.recv(buf, 0)
}
pub fn recv_with_flags(&self, buf: &mut [u8], flags: i32) -> io::Result<usize> {
self.inner.recv(buf, flags)
}
pub fn recv_out_of_band(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.recv(buf, MSG_OOB)
}
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.peek(buf)
}
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> {
self.inner.recv_from(buf, 0)
}
pub fn recv_from_with_flags(
&self,
buf: &mut [u8],
flags: i32,
) -> io::Result<(usize, SockAddr)> {
self.inner.recv_from(buf, flags)
}
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> {
self.inner.peek_from(buf)
}
pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
self.inner.send(buf, 0)
}
pub fn send_with_flags(&self, buf: &[u8], flags: i32) -> io::Result<usize> {
self.inner.send(buf, flags)
}
pub fn send_out_of_band(&self, buf: &[u8]) -> io::Result<usize> {
self.inner.send(buf, MSG_OOB)
}
pub fn send_to(&self, buf: &[u8], addr: &SockAddr) -> io::Result<usize> {
self.inner.send_to(buf, 0, addr)
}
pub fn send_to_with_flags(&self, buf: &[u8], addr: &SockAddr, flags: i32) -> io::Result<usize> {
self.inner.send_to(buf, flags, addr)
}
pub fn ttl(&self) -> io::Result<u32> {
self.inner.ttl()
}
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.inner.set_ttl(ttl)
}
pub fn unicast_hops_v6(&self) -> io::Result<u32> {
self.inner.unicast_hops_v6()
}
pub fn set_unicast_hops_v6(&self, ttl: u32) -> io::Result<()> {
self.inner.set_unicast_hops_v6(ttl)
}
pub fn only_v6(&self) -> io::Result<bool> {
self.inner.only_v6()
}
pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
self.inner.set_only_v6(only_v6)
}
pub fn read_timeout(&self) -> io::Result<Option<Duration>> {
self.inner.read_timeout()
}
pub fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.inner.set_read_timeout(dur)
}
pub fn write_timeout(&self) -> io::Result<Option<Duration>> {
self.inner.write_timeout()
}
pub fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.inner.set_write_timeout(dur)
}
pub fn nodelay(&self) -> io::Result<bool> {
self.inner.nodelay()
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.inner.set_nodelay(nodelay)
}
pub fn broadcast(&self) -> io::Result<bool> {
self.inner.broadcast()
}
pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
self.inner.set_broadcast(broadcast)
}
pub fn multicast_loop_v4(&self) -> io::Result<bool> {
self.inner.multicast_loop_v4()
}
pub fn set_multicast_loop_v4(&self, multicast_loop_v4: bool) -> io::Result<()> {
self.inner.set_multicast_loop_v4(multicast_loop_v4)
}
pub fn multicast_ttl_v4(&self) -> io::Result<u32> {
self.inner.multicast_ttl_v4()
}
pub fn set_multicast_ttl_v4(&self, multicast_ttl_v4: u32) -> io::Result<()> {
self.inner.set_multicast_ttl_v4(multicast_ttl_v4)
}
pub fn multicast_hops_v6(&self) -> io::Result<u32> {
self.inner.multicast_hops_v6()
}
pub fn set_multicast_hops_v6(&self, hops: u32) -> io::Result<()> {
self.inner.set_multicast_hops_v6(hops)
}
pub fn multicast_if_v4(&self) -> io::Result<Ipv4Addr> {
self.inner.multicast_if_v4()
}
pub fn set_multicast_if_v4(&self, interface: &Ipv4Addr) -> io::Result<()> {
self.inner.set_multicast_if_v4(interface)
}
pub fn multicast_if_v6(&self) -> io::Result<u32> {
self.inner.multicast_if_v6()
}
pub fn set_multicast_if_v6(&self, interface: u32) -> io::Result<()> {
self.inner.set_multicast_if_v6(interface)
}
pub fn multicast_loop_v6(&self) -> io::Result<bool> {
self.inner.multicast_loop_v6()
}
pub fn set_multicast_loop_v6(&self, multicast_loop_v6: bool) -> io::Result<()> {
self.inner.set_multicast_loop_v6(multicast_loop_v6)
}
pub fn join_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> {
self.inner.join_multicast_v4(multiaddr, interface)
}
pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
self.inner.join_multicast_v6(multiaddr, interface)
}
pub fn leave_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> {
self.inner.leave_multicast_v4(multiaddr, interface)
}
pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
self.inner.leave_multicast_v6(multiaddr, interface)
}
pub fn linger(&self) -> io::Result<Option<Duration>> {
self.inner.linger()
}
pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
self.inner.set_linger(dur)
}
pub fn reuse_address(&self) -> io::Result<bool> {
self.inner.reuse_address()
}
pub fn set_reuse_address(&self, reuse: bool) -> io::Result<()> {
self.inner.set_reuse_address(reuse)
}
pub fn recv_buffer_size(&self) -> io::Result<usize> {
self.inner.recv_buffer_size()
}
pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
self.inner.set_recv_buffer_size(size)
}
pub fn send_buffer_size(&self) -> io::Result<usize> {
self.inner.send_buffer_size()
}
pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
self.inner.set_send_buffer_size(size)
}
pub fn keepalive(&self) -> io::Result<Option<Duration>> {
self.inner.keepalive()
}
pub fn set_keepalive(&self, keepalive: Option<Duration>) -> io::Result<()> {
self.inner.set_keepalive(keepalive)
}
pub fn out_of_band_inline(&self) -> io::Result<bool> {
self.inner.out_of_band_inline()
}
pub fn set_out_of_band_inline(&self, oob_inline: bool) -> io::Result<()> {
self.inner.set_out_of_band_inline(oob_inline)
}
#[cfg(all(
unix,
not(any(target_os = "solaris", target_os = "illumos")),
feature = "reuseport"
))]
pub fn reuse_port(&self) -> io::Result<bool> {
self.inner.reuse_port()
}
#[cfg(all(
unix,
not(any(target_os = "solaris", target_os = "illumos")),
feature = "reuseport"
))]
pub fn set_reuse_port(&self, reuse: bool) -> io::Result<()> {
self.inner.set_reuse_port(reuse)
}
}
impl Read for Socket {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf)
}
}
impl<'a> Read for &'a Socket {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
(&self.inner).read(buf)
}
}
impl Write for Socket {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl<'a> Write for &'a Socket {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
(&self.inner).write(buf)
}
fn flush(&mut self) -> io::Result<()> {
(&self.inner).flush()
}
}
impl fmt::Debug for Socket {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.inner.fmt(f)
}
}
impl From<net::TcpStream> for Socket {
fn from(socket: net::TcpStream) -> Socket {
Socket {
inner: socket.into(),
}
}
}
impl From<net::TcpListener> for Socket {
fn from(socket: net::TcpListener) -> Socket {
Socket {
inner: socket.into(),
}
}
}
impl From<net::UdpSocket> for Socket {
fn from(socket: net::UdpSocket) -> Socket {
Socket {
inner: socket.into(),
}
}
}
#[cfg(all(unix, feature = "unix"))]
impl From<UnixStream> for Socket {
fn from(socket: UnixStream) -> Socket {
Socket {
inner: socket.into(),
}
}
}
#[cfg(all(unix, feature = "unix"))]
impl From<UnixListener> for Socket {
fn from(socket: UnixListener) -> Socket {
Socket {
inner: socket.into(),
}
}
}
#[cfg(all(unix, feature = "unix"))]
impl From<UnixDatagram> for Socket {
fn from(socket: UnixDatagram) -> Socket {
Socket {
inner: socket.into(),
}
}
}
impl From<Socket> for net::TcpStream {
fn from(socket: Socket) -> net::TcpStream {
socket.inner.into()
}
}
impl From<Socket> for net::TcpListener {
fn from(socket: Socket) -> net::TcpListener {
socket.inner.into()
}
}
impl From<Socket> for net::UdpSocket {
fn from(socket: Socket) -> net::UdpSocket {
socket.inner.into()
}
}
#[cfg(all(unix, feature = "unix"))]
impl From<Socket> for UnixStream {
fn from(socket: Socket) -> UnixStream {
socket.inner.into()
}
}
#[cfg(all(unix, feature = "unix"))]
impl From<Socket> for UnixListener {
fn from(socket: Socket) -> UnixListener {
socket.inner.into()
}
}
#[cfg(all(unix, feature = "unix"))]
impl From<Socket> for UnixDatagram {
fn from(socket: Socket) -> UnixDatagram {
socket.inner.into()
}
}
#[cfg(test)]
mod test {
use std::net::SocketAddr;
use super::*;
#[test]
fn connect_timeout_unrouteable() {
let addr = "10.255.255.1:80".parse::<SocketAddr>().unwrap().into();
let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
match socket.connect_timeout(&addr, Duration::from_millis(250)) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {}
Err(e) => panic!("unexpected error {}", e),
}
}
#[test]
fn connect_timeout_unbound() {
let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
let addr = "127.0.0.1:0".parse::<SocketAddr>().unwrap().into();
socket.bind(&addr).unwrap();
let addr = socket.local_addr().unwrap();
drop(socket);
let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
match socket.connect_timeout(&addr, Duration::from_millis(250)) {
Ok(_) => panic!("unexpected success"),
Err(ref e)
if e.kind() == io::ErrorKind::ConnectionRefused
|| e.kind() == io::ErrorKind::TimedOut => {}
Err(e) => panic!("unexpected error {}", e),
}
}
#[test]
fn connect_timeout_valid() {
let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
socket
.bind(&"127.0.0.1:0".parse::<SocketAddr>().unwrap().into())
.unwrap();
socket.listen(128).unwrap();
let addr = socket.local_addr().unwrap();
let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
socket
.connect_timeout(&addr, Duration::from_millis(250))
.unwrap();
}
#[test]
#[cfg(all(unix, feature = "pair", feature = "unix"))]
fn pair() {
let (mut a, mut b) = Socket::pair(Domain::unix(), Type::stream(), None).unwrap();
a.write_all(b"hello world").unwrap();
let mut buf = [0; 11];
b.read_exact(&mut buf).unwrap();
assert_eq!(buf, &b"hello world"[..]);
}
#[test]
#[cfg(all(unix, feature = "unix"))]
fn unix() {
use tempdir::TempDir;
let dir = TempDir::new("unix").unwrap();
let addr = SockAddr::unix(dir.path().join("sock")).unwrap();
let listener = Socket::new(Domain::unix(), Type::stream(), None).unwrap();
listener.bind(&addr).unwrap();
listener.listen(10).unwrap();
let mut a = Socket::new(Domain::unix(), Type::stream(), None).unwrap();
a.connect(&addr).unwrap();
let mut b = listener.accept().unwrap().0;
a.write_all(b"hello world").unwrap();
let mut buf = [0; 11];
b.read_exact(&mut buf).unwrap();
assert_eq!(buf, &b"hello world"[..]);
}
#[test]
fn keepalive() {
let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
socket.set_keepalive(Some(Duration::from_secs(7))).unwrap();
#[cfg(unix)]
assert_eq!(socket.keepalive().unwrap(), Some(Duration::from_secs(7)));
socket.set_keepalive(None).unwrap();
#[cfg(unix)]
assert_eq!(socket.keepalive().unwrap(), None);
}
#[test]
fn nodelay() {
let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
assert!(socket.set_nodelay(true).is_ok());
let result = socket.nodelay();
assert!(result.is_ok());
assert!(result.unwrap());
}
#[test]
fn out_of_band_inline() {
let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
assert_eq!(socket.out_of_band_inline().unwrap(), false);
socket.set_out_of_band_inline(true).unwrap();
assert_eq!(socket.out_of_band_inline().unwrap(), true);
}
#[test]
#[cfg(any(target_os = "windows", target_os = "linux"))]
fn out_of_band_send_recv() {
let s1 = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
s1.bind(&"127.0.0.1:0".parse::<SocketAddr>().unwrap().into())
.unwrap();
let s1_addr = s1.local_addr().unwrap();
s1.listen(1).unwrap();
let s2 = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
s2.connect(&s1_addr).unwrap();
let (s3, _) = s1.accept().unwrap();
let mut buf = [0; 10];
s2.send(&mut buf).unwrap();
assert_eq!(s2.send_out_of_band(&mut [b"!"[0]]).unwrap(), 1);
assert_eq!(s3.recv_out_of_band(&mut buf).unwrap(), 1);
assert_eq!(buf[0], b"!"[0]);
assert_eq!(s3.recv(&mut buf).unwrap(), 10);
}
#[test]
fn tcp() {
let s1 = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
s1.bind(&"127.0.0.1:0".parse::<SocketAddr>().unwrap().into())
.unwrap();
let s1_addr = s1.local_addr().unwrap();
s1.listen(1).unwrap();
let s2 = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
s2.connect(&s1_addr).unwrap();
let (s3, _) = s1.accept().unwrap();
let mut buf = [0; 11];
assert_eq!(s2.send(&mut buf).unwrap(), 11);
assert_eq!(s3.recv(&mut buf).unwrap(), 11);
}
}