use serde::de::{self, Visitor};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt;
use std::marker::PhantomData;
use std::str::FromStr;
use uuid::Uuid;
pub(crate) fn serialize_display_or_inner<DisplayValue, Inner, S>(
display_value: &DisplayValue,
inner: &Inner,
serializer: S,
) -> Result<S::Ok, S::Error>
where
DisplayValue: fmt::Display + ?Sized,
Inner: Serialize,
S: Serializer,
{
if serializer.is_human_readable() {
serializer.collect_str(display_value)
} else {
inner.serialize(serializer)
}
}
pub(crate) fn deserialize_i64_wrapper<'de, Value, D, F, E>(
deserializer: D,
validate: F,
) -> Result<Value, D::Error>
where
Value: FromStr<Err = E>,
F: FnOnce(i64) -> Result<Value, E>,
E: fmt::Display,
D: Deserializer<'de>,
{
if !deserializer.is_human_readable() {
let raw = i64::deserialize(deserializer)?;
return validate(raw).map_err(de::Error::custom);
}
deserializer.deserialize_any(I64WrapperVisitor::<Value, F> {
validate,
_marker: PhantomData,
})
}
struct I64WrapperVisitor<Value, F> {
validate: F,
_marker: PhantomData<Value>,
}
impl<'de, Value, F, E> Visitor<'de> for I64WrapperVisitor<Value, F>
where
Value: FromStr<Err = E>,
F: FnOnce(i64) -> Result<Value, E>,
E: fmt::Display,
{
type Value = Value;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("a string or integer identifier")
}
fn visit_str<DE>(self, value: &str) -> Result<Self::Value, DE>
where
DE: de::Error,
{
Value::from_str(value).map_err(DE::custom)
}
fn visit_i64<DE>(self, value: i64) -> Result<Self::Value, DE>
where
DE: de::Error,
{
(self.validate)(value).map_err(DE::custom)
}
fn visit_u64<DE>(self, value: u64) -> Result<Self::Value, DE>
where
DE: de::Error,
{
let signed = i64::try_from(value)
.map_err(|_| DE::custom(format_args!("integer {value} exceeds i64 range")))?;
(self.validate)(signed).map_err(DE::custom)
}
}
pub(crate) fn deserialize_uuid_wrapper<'de, Value, D, F, E>(
deserializer: D,
validate: F,
) -> Result<Value, D::Error>
where
Value: FromStr<Err = E>,
F: FnOnce(Uuid) -> Result<Value, E>,
E: fmt::Display,
D: Deserializer<'de>,
{
if !deserializer.is_human_readable() {
let raw = Uuid::deserialize(deserializer)?;
return validate(raw).map_err(de::Error::custom);
}
deserializer.deserialize_str(UuidWrapperVisitor::<Value>(PhantomData))
}
struct UuidWrapperVisitor<Value>(PhantomData<Value>);
impl<'de, Value> Visitor<'de> for UuidWrapperVisitor<Value>
where
Value: FromStr,
Value::Err: fmt::Display,
{
type Value = Value;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("a UUID string")
}
fn visit_str<DE>(self, value: &str) -> Result<Self::Value, DE>
where
DE: de::Error,
{
Value::from_str(value).map_err(DE::custom)
}
}