use std::{
net::{IpAddr, SocketAddr},
time::Duration,
};
use tracing::{debug, info};
const DNS_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Debug)]
pub enum ResolveError {
NoResults(String),
LookupFailed(std::io::Error),
Timeout(String),
}
impl std::fmt::Display for ResolveError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ResolveError::NoResults(host) => {
write!(f, "DNS resolution returned no results for '{host}'")
}
ResolveError::LookupFailed(e) => write!(f, "DNS resolution failed: {e}"),
ResolveError::Timeout(host) => {
write!(
f,
"DNS resolution for '{host}' timed out after {}s",
DNS_TIMEOUT.as_secs()
)
}
}
}
}
impl std::error::Error for ResolveError {}
pub async fn resolve_host(host: &str, port: u16) -> Result<SocketAddr, ResolveError> {
if let Ok(ip) = host.parse::<IpAddr>() {
return Ok(SocketAddr::new(ip, port));
}
info!("Resolving hostname '{host}' via DNS...");
let lookup = format!("{host}:{port}");
let addr = tokio::time::timeout(DNS_TIMEOUT, tokio::net::lookup_host(&lookup))
.await
.map_err(|_| ResolveError::Timeout(host.to_string()))?
.map_err(ResolveError::LookupFailed)?
.next()
.ok_or_else(|| ResolveError::NoResults(host.to_string()))?;
debug!("Resolved '{host}' -> {addr}");
Ok(addr)
}
pub async fn resolve_host_port(addr: &str) -> Result<SocketAddr, ResolveError> {
if let Ok(socket) = addr.parse::<SocketAddr>() {
return Ok(socket);
}
info!("Resolving address '{addr}' via DNS...");
let resolved = tokio::time::timeout(DNS_TIMEOUT, tokio::net::lookup_host(addr))
.await
.map_err(|_| ResolveError::Timeout(addr.to_string()))?
.map_err(ResolveError::LookupFailed)?
.next()
.ok_or_else(|| ResolveError::NoResults(addr.to_string()))?;
debug!("Resolved '{addr}' -> {resolved}");
Ok(resolved)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn resolve_ipv4_address() {
let addr = resolve_host("127.0.0.1", 3333).await.unwrap();
assert_eq!(addr, SocketAddr::new("127.0.0.1".parse().unwrap(), 3333));
}
#[tokio::test]
async fn resolve_ipv6_address() {
let addr = resolve_host("::1", 3333).await.unwrap();
assert_eq!(addr, SocketAddr::new("::1".parse().unwrap(), 3333));
}
#[tokio::test]
async fn resolve_localhost_hostname() {
let addr = resolve_host("localhost", 3333).await.unwrap();
assert_eq!(addr.port(), 3333);
assert!(addr.ip().is_loopback());
}
#[tokio::test]
async fn resolve_invalid_hostname_fails() {
let result = resolve_host("this.hostname.definitely.does.not.exist.invalid", 3333).await;
assert!(result.is_err());
}
#[tokio::test]
async fn resolve_host_port_ipv4() {
let addr = resolve_host_port("127.0.0.1:3333").await.unwrap();
assert_eq!(addr, SocketAddr::new("127.0.0.1".parse().unwrap(), 3333));
}
#[tokio::test]
async fn resolve_host_port_localhost() {
let addr = resolve_host_port("localhost:3333").await.unwrap();
assert_eq!(addr.port(), 3333);
assert!(addr.ip().is_loopback());
}
#[tokio::test]
async fn resolve_host_port_invalid_fails() {
let result =
resolve_host_port("this.hostname.definitely.does.not.exist.invalid:3333").await;
assert!(result.is_err());
}
}