extern crate alloc;
use crate::{
char_set::{AllowedAscii, PRINTABLE_ASCII},
dom::{Domain, DomainErr, Rfc1123Domain, Rfc1123Err},
};
use alloc::{borrow::ToOwned as _, string::String};
use core::{
fmt::{self, Formatter},
marker::PhantomData,
};
use serde::{
de::{self, Deserialize, Deserializer, Unexpected, Visitor},
ser::{Serialize, Serializer},
};
static DOMAIN_CHARS: &AllowedAscii<[u8; 92]> = &PRINTABLE_ASCII;
impl<T: AsRef<[u8]>> Serialize for Domain<T> {
#[inline]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.as_str())
}
}
impl<T: AsRef<[u8]>> Serialize for Rfc1123Domain<T> {
#[inline]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.as_str())
}
}
#[expect(
clippy::partial_pub_fields,
reason = "we don't expost PhantomData for obvious reasons, so this is fine"
)]
#[derive(Clone, Copy, Debug)]
pub struct DomainVisitor<'a, T, T2> {
_x: PhantomData<fn() -> T2>,
pub allowed_ascii: &'a AllowedAscii<T>,
}
fn dom_err_to_serde<E: de::Error>(value: DomainErr) -> E {
match value {
DomainErr::Empty => E::invalid_length(
0,
&"a valid domain with length inclusively between 1 and 253",
),
DomainErr::RootDomain => {
E::invalid_length(0, &"a valid domain with at least one non-root label")
}
DomainErr::LenExceeds253(len) => E::invalid_length(
len,
&"a valid domain with length inclusively between 1 and 253",
),
DomainErr::LabelLenExceeds63 => E::invalid_length(
64,
&"a valid domain containing labels of length inclusively between 1 and 63",
),
DomainErr::EmptyLabel => E::invalid_length(
0,
&"a valid domain containing labels of length inclusively between 1 and 63",
),
DomainErr::InvalidByte(byt) => E::invalid_value(
Unexpected::Unsigned(u64::from(byt)),
&"a valid domain containing only the supplied ASCII subset",
),
}
}
impl<'a, T, T2> DomainVisitor<'a, T, T2> {
#[expect(single_use_lifetimes, reason = "false positive")]
#[inline]
pub const fn new<'b: 'a>(allowed_ascii: &'b AllowedAscii<T>) -> Self {
Self {
_x: PhantomData,
allowed_ascii,
}
}
}
impl<'de: 'a, 'a, T: AsRef<[u8]>> Visitor<'de> for DomainVisitor<'_, T, &'a str> {
type Value = Domain<&'a str>;
#[inline]
fn expecting(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
formatter.write_str("Domain")
}
#[inline]
fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
where
E: de::Error,
{
Self::Value::try_from_bytes(v, self.allowed_ascii).map_err(|err| dom_err_to_serde::<E>(err))
}
}
impl<T: AsRef<[u8]>> Visitor<'_> for DomainVisitor<'_, T, String> {
type Value = Domain<String>;
#[inline]
fn expecting(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
formatter.write_str("Domain")
}
#[inline]
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: de::Error,
{
Self::Value::try_from_bytes(v, self.allowed_ascii).map_err(|err| dom_err_to_serde::<E>(err))
}
#[inline]
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
self.visit_string(v.to_owned())
}
}
impl<'de> Deserialize<'de> for Domain<String> {
#[inline]
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_string(DomainVisitor::<'_, _, String>::new(DOMAIN_CHARS))
}
}
impl<'de: 'a, 'a> Deserialize<'de> for Domain<&'a str> {
#[inline]
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(DomainVisitor::<'_, _, &str>::new(DOMAIN_CHARS))
}
}
fn rfc_err_to_serde<E: de::Error>(value: Rfc1123Err) -> E {
match value {
Rfc1123Err::DomainErr(err) => dom_err_to_serde(err),
Rfc1123Err::LabelStartsWithAHyphen | Rfc1123Err::LabelEndsWithAHyphen => E::invalid_value(
Unexpected::Str("-"),
&"a valid domain conforming to RFC 1123 which requires all labels to not begin or end with a '-'",
),
Rfc1123Err::InvalidTld => E::invalid_value(
Unexpected::Str(
"tld that is not all letters nor begins with 'xn--' and has length of at least five",
),
&"a valid domain conforming to RFC 1123 which requires the last label (i.e., TLD) to either be all letters or have length of at least five and begins with 'xn--'",
),
}
}
struct Rfc1123Visitor<T>(PhantomData<fn() -> T>);
impl<'de: 'a, 'a> Visitor<'de> for Rfc1123Visitor<&'a str> {
type Value = Rfc1123Domain<&'a str>;
fn expecting(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
formatter.write_str("Rfc1123Domain")
}
fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
where
E: de::Error,
{
Self::Value::try_from_bytes(v).map_err(|err| rfc_err_to_serde(err))
}
}
impl Visitor<'_> for Rfc1123Visitor<String> {
type Value = Rfc1123Domain<String>;
fn expecting(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
formatter.write_str("Rfc1123Domain")
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: de::Error,
{
Self::Value::try_from_bytes(v).map_err(|err| rfc_err_to_serde(err))
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
self.visit_string(v.to_owned())
}
}
impl<'de> Deserialize<'de> for Rfc1123Domain<String> {
#[inline]
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_string(Rfc1123Visitor::<String>(PhantomData))
}
}
impl<'de: 'a, 'a> Deserialize<'de> for Rfc1123Domain<&'a str> {
#[inline]
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(Rfc1123Visitor::<&'a str>(PhantomData))
}
}
#[cfg(test)]
mod tests {
extern crate alloc;
use crate::{
char_set::ASCII_HYPHEN_DIGITS_LETTERS,
dom::{Domain, Rfc1123Domain},
};
use alloc::string::String;
#[test]
fn test_serde() {
assert!(
serde_json::from_str::<Domain<&str>>(r#""example.com""#)
.map_or(false, |dom| dom.into_iter().count() == 2)
);
assert!(
serde_json::from_str::<Domain<String>>(r#""c\"om""#)
.map_or(false, |dom| dom.into_iter().count() == 1)
);
assert!(
serde_json::from_str::<Domain<&str>>(r#""c\"om""#)
.map_or_else(|err| err.is_data() && err.column() == 7, |_| false)
);
assert!(
serde_json::to_string(
&Domain::try_from_bytes("example.com", &ASCII_HYPHEN_DIGITS_LETTERS).unwrap()
)
.map_or(false, |output| output == r#""example.com""#)
);
assert!(
serde_json::to_string(
&Domain::try_from_bytes(b"example.com", &ASCII_HYPHEN_DIGITS_LETTERS).unwrap()
)
.map_or(false, |output| output == r#""example.com""#)
);
assert!(
serde_json::from_str::<Rfc1123Domain<&str>>(r#""example.com""#)
.map_or(false, |dom| dom.into_iter().count() == 2)
);
assert!(
serde_json::from_str::<Rfc1123Domain<String>>(r#""c\u006fm""#)
.map_or(false, |dom| dom.tld().as_str() == "com")
);
assert!(
serde_json::from_str::<Rfc1123Domain<&str>>(r#""c\u006fm""#)
.map_or_else(|err| err.is_data() && err.column() == 10, |_| false)
);
}
}