use anyhow::Context;
use std::net::{IpAddr, SocketAddr};
pub(crate) fn resolve(addr: Option<&str>, default: &str) -> anyhow::Result<SocketAddr> {
use std::net::ToSocketAddrs;
addr.unwrap_or(default)
.to_socket_addrs()
.context("invalid address")?
.next()
.context("no addresses resolved")
}
pub(crate) fn pick_addr(addrs: impl IntoIterator<Item = SocketAddr>, local: SocketAddr) -> Option<SocketAddr> {
let mut converted = None;
let mut other = None;
for addr in addrs {
if addr.is_ipv4() == local.is_ipv4() {
return Some(addr);
}
let normalized = normalize_family(addr, local);
if normalized.is_ipv4() == local.is_ipv4() {
if converted.is_none() {
converted = Some(normalized);
}
} else if other.is_none() {
other = Some(addr);
}
}
converted.or(other)
}
fn normalize_family(addr: SocketAddr, local: SocketAddr) -> SocketAddr {
match (addr, local.is_ipv4()) {
(SocketAddr::V6(v6), true) => match v6.ip().to_ipv4_mapped() {
Some(v4) => SocketAddr::new(IpAddr::V4(v4), v6.port()),
None => addr,
},
(SocketAddr::V4(v4), false) => SocketAddr::new(IpAddr::V6(v4.ip().to_ipv6_mapped()), v4.port()),
_ => addr,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolves_socket_literal() {
let addr = resolve(Some("[::]:0"), "[::]:443").unwrap();
assert!(addr.ip().is_unspecified());
assert_eq!(addr.port(), 0);
}
#[test]
fn resolves_dns_hostname() {
let addr = resolve(Some("localhost:0"), "[::]:443").unwrap();
assert!(addr.ip().is_loopback());
assert_eq!(addr.port(), 0);
}
#[test]
fn falls_back_to_default() {
let addr = resolve(None, "127.0.0.1:1234").unwrap();
assert_eq!(addr.ip().to_string(), "127.0.0.1");
assert_eq!(addr.port(), 1234);
}
#[test]
fn pick_addr_prefers_matching_family() {
let v4: SocketAddr = "127.0.0.1:443".parse().unwrap();
let v6: SocketAddr = "[::1]:443".parse().unwrap();
let local_v4: SocketAddr = "0.0.0.0:0".parse().unwrap();
let local_v6: SocketAddr = "[::]:0".parse().unwrap();
assert_eq!(pick_addr([v6, v4], local_v4), Some(v4));
assert_eq!(pick_addr([v4, v6], local_v6), Some(v6));
}
#[test]
fn pick_addr_wraps_v4_for_v6_socket() {
let v4: SocketAddr = "127.0.0.1:443".parse().unwrap();
let mapped: SocketAddr = "[::ffff:127.0.0.1]:443".parse().unwrap();
let local_v6: SocketAddr = "[::]:0".parse().unwrap();
assert_eq!(pick_addr([v4], local_v6), Some(mapped));
}
#[test]
fn pick_addr_unwraps_v4_mapped_for_v4_socket() {
let mapped: SocketAddr = "[::ffff:127.0.0.1]:443".parse().unwrap();
let v4: SocketAddr = "127.0.0.1:443".parse().unwrap();
let local_v4: SocketAddr = "0.0.0.0:0".parse().unwrap();
assert_eq!(pick_addr([mapped], local_v4), Some(v4));
}
#[test]
fn pick_addr_falls_back_for_unmappable_v6() {
let v6: SocketAddr = "[2001:db8::1]:443".parse().unwrap();
let local_v4: SocketAddr = "0.0.0.0:0".parse().unwrap();
assert_eq!(pick_addr([v6], local_v4), Some(v6));
}
#[test]
fn pick_addr_empty() {
let local: SocketAddr = "0.0.0.0:0".parse().unwrap();
assert_eq!(pick_addr(std::iter::empty(), local), None);
}
}