use super::NetSim;
use std::future;
use std::io;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
pub async fn lookup_host(host: impl ToSocketAddrs) -> io::Result<impl Iterator<Item = SocketAddr>> {
to_socket_addrs(host).await
}
pub trait ToSocketAddrs: sealed::ToSocketAddrsPriv + std::fmt::Debug {}
macro_rules! try_opt {
($e:expr, $msg:expr) => {
match $e {
Some(r) => r,
None => {
return MaybeReady(sealed::State::Err(Some(io::Error::new(
io::ErrorKind::InvalidInput,
$msg,
))))
}
}
};
}
type ReadyFuture<T> = future::Ready<io::Result<T>>;
pub(crate) fn to_socket_addrs<T>(arg: T) -> T::Future
where
T: ToSocketAddrs,
{
arg.to_socket_addrs(sealed::Internal)
}
impl<T: ToSocketAddrs + ?Sized> ToSocketAddrs for &T {}
impl<T> sealed::ToSocketAddrsPriv for &T
where
T: sealed::ToSocketAddrsPriv + ?Sized,
{
type Iter = T::Iter;
type Future = T::Future;
fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
(**self).to_socket_addrs(sealed::Internal)
}
}
impl ToSocketAddrs for SocketAddr {}
impl sealed::ToSocketAddrsPriv for SocketAddr {
type Iter = std::option::IntoIter<SocketAddr>;
type Future = ReadyFuture<Self::Iter>;
fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
let iter = Some(*self).into_iter();
future::ready(Ok(iter))
}
}
impl ToSocketAddrs for SocketAddrV4 {}
impl sealed::ToSocketAddrsPriv for SocketAddrV4 {
type Iter = std::option::IntoIter<SocketAddr>;
type Future = ReadyFuture<Self::Iter>;
fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
SocketAddr::V4(*self).to_socket_addrs(sealed::Internal)
}
}
impl ToSocketAddrs for SocketAddrV6 {}
impl sealed::ToSocketAddrsPriv for SocketAddrV6 {
type Iter = std::option::IntoIter<SocketAddr>;
type Future = ReadyFuture<Self::Iter>;
fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
SocketAddr::V6(*self).to_socket_addrs(sealed::Internal)
}
}
impl ToSocketAddrs for (IpAddr, u16) {}
impl sealed::ToSocketAddrsPriv for (IpAddr, u16) {
type Iter = std::option::IntoIter<SocketAddr>;
type Future = ReadyFuture<Self::Iter>;
fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
let iter = Some(SocketAddr::from(*self)).into_iter();
future::ready(Ok(iter))
}
}
impl ToSocketAddrs for (Ipv4Addr, u16) {}
impl sealed::ToSocketAddrsPriv for (Ipv4Addr, u16) {
type Iter = std::option::IntoIter<SocketAddr>;
type Future = ReadyFuture<Self::Iter>;
fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
let (ip, port) = *self;
SocketAddrV4::new(ip, port).to_socket_addrs(sealed::Internal)
}
}
impl ToSocketAddrs for (Ipv6Addr, u16) {}
impl sealed::ToSocketAddrsPriv for (Ipv6Addr, u16) {
type Iter = std::option::IntoIter<SocketAddr>;
type Future = ReadyFuture<Self::Iter>;
fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
let (ip, port) = *self;
SocketAddrV6::new(ip, port, 0, 0).to_socket_addrs(sealed::Internal)
}
}
impl ToSocketAddrs for &[SocketAddr] {}
impl sealed::ToSocketAddrsPriv for &[SocketAddr] {
type Iter = std::vec::IntoIter<SocketAddr>;
type Future = ReadyFuture<Self::Iter>;
fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
#[inline]
fn slice_to_vec(addrs: &[SocketAddr]) -> Vec<SocketAddr> {
addrs.to_vec()
}
let iter = slice_to_vec(self).into_iter();
future::ready(Ok(iter))
}
}
impl ToSocketAddrs for str {}
impl sealed::ToSocketAddrsPriv for str {
type Iter = sealed::OneOrMore;
type Future = sealed::MaybeReady;
fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
use sealed::MaybeReady;
let res: Result<SocketAddr, _> = self.parse();
if let Ok(addr) = res {
return MaybeReady(sealed::State::Ready(Some(addr)));
}
let (host, port_str) = try_opt!(self.rsplit_once(':'), "invalid socket address");
let port: u16 = try_opt!(port_str.parse().ok(), "invalid port value");
(host, port).to_socket_addrs(sealed::Internal)
}
}
impl ToSocketAddrs for (&str, u16) {}
impl sealed::ToSocketAddrsPriv for (&str, u16) {
type Iter = sealed::OneOrMore;
type Future = sealed::MaybeReady;
fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
use sealed::MaybeReady;
let (host, port) = *self;
if let Ok(addr) = host.parse::<Ipv4Addr>() {
let addr = SocketAddrV4::new(addr, port);
let addr = SocketAddr::V4(addr);
return MaybeReady(sealed::State::Ready(Some(addr)));
}
if let Ok(addr) = host.parse::<Ipv6Addr>() {
let addr = SocketAddrV6::new(addr, port, 0, 0);
let addr = SocketAddr::V6(addr);
return MaybeReady(sealed::State::Ready(Some(addr)));
}
let ip = try_opt!(NetSim::current().lookup_host(host), "couldn't resolve host");
let addr = SocketAddr::from((ip, port));
MaybeReady(sealed::State::Ready(Some(addr)))
}
}
impl ToSocketAddrs for (String, u16) {}
impl sealed::ToSocketAddrsPriv for (String, u16) {
type Iter = sealed::OneOrMore;
type Future = sealed::MaybeReady;
fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
(self.0.as_str(), self.1).to_socket_addrs(sealed::Internal)
}
}
impl ToSocketAddrs for String {}
impl sealed::ToSocketAddrsPriv for String {
type Iter = <str as sealed::ToSocketAddrsPriv>::Iter;
type Future = <str as sealed::ToSocketAddrsPriv>::Future;
fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
self.as_str().to_socket_addrs(sealed::Internal)
}
}
pub(crate) mod sealed {
use std::future::Future;
use std::io;
use std::net::SocketAddr;
#[doc(hidden)]
pub trait ToSocketAddrsPriv {
type Iter: Iterator<Item = SocketAddr> + Send + 'static;
type Future: Future<Output = io::Result<Self::Iter>> + Send + 'static;
fn to_socket_addrs(&self, internal: Internal) -> Self::Future;
}
#[allow(missing_debug_implementations)]
pub struct Internal;
use std::option;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::vec;
#[doc(hidden)]
#[derive(Debug)]
pub struct MaybeReady(pub(super) State);
#[derive(Debug)]
pub(super) enum State {
Ready(Option<SocketAddr>),
Err(Option<io::Error>),
}
#[doc(hidden)]
#[derive(Debug)]
pub enum OneOrMore {
One(option::IntoIter<SocketAddr>),
More(vec::IntoIter<SocketAddr>),
}
impl Future for MaybeReady {
type Output = io::Result<OneOrMore>;
fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.0 {
State::Ready(ref mut i) => {
let iter = OneOrMore::One(i.take().into_iter());
Poll::Ready(Ok(iter))
}
State::Err(ref mut e) => Poll::Ready(Err(e.take().unwrap())),
}
}
}
impl Iterator for OneOrMore {
type Item = SocketAddr;
fn next(&mut self) -> Option<Self::Item> {
match self {
OneOrMore::One(i) => i.next(),
OneOrMore::More(i) => i.next(),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
match self {
OneOrMore::One(i) => i.size_hint(),
OneOrMore::More(i) => i.size_hint(),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::Runtime;
#[test]
fn localhost() {
let runtime = Runtime::new();
runtime.block_on(async {
assert_eq!(
lookup_host("localhost:1").await.unwrap().next().unwrap(),
SocketAddr::from((Ipv4Addr::LOCALHOST, 1))
);
assert_eq!(
lookup_host(("localhost", 1)).await.unwrap().next().unwrap(),
SocketAddr::from((Ipv4Addr::LOCALHOST, 1))
);
});
}
#[test]
fn dns() {
let runtime = Runtime::new();
runtime.block_on(async {
NetSim::current().add_dns_record("madsim.io", Ipv4Addr::new(8, 8, 8, 8).into());
assert_eq!(
lookup_host("madsim.io:1").await.unwrap().next().unwrap(),
SocketAddr::from((Ipv4Addr::new(8, 8, 8, 8), 1))
);
assert!(lookup_host(("mad.io", 1)).await.is_err());
});
}
#[test]
fn dns_failed_to_resolve() {
let runtime = Runtime::new();
runtime.block_on(async {
NetSim::current().add_dns_record("madsim.io", Ipv4Addr::new(8, 8, 8, 8).into());
assert_eq!(
lookup_host("madsim.io:1").await.unwrap().next().unwrap(),
SocketAddr::from((Ipv4Addr::new(8, 8, 8, 8), 1))
);
assert!(lookup_host("mad.io").await.is_err());
assert!(lookup_host("madsim.io:65536").await.is_err());
assert!(lookup_host("madsim:io").await.is_err());
});
}
}