use indexmap::IndexMap;
#[cfg(feature = "regex")]
use regex::Regex;
use std::io;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use crate::ip::IpVersionAddrIter;
pub struct Dns {
addrs: IpVersionAddrIter,
names: IndexMap<String, IpAddr>,
}
pub trait ToIpAddr: sealed::Sealed {
#[doc(hidden)]
fn to_ip_addr(&self, dns: &mut Dns) -> IpAddr;
}
pub trait ToIpAddrs: sealed::Sealed {
#[doc(hidden)]
fn to_ip_addrs(&self, dns: &mut Dns) -> Vec<IpAddr>;
}
pub trait ToSocketAddrs: sealed::Sealed {
#[doc(hidden)]
fn to_socket_addr(&self, dns: &Dns) -> io::Result<SocketAddr>;
}
impl Dns {
pub(crate) fn new(addrs: IpVersionAddrIter) -> Dns {
Dns {
addrs,
names: IndexMap::new(),
}
}
pub(crate) fn lookup(&mut self, addr: impl ToIpAddr) -> IpAddr {
addr.to_ip_addr(self)
}
pub(crate) fn lookup_many(&mut self, addrs: impl ToIpAddrs) -> Vec<IpAddr> {
addrs.to_ip_addrs(self)
}
pub(crate) fn reverse(&self, addr: IpAddr) -> Option<&str> {
self.names
.iter()
.find(|(_, a)| **a == addr)
.map(|(name, _)| name.as_str())
}
}
impl ToIpAddr for String {
fn to_ip_addr(&self, dns: &mut Dns) -> IpAddr {
(&self[..]).to_ip_addr(dns)
}
}
impl ToIpAddr for &str {
fn to_ip_addr(&self, dns: &mut Dns) -> IpAddr {
if let Ok(ipaddr) = self.parse() {
return ipaddr;
}
*dns.names
.entry(self.to_string())
.or_insert_with(|| dns.addrs.next())
}
}
impl ToIpAddr for IpAddr {
fn to_ip_addr(&self, _: &mut Dns) -> IpAddr {
*self
}
}
impl ToIpAddr for Ipv4Addr {
fn to_ip_addr(&self, _: &mut Dns) -> IpAddr {
IpAddr::V4(*self)
}
}
impl ToIpAddr for Ipv6Addr {
fn to_ip_addr(&self, _: &mut Dns) -> IpAddr {
IpAddr::V6(*self)
}
}
impl<T> ToIpAddrs for T
where
T: ToIpAddr,
{
fn to_ip_addrs(&self, dns: &mut Dns) -> Vec<IpAddr> {
vec![self.to_ip_addr(dns)]
}
}
#[cfg(feature = "regex")]
impl ToIpAddrs for Regex {
fn to_ip_addrs(&self, dns: &mut Dns) -> Vec<IpAddr> {
#[allow(clippy::needless_collect)]
let hosts = dns.names.keys().cloned().collect::<Vec<_>>();
hosts
.into_iter()
.filter_map(|h| self.is_match(&h).then(|| h.to_ip_addr(dns)))
.collect::<Vec<_>>()
}
}
impl ToSocketAddrs for (String, u16) {
fn to_socket_addr(&self, dns: &Dns) -> io::Result<SocketAddr> {
(&self.0[..], self.1).to_socket_addr(dns)
}
}
impl ToSocketAddrs for (&str, u16) {
fn to_socket_addr(&self, dns: &Dns) -> io::Result<SocketAddr> {
if let Ok(ip) = self.0.parse::<IpAddr>() {
return Ok((ip, self.1).into());
}
match dns.names.get(self.0) {
Some(ip) => Ok((*ip, self.1).into()),
None => Err(io::Error::new(
io::ErrorKind::NotFound,
format!("no ip address found for a hostname: {}", self.0),
)),
}
}
}
impl ToSocketAddrs for SocketAddr {
fn to_socket_addr(&self, _: &Dns) -> io::Result<SocketAddr> {
Ok(*self)
}
}
impl ToSocketAddrs for SocketAddrV4 {
fn to_socket_addr(&self, _: &Dns) -> io::Result<SocketAddr> {
Ok(SocketAddr::V4(*self))
}
}
impl ToSocketAddrs for SocketAddrV6 {
fn to_socket_addr(&self, _: &Dns) -> io::Result<SocketAddr> {
Ok(SocketAddr::V6(*self))
}
}
impl ToSocketAddrs for (IpAddr, u16) {
fn to_socket_addr(&self, _: &Dns) -> io::Result<SocketAddr> {
Ok((*self).into())
}
}
impl ToSocketAddrs for (Ipv4Addr, u16) {
fn to_socket_addr(&self, _: &Dns) -> io::Result<SocketAddr> {
Ok((*self).into())
}
}
impl ToSocketAddrs for (Ipv6Addr, u16) {
fn to_socket_addr(&self, _: &Dns) -> io::Result<SocketAddr> {
Ok((*self).into())
}
}
impl<T: ToSocketAddrs + ?Sized> ToSocketAddrs for &T {
fn to_socket_addr(&self, dns: &Dns) -> io::Result<SocketAddr> {
(**self).to_socket_addr(dns)
}
}
impl ToSocketAddrs for str {
fn to_socket_addr(&self, dns: &Dns) -> io::Result<SocketAddr> {
let socketaddr: Result<SocketAddr, _> = self.parse();
if let Ok(s) = socketaddr {
return Ok(s);
}
macro_rules! try_opt {
($e:expr, $msg:expr) => {
match $e {
Some(r) => r,
None => return Err(io::Error::new(io::ErrorKind::InvalidInput, $msg)),
}
};
}
let (host, port_str) = try_opt!(self.rsplit_once(':'), "invalid socket address");
let port: u16 = try_opt!(port_str.parse().ok(), "invalid port value");
(host, port).to_socket_addr(dns)
}
}
impl ToSocketAddrs for String {
fn to_socket_addr(&self, dns: &Dns) -> io::Result<SocketAddr> {
self.as_str().to_socket_addr(dns)
}
}
mod sealed {
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
#[cfg(test)]
mod tests {
use crate::{dns::Dns, ip::IpVersionAddrIter, ToSocketAddrs};
use std::net::Ipv4Addr;
#[test]
fn parse_str() {
let mut dns = Dns::new(IpVersionAddrIter::default());
let generated_addr = dns.lookup("foo");
let hostname_port = "foo:5000";
let ipv4_port = "127.0.0.1:5000";
let ipv6_port = "[::1]:5000";
assert_eq!(
hostname_port.to_socket_addr(&dns).unwrap(),
format!("{generated_addr}:5000").parse().unwrap()
);
assert_eq!(
ipv4_port.to_socket_addr(&dns).unwrap(),
ipv4_port.parse().unwrap()
);
assert_eq!(
ipv6_port.to_socket_addr(&dns).unwrap(),
ipv6_port.parse().unwrap()
);
}
#[test]
fn raw_value_parsing() {
let mut dns = Dns::new(IpVersionAddrIter::default());
let addr = dns.lookup(Ipv4Addr::new(192, 168, 2, 2));
assert_eq!(addr, Ipv4Addr::new(192, 168, 2, 2));
let addr = dns.lookup("192.168.3.3");
assert_eq!(addr, Ipv4Addr::new(192, 168, 3, 3));
let addr = "192.168.3.3:0".to_socket_addr(&dns).unwrap();
assert_eq!(addr.ip(), Ipv4Addr::new(192, 168, 3, 3));
}
}