use core::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
str::FromStr,
};
#[cfg(not(any(feature = "alloc", feature = "std")))]
use core::str::from_utf8;
#[cfg(any(feature = "alloc", feature = "std"))]
use simdutf8::basic::from_utf8;
pub use derive_more::TryUnwrapError;
use super::Domain;
#[derive(
Clone,
Copy,
Debug,
Eq,
PartialEq,
PartialOrd,
Ord,
Hash,
derive_more::Display,
derive_more::IsVariant,
derive_more::Unwrap,
)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
#[unwrap(ref, ref_mut)]
pub enum Host<S> {
Ip(IpAddr),
Domain(S),
}
#[cfg(feature = "cheap-clone")]
impl<S: cheap_clone::CheapClone> cheap_clone::CheapClone for Host<S> {}
impl<'a> Host<&'a str> {
#[inline]
pub fn try_from_ascii_str(input: &'a str) -> Result<Self, ParseAsciiHostError> {
if let Ok(ip) = input.parse() {
return Ok(Host::Ip(ip));
}
Domain::try_from_ascii_str(input)
.map(|d| Host::Domain(d.as_ref().0))
.map_err(|_| ParseAsciiHostError(()))
}
#[inline]
pub const fn as_bytes(&self) -> Host<&'a [u8]> {
match self {
Self::Ip(ip) => Host::Ip(*ip),
Self::Domain(domain) => Host::Domain(domain.as_bytes()),
}
}
}
impl<'a> Host<&'a [u8]> {
#[inline]
pub fn try_from_ascii_bytes(input: &'a [u8]) -> Result<Self, ParseAsciiHostError> {
let input_str = from_utf8(input).map_err(|_| ParseAsciiHostError(()))?;
if let Ok(ip) = input_str.parse() {
return Ok(Host::Ip(ip));
}
Domain::try_from_ascii_bytes(input)
.map(|d| Host::Domain(d.as_ref().0))
.map_err(|_| ParseAsciiHostError(()))
}
#[inline]
pub const fn as_str(&self) -> Host<&'a str> {
match self {
Self::Ip(ip) => Host::Ip(*ip),
Self::Domain(domain) => match core::str::from_utf8(domain) {
Ok(domain) => Host::Domain(domain),
Err(_) => panic!("A Host<&str> should always be valid UTF-8"),
},
}
}
}
impl<S> From<Domain<S>> for Host<S> {
fn from(value: Domain<S>) -> Self {
Self::Domain(value.0)
}
}
impl<S> From<IpAddr> for Host<S> {
fn from(value: IpAddr) -> Self {
Self::from_ip(value)
}
}
impl<S> From<Ipv4Addr> for Host<S> {
fn from(value: Ipv4Addr) -> Self {
Self::from(IpAddr::V4(value))
}
}
impl<S> From<Ipv6Addr> for Host<S> {
fn from(value: Ipv6Addr) -> Self {
Self::from(IpAddr::V6(value))
}
}
impl<S> Host<S> {
#[inline]
pub const fn from_ip(ip: IpAddr) -> Self {
Self::Ip(ip)
}
#[inline]
pub const fn is_ipv4(&self) -> bool {
matches!(self, Host::Ip(IpAddr::V4(_)))
}
#[inline]
pub const fn is_ipv6(&self) -> bool {
matches!(self, Host::Ip(IpAddr::V6(_)))
}
#[inline]
pub const fn as_ref(&self) -> Host<&S> {
match self {
Host::Domain(domain) => Host::Domain(domain),
Host::Ip(ip) => Host::Ip(*ip),
}
}
#[inline]
pub fn as_deref(&self) -> Host<&S::Target>
where
S: core::ops::Deref,
{
match self {
Host::Domain(domain) => Host::Domain(core::ops::Deref::deref(domain)),
Host::Ip(ip) => Host::Ip(*ip),
}
}
#[inline]
pub const fn ip(&self) -> Option<&IpAddr> {
match self {
Host::Ip(ip) => Some(ip),
_ => None,
}
}
#[inline]
pub const fn domain(&self) -> Option<&S> {
match self {
Host::Domain(domain) => Some(domain),
_ => None,
}
}
}
impl<S> Host<&S> {
#[inline]
pub const fn copied(self) -> Host<S>
where
S: Copy,
{
match self {
Self::Domain(domain) => Host::Domain(*domain),
Self::Ip(ip) => Host::Ip(ip),
}
}
#[inline]
pub fn cloned(self) -> Host<S>
where
S: Clone,
{
match self {
Self::Domain(domain) => Host::Domain(domain.clone()),
Self::Ip(ip) => Host::Ip(ip),
}
}
}
macro_rules! try_from_str {
($convert:ident($s:ident)) => {{
if let Ok(ip) = $s.parse() {
return Ok(Host::Ip(ip));
}
$s.$convert()
.map(|d: Domain<S>| Host::Domain(d.0))
.map_err(|_| ParseHostError(()))
}};
}
impl<S> FromStr for Host<S>
where
Domain<S>: FromStr,
{
type Err = ParseHostError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
try_from_str!(parse(s))
}
}
impl<'a, S> TryFrom<&'a str> for Host<S>
where
Domain<S>: TryFrom<&'a str>,
{
type Error = ParseHostError;
fn try_from(s: &'a str) -> Result<Self, Self::Error> {
try_from_str!(try_into(s))
}
}
#[derive(Debug, Clone, Copy, thiserror::Error)]
#[error("{}", self.as_str())]
pub struct ParseAsciiHostError(pub(super) ());
impl ParseAsciiHostError {
#[inline]
pub const fn as_str(&self) -> &'static str {
"invalid ASCII host"
}
}
#[derive(Debug, Clone, thiserror::Error)]
#[error("{}", self.as_str())]
pub struct ParseHostError(pub(super) ());
impl ParseHostError {
#[inline]
pub const fn as_str(&self) -> &'static str {
"invalid host"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(any(feature = "std", feature = "alloc"))]
use std::string::String;
#[cfg(any(feature = "std", feature = "alloc"))]
#[test]
fn negative_from_str() {
let err = "@a".parse::<Host<String>>().unwrap_err();
assert_eq!(err.as_str(), "invalid host");
}
#[cfg(any(feature = "std", feature = "alloc"))]
#[test]
fn negative_try_from_str() {
let err = Host::<String>::try_from("@a").unwrap_err();
assert_eq!(err.as_str(), "invalid host");
}
#[test]
fn ip_from_ascii_str() {
let host = Host::try_from_ascii_str("127.0.0.1").unwrap();
assert!(host.is_ipv4());
let err = Host::try_from_ascii_str("@a").unwrap_err();
assert_eq!(err.as_str(), "invalid ASCII host");
}
#[test]
fn ip_from_ascii_bytes() {
let host = Host::try_from_ascii_bytes(b"127.0.0.1").unwrap();
assert!(host.is_ipv4());
let err = Host::try_from_ascii_str("@a").unwrap_err();
assert_eq!(err.as_str(), "invalid ASCII host");
}
}