use std::hash::Hash;
use std::mem::{self, size_of};
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
#[cfg(not(target_os = "wasi"))]
use std::path::Path;
use std::{fmt, io, ptr};
#[cfg(windows)]
use windows_sys::Win32::Networking::WinSock::SOCKADDR_IN6_0;
#[cfg(not(target_os = "wasi"))]
use crate::sys::AF_UNIX;
use crate::sys::{c_int, sockaddr_in, sockaddr_in6, sockaddr_storage, AF_INET, AF_INET6};
use crate::Domain;
#[allow(non_camel_case_types)]
pub type socklen_t = crate::sys::socklen_t;
#[allow(non_camel_case_types)]
pub type sa_family_t = crate::sys::sa_family_t;
#[repr(transparent)]
pub struct SockAddrStorage {
storage: sockaddr_storage,
}
impl SockAddrStorage {
#[inline]
pub fn zeroed() -> Self {
unsafe { mem::zeroed() }
}
#[inline]
pub fn size_of(&self) -> socklen_t {
size_of::<Self>() as socklen_t
}
#[inline]
pub unsafe fn view_as<T>(&mut self) -> &mut T {
assert!(size_of::<T>() <= size_of::<Self>());
&mut *(self as *mut Self as *mut T)
}
}
impl std::fmt::Debug for SockAddrStorage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("sockaddr_storage")
.field("ss_family", &self.storage.ss_family)
.finish_non_exhaustive()
}
}
#[derive(Clone)]
pub struct SockAddr {
storage: sockaddr_storage,
len: socklen_t,
}
#[allow(clippy::len_without_is_empty)]
impl SockAddr {
pub const unsafe fn new(storage: SockAddrStorage, len: socklen_t) -> SockAddr {
SockAddr {
storage: storage.storage,
len: len as socklen_t,
}
}
pub unsafe fn try_init<F, T>(init: F) -> io::Result<(T, SockAddr)>
where
F: FnOnce(*mut SockAddrStorage, *mut socklen_t) -> io::Result<T>,
{
const STORAGE_SIZE: socklen_t = size_of::<sockaddr_storage>() as socklen_t;
let mut storage = SockAddrStorage::zeroed();
let mut len = STORAGE_SIZE;
init(&mut storage, &mut len).map(|res| {
debug_assert!(len <= STORAGE_SIZE, "overflown address storage");
(res, SockAddr::new(storage, len))
})
}
#[cfg(not(target_os = "wasi"))]
pub fn unix<P>(path: P) -> io::Result<SockAddr>
where
P: AsRef<Path>,
{
crate::sys::unix_sockaddr(path.as_ref())
}
pub unsafe fn set_length(&mut self, length: socklen_t) {
self.len = length;
}
pub const fn family(&self) -> sa_family_t {
self.storage.ss_family
}
pub const fn domain(&self) -> Domain {
Domain(self.storage.ss_family as c_int)
}
pub const fn len(&self) -> socklen_t {
self.len
}
pub const fn as_ptr(&self) -> *const SockAddrStorage {
&self.storage as *const sockaddr_storage as *const SockAddrStorage
}
pub const fn as_storage(self) -> SockAddrStorage {
SockAddrStorage {
storage: self.storage,
}
}
pub const fn is_ipv4(&self) -> bool {
self.storage.ss_family == AF_INET as sa_family_t
}
pub const fn is_ipv6(&self) -> bool {
self.storage.ss_family == AF_INET6 as sa_family_t
}
#[cfg(not(target_os = "wasi"))]
pub fn is_unix(&self) -> bool {
self.storage.ss_family == AF_UNIX as sa_family_t
}
pub fn as_socket(&self) -> Option<SocketAddr> {
if self.storage.ss_family == AF_INET as sa_family_t {
let addr = unsafe { &*(ptr::addr_of!(self.storage).cast::<sockaddr_in>()) };
let ip = crate::sys::from_in_addr(addr.sin_addr);
let port = u16::from_be(addr.sin_port);
Some(SocketAddr::V4(SocketAddrV4::new(ip, port)))
} else if self.storage.ss_family == AF_INET6 as sa_family_t {
let addr = unsafe { &*(ptr::addr_of!(self.storage).cast::<sockaddr_in6>()) };
let ip = crate::sys::from_in6_addr(addr.sin6_addr);
let port = u16::from_be(addr.sin6_port);
Some(SocketAddr::V6(SocketAddrV6::new(
ip,
port,
addr.sin6_flowinfo,
#[cfg(any(unix, all(target_os = "wasi", not(target_env = "p1"))))]
addr.sin6_scope_id,
#[cfg(windows)]
unsafe {
addr.Anonymous.sin6_scope_id
},
)))
} else {
None
}
}
pub fn as_socket_ipv4(&self) -> Option<SocketAddrV4> {
match self.as_socket() {
Some(SocketAddr::V4(addr)) => Some(addr),
_ => None,
}
}
pub fn as_socket_ipv6(&self) -> Option<SocketAddrV6> {
match self.as_socket() {
Some(SocketAddr::V6(addr)) => Some(addr),
_ => None,
}
}
fn as_bytes(&self) -> &[u8] {
unsafe { std::slice::from_raw_parts(self.as_ptr().cast(), self.len as usize) }
}
}
impl From<SocketAddr> for SockAddr {
fn from(addr: SocketAddr) -> SockAddr {
match addr {
SocketAddr::V4(addr) => addr.into(),
SocketAddr::V6(addr) => addr.into(),
}
}
}
impl From<SocketAddrV4> for SockAddr {
fn from(addr: SocketAddrV4) -> SockAddr {
let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
let len = {
let storage = unsafe { &mut *ptr::addr_of_mut!(storage).cast::<sockaddr_in>() };
storage.sin_family = AF_INET as sa_family_t;
storage.sin_port = addr.port().to_be();
storage.sin_addr = crate::sys::to_in_addr(addr.ip());
#[cfg(not(target_os = "wasi"))]
{
storage.sin_zero = Default::default();
}
mem::size_of::<sockaddr_in>() as socklen_t
};
#[cfg(any(
target_os = "dragonfly",
target_os = "freebsd",
target_os = "haiku",
target_os = "hermit",
target_os = "ios",
target_os = "visionos",
target_os = "macos",
target_os = "netbsd",
target_os = "nto",
target_os = "openbsd",
target_os = "tvos",
target_os = "vxworks",
target_os = "watchos",
))]
{
storage.ss_len = len as u8;
}
SockAddr { storage, len }
}
}
impl From<SocketAddrV6> for SockAddr {
fn from(addr: SocketAddrV6) -> SockAddr {
let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
let len = {
let storage = unsafe { &mut *ptr::addr_of_mut!(storage).cast::<sockaddr_in6>() };
storage.sin6_family = AF_INET6 as sa_family_t;
storage.sin6_port = addr.port().to_be();
storage.sin6_addr = crate::sys::to_in6_addr(addr.ip());
storage.sin6_flowinfo = addr.flowinfo();
#[cfg(any(unix, all(target_os = "wasi", not(target_env = "p1"))))]
{
storage.sin6_scope_id = addr.scope_id();
}
#[cfg(windows)]
{
storage.Anonymous = SOCKADDR_IN6_0 {
sin6_scope_id: addr.scope_id(),
};
}
mem::size_of::<sockaddr_in6>() as socklen_t
};
#[cfg(any(
target_os = "dragonfly",
target_os = "freebsd",
target_os = "haiku",
target_os = "hermit",
target_os = "ios",
target_os = "visionos",
target_os = "macos",
target_os = "netbsd",
target_os = "nto",
target_os = "openbsd",
target_os = "tvos",
target_os = "vxworks",
target_os = "watchos",
))]
{
storage.ss_len = len as u8;
}
SockAddr { storage, len }
}
}
impl fmt::Debug for SockAddr {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut f = fmt.debug_struct("SockAddr");
#[cfg(any(
target_os = "dragonfly",
target_os = "freebsd",
target_os = "haiku",
target_os = "hermit",
target_os = "ios",
target_os = "visionos",
target_os = "macos",
target_os = "netbsd",
target_os = "nto",
target_os = "openbsd",
target_os = "tvos",
target_os = "vxworks",
target_os = "watchos",
))]
f.field("ss_len", &self.storage.ss_len);
f.field("ss_family", &self.storage.ss_family)
.field("len", &self.len)
.finish()
}
}
impl PartialEq for SockAddr {
fn eq(&self, other: &Self) -> bool {
self.as_bytes() == other.as_bytes()
}
}
impl Eq for SockAddr {}
impl Hash for SockAddr {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.as_bytes().hash(state);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ipv4() {
use std::net::Ipv4Addr;
let std = SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 9876);
let addr = SockAddr::from(std);
assert!(addr.is_ipv4());
assert!(!addr.is_ipv6());
#[cfg(not(target_os = "wasi"))]
assert!(!addr.is_unix());
assert_eq!(addr.family(), AF_INET as sa_family_t);
assert_eq!(addr.domain(), Domain::IPV4);
assert_eq!(addr.len(), size_of::<sockaddr_in>() as socklen_t);
assert_eq!(addr.as_socket(), Some(SocketAddr::V4(std)));
assert_eq!(addr.as_socket_ipv4(), Some(std));
assert!(addr.as_socket_ipv6().is_none());
let addr = SockAddr::from(SocketAddr::from(std));
assert_eq!(addr.family(), AF_INET as sa_family_t);
assert_eq!(addr.len(), size_of::<sockaddr_in>() as socklen_t);
assert_eq!(addr.as_socket(), Some(SocketAddr::V4(std)));
assert_eq!(addr.as_socket_ipv4(), Some(std));
assert!(addr.as_socket_ipv6().is_none());
#[cfg(all(unix, not(target_os = "wasi")))]
{
assert!(addr.as_pathname().is_none());
assert!(addr.as_abstract_namespace().is_none());
}
}
#[test]
fn ipv6() {
use std::net::Ipv6Addr;
let std = SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 9876, 11, 12);
let addr = SockAddr::from(std);
assert!(addr.is_ipv6());
assert!(!addr.is_ipv4());
#[cfg(not(target_os = "wasi"))]
assert!(!addr.is_unix());
assert_eq!(addr.family(), AF_INET6 as sa_family_t);
assert_eq!(addr.domain(), Domain::IPV6);
assert_eq!(addr.len(), size_of::<sockaddr_in6>() as socklen_t);
assert_eq!(addr.as_socket(), Some(SocketAddr::V6(std)));
assert!(addr.as_socket_ipv4().is_none());
assert_eq!(addr.as_socket_ipv6(), Some(std));
let addr = SockAddr::from(SocketAddr::from(std));
assert_eq!(addr.family(), AF_INET6 as sa_family_t);
assert_eq!(addr.len(), size_of::<sockaddr_in6>() as socklen_t);
assert_eq!(addr.as_socket(), Some(SocketAddr::V6(std)));
assert!(addr.as_socket_ipv4().is_none());
assert_eq!(addr.as_socket_ipv6(), Some(std));
#[cfg(all(unix, not(target_os = "wasi")))]
{
assert!(addr.as_pathname().is_none());
assert!(addr.as_abstract_namespace().is_none());
}
}
#[test]
fn ipv4_eq() {
use std::net::Ipv4Addr;
let std1 = SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 9876);
let std2 = SocketAddrV4::new(Ipv4Addr::new(5, 6, 7, 8), 8765);
test_eq(
SockAddr::from(std1),
SockAddr::from(std1),
SockAddr::from(std2),
);
}
#[test]
fn ipv4_hash() {
use std::net::Ipv4Addr;
let std1 = SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 9876);
let std2 = SocketAddrV4::new(Ipv4Addr::new(5, 6, 7, 8), 8765);
test_hash(
SockAddr::from(std1),
SockAddr::from(std1),
SockAddr::from(std2),
);
}
#[test]
fn ipv6_eq() {
use std::net::Ipv6Addr;
let std1 = SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 9876, 11, 12);
let std2 = SocketAddrV6::new(Ipv6Addr::new(3, 4, 5, 6, 7, 8, 9, 0), 7654, 13, 14);
test_eq(
SockAddr::from(std1),
SockAddr::from(std1),
SockAddr::from(std2),
);
}
#[test]
fn ipv6_hash() {
use std::net::Ipv6Addr;
let std1 = SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 9876, 11, 12);
let std2 = SocketAddrV6::new(Ipv6Addr::new(3, 4, 5, 6, 7, 8, 9, 0), 7654, 13, 14);
test_hash(
SockAddr::from(std1),
SockAddr::from(std1),
SockAddr::from(std2),
);
}
#[test]
fn ipv4_ipv6_eq() {
use std::net::Ipv4Addr;
use std::net::Ipv6Addr;
let std1 = SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 9876);
let std2 = SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 9876, 11, 12);
test_eq(
SockAddr::from(std1),
SockAddr::from(std1),
SockAddr::from(std2),
);
test_eq(
SockAddr::from(std2),
SockAddr::from(std2),
SockAddr::from(std1),
);
}
#[test]
fn ipv4_ipv6_hash() {
use std::net::Ipv4Addr;
use std::net::Ipv6Addr;
let std1 = SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 9876);
let std2 = SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 9876, 11, 12);
test_hash(
SockAddr::from(std1),
SockAddr::from(std1),
SockAddr::from(std2),
);
test_hash(
SockAddr::from(std2),
SockAddr::from(std2),
SockAddr::from(std1),
);
}
#[allow(clippy::eq_op)] fn test_eq(a0: SockAddr, a1: SockAddr, b: SockAddr) {
assert!(a0 == a0);
assert!(a0 == a1);
assert!(a1 == a0);
assert!(a0 != b);
assert!(b != a0);
}
fn test_hash(a0: SockAddr, a1: SockAddr, b: SockAddr) {
assert!(calculate_hash(&a0) == calculate_hash(&a0));
assert!(calculate_hash(&a0) == calculate_hash(&a1));
assert!(calculate_hash(&a0) != calculate_hash(&b));
}
fn calculate_hash(x: &SockAddr) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::Hasher;
let mut hasher = DefaultHasher::new();
x.hash(&mut hasher);
hasher.finish()
}
}