use alloc::{borrow::ToOwned, format, string::String, vec::Vec};
use core::{fmt, marker::PhantomData};
use serde::{
Deserializer,
de::{DeserializeSeed, IgnoredAny, MapAccess, SeqAccess, Visitor},
};
use super::{
error::{ActualValue, FieldError},
state::Field,
};
pub trait FieldDecode<'de>: Sized {
fn decode_field<D>(deserializer: D) -> Result<Field<Self>, D::Error>
where
D: Deserializer<'de>;
}
pub trait ScalarFieldDecode: Sized {
const EXPECTED: &'static str;
fn from_bool(_value: bool) -> Result<Self, FieldError> {
Err(FieldError::type_mismatch(
Self::EXPECTED,
ActualValue::Boolean,
))
}
fn from_char(_value: char) -> Result<Self, FieldError> {
Err(FieldError::type_mismatch(
Self::EXPECTED,
ActualValue::Character,
))
}
fn from_i64(value: i64) -> Result<Self, FieldError> {
Self::from_i128(i128::from(value))
}
fn from_u64(value: u64) -> Result<Self, FieldError> {
Self::from_u128(u128::from(value))
}
fn from_i128(_value: i128) -> Result<Self, FieldError> {
Err(FieldError::type_mismatch(
Self::EXPECTED,
ActualValue::SignedInteger,
))
}
fn from_u128(_value: u128) -> Result<Self, FieldError> {
Err(FieldError::type_mismatch(
Self::EXPECTED,
ActualValue::UnsignedInteger,
))
}
fn from_f64(_value: f64) -> Result<Self, FieldError> {
Err(FieldError::type_mismatch(
Self::EXPECTED,
ActualValue::FloatingPoint,
))
}
fn from_str(_value: &str) -> Result<Self, FieldError> {
Err(FieldError::type_mismatch(
Self::EXPECTED,
ActualValue::String,
))
}
fn from_string(value: String) -> Result<Self, FieldError> {
Self::from_str(&value)
}
fn from_bytes(_value: &[u8]) -> Result<Self, FieldError> {
Err(FieldError::type_mismatch(
Self::EXPECTED,
ActualValue::Bytes,
))
}
fn from_byte_buf(value: Vec<u8>) -> Result<Self, FieldError> {
Self::from_bytes(&value)
}
}
impl<'de, T> FieldDecode<'de> for T
where
T: ScalarFieldDecode,
{
fn decode_field<D>(deserializer: D) -> Result<Field<Self>, D::Error>
where
D: Deserializer<'de>,
{
struct ScalarVisitor<T>(PhantomData<T>);
impl<'de, T> Visitor<'de> for ScalarVisitor<T>
where
T: ScalarFieldDecode,
{
type Value = Field<T>;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str(T::EXPECTED)
}
fn visit_unit<E>(self) -> Result<Self::Value, E> {
Ok(Field::Missing)
}
fn visit_none<E>(self) -> Result<Self::Value, E> {
Ok(Field::Missing)
}
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
T::decode_field(deserializer)
}
fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E> {
Ok(Field::from_result(T::from_bool(value)))
}
fn visit_char<E>(self, value: char) -> Result<Self::Value, E> {
Ok(Field::from_result(T::from_char(value)))
}
fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E> {
Ok(Field::from_result(T::from_i64(value)))
}
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E> {
Ok(Field::from_result(T::from_u64(value)))
}
fn visit_i128<E>(self, value: i128) -> Result<Self::Value, E> {
Ok(Field::from_result(T::from_i128(value)))
}
fn visit_u128<E>(self, value: u128) -> Result<Self::Value, E> {
Ok(Field::from_result(T::from_u128(value)))
}
fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E> {
Ok(Field::from_result(T::from_f64(value)))
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E> {
Ok(Field::from_result(T::from_str(value)))
}
fn visit_string<E>(self, value: String) -> Result<Self::Value, E> {
Ok(Field::from_result(T::from_string(value)))
}
fn visit_bytes<E>(self, value: &[u8]) -> Result<Self::Value, E> {
Ok(Field::from_result(T::from_bytes(value)))
}
fn visit_borrowed_bytes<E>(self, value: &'de [u8]) -> Result<Self::Value, E> {
Ok(Field::from_result(T::from_bytes(value)))
}
fn visit_byte_buf<E>(self, value: Vec<u8>) -> Result<Self::Value, E> {
Ok(Field::from_result(T::from_byte_buf(value)))
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
invalid_seq(
&mut seq,
FieldError::type_mismatch(T::EXPECTED, ActualValue::Array),
)
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
scalar_from_map::<_, T>(&mut map)
}
}
deserializer.deserialize_any(ScalarVisitor::<T>(PhantomData))
}
}
impl<'de, T> FieldDecode<'de> for Option<T>
where
T: FieldDecode<'de>,
{
fn decode_field<D>(deserializer: D) -> Result<Field<Self>, D::Error>
where
D: Deserializer<'de>,
{
struct OptionVisitor<T>(PhantomData<T>);
impl<'de, T> Visitor<'de> for OptionVisitor<T>
where
T: FieldDecode<'de>,
{
type Value = Field<Option<T>>;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("optional recoverable field")
}
fn visit_unit<E>(self) -> Result<Self::Value, E> {
Ok(Field::Valid(None))
}
fn visit_none<E>(self) -> Result<Self::Value, E> {
Ok(Field::Valid(None))
}
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
Ok(match T::decode_field(deserializer)? {
Field::Missing => Field::Valid(None),
Field::Valid(value) => Field::Valid(Some(value)),
Field::Invalid(error) => Field::Invalid(error),
})
}
}
deserializer.deserialize_option(OptionVisitor::<T>(PhantomData))
}
}
const SERDE_JSON_ARBITRARY_PRECISION_NUMBER: &str = "$serde_json::private::Number";
enum ScalarMapKey {
SerdeJsonArbitraryPrecisionNumber,
Other,
}
struct ScalarMapKeySeed;
impl<'de> DeserializeSeed<'de> for ScalarMapKeySeed {
type Value = ScalarMapKey;
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(ScalarMapKeyVisitor)
}
}
struct ScalarMapKeyVisitor;
macro_rules! visit_non_string_as {
($value:expr) => {
fn visit_bool<E>(self, _value: bool) -> Result<Self::Value, E> {
Ok($value)
}
fn visit_i64<E>(self, _value: i64) -> Result<Self::Value, E> {
Ok($value)
}
fn visit_u64<E>(self, _value: u64) -> Result<Self::Value, E> {
Ok($value)
}
fn visit_i128<E>(self, _value: i128) -> Result<Self::Value, E> {
Ok($value)
}
fn visit_u128<E>(self, _value: u128) -> Result<Self::Value, E> {
Ok($value)
}
fn visit_f64<E>(self, _value: f64) -> Result<Self::Value, E> {
Ok($value)
}
fn visit_char<E>(self, _value: char) -> Result<Self::Value, E> {
Ok($value)
}
fn visit_bytes<E>(self, _value: &[u8]) -> Result<Self::Value, E> {
Ok($value)
}
fn visit_borrowed_bytes<E>(self, _value: &'de [u8]) -> Result<Self::Value, E> {
Ok($value)
}
fn visit_byte_buf<E>(self, _value: Vec<u8>) -> Result<Self::Value, E> {
Ok($value)
}
fn visit_unit<E>(self) -> Result<Self::Value, E> {
Ok($value)
}
fn visit_none<E>(self) -> Result<Self::Value, E> {
Ok($value)
}
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(self)
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
drain_seq(&mut seq)?;
Ok($value)
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
drain_map(&mut map)?;
Ok($value)
}
};
}
impl<'de> Visitor<'de> for ScalarMapKeyVisitor {
type Value = ScalarMapKey;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("map key")
}
visit_non_string_as!(ScalarMapKey::Other);
fn visit_str<E>(self, _value: &str) -> Result<Self::Value, E> {
Ok(ScalarMapKey::Other)
}
fn visit_borrowed_str<E>(self, value: &'de str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(scalar_map_key_from_str(value))
}
fn visit_string<E>(self, _value: String) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(ScalarMapKey::Other)
}
}
enum ArbitraryPrecisionNumber<T> {
Scalar(Field<T>),
NotNumber,
}
struct ArbitraryPrecisionNumberSeed<T>(PhantomData<T>);
impl<'de, T> DeserializeSeed<'de> for ArbitraryPrecisionNumberSeed<T>
where
T: ScalarFieldDecode,
{
type Value = ArbitraryPrecisionNumber<T>;
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(ArbitraryPrecisionNumberVisitor::<T>(PhantomData))
}
}
struct ArbitraryPrecisionNumberVisitor<T>(PhantomData<T>);
impl<'de, T> Visitor<'de> for ArbitraryPrecisionNumberVisitor<T>
where
T: ScalarFieldDecode,
{
type Value = ArbitraryPrecisionNumber<T>;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("string containing a JSON number")
}
visit_non_string_as!(ArbitraryPrecisionNumber::NotNumber);
fn visit_str<E>(self, _value: &str) -> Result<Self::Value, E> {
Ok(ArbitraryPrecisionNumber::NotNumber)
}
fn visit_borrowed_str<E>(self, _value: &'de str) -> Result<Self::Value, E> {
Ok(ArbitraryPrecisionNumber::NotNumber)
}
fn visit_string<E>(self, value: String) -> Result<Self::Value, E> {
Ok(ArbitraryPrecisionNumber::Scalar(Field::from_result(
scalar_from_json_number::<T>(&value),
)))
}
}
fn scalar_map_key_from_str(value: &str) -> ScalarMapKey {
if value == SERDE_JSON_ARBITRARY_PRECISION_NUMBER {
ScalarMapKey::SerdeJsonArbitraryPrecisionNumber
} else {
ScalarMapKey::Other
}
}
fn scalar_from_map<'de, A, T>(map: &mut A) -> Result<Field<T>, A::Error>
where
A: MapAccess<'de>,
T: ScalarFieldDecode,
{
let object_error = || FieldError::type_mismatch(T::EXPECTED, ActualValue::Object);
let Some(key) = map.next_key_seed(ScalarMapKeySeed)? else {
return Ok(Field::invalid(object_error()));
};
match key {
ScalarMapKey::Other => {
map.next_value::<IgnoredAny>()?;
invalid_map(map, object_error())
}
ScalarMapKey::SerdeJsonArbitraryPrecisionNumber => {
let number = map.next_value_seed(ArbitraryPrecisionNumberSeed::<T>(PhantomData))?;
if map.next_key::<IgnoredAny>()?.is_some() {
map.next_value::<IgnoredAny>()?;
drain_map(map)?;
return Ok(Field::invalid(object_error()));
}
match number {
ArbitraryPrecisionNumber::Scalar(field) => Ok(field),
ArbitraryPrecisionNumber::NotNumber => Ok(Field::invalid(object_error())),
}
}
}
}
fn scalar_from_json_number<T>(value: &str) -> Result<T, FieldError>
where
T: ScalarFieldDecode,
{
if is_float_json_number(value) {
return scalar_from_json_float(value);
}
if value.starts_with('-') {
if let Ok(value) = value.parse::<i64>() {
return T::from_i64(value);
}
value
.parse()
.map_or_else(|_| scalar_from_json_float(value), T::from_i128)
} else {
if let Ok(value) = value.parse::<u64>() {
return T::from_u64(value);
}
value
.parse()
.map_or_else(|_| scalar_from_json_float(value), T::from_u128)
}
}
fn scalar_from_json_float<T>(value: &str) -> Result<T, FieldError>
where
T: ScalarFieldDecode,
{
value.parse().map_or_else(
|_| Err(FieldError::new(format!("number `{value}` is not finite"))),
T::from_f64,
)
}
fn is_float_json_number(value: &str) -> bool {
value
.as_bytes()
.iter()
.any(|byte| matches!(byte, b'.' | b'e' | b'E'))
}
impl ScalarFieldDecode for String {
const EXPECTED: &'static str = "string";
fn from_str(value: &str) -> Result<Self, FieldError> {
Ok(value.to_owned())
}
fn from_string(value: String) -> Result<Self, FieldError> {
Ok(value)
}
}
impl ScalarFieldDecode for bool {
const EXPECTED: &'static str = "boolean";
fn from_bool(value: bool) -> Result<Self, FieldError> {
Ok(value)
}
}
impl ScalarFieldDecode for char {
const EXPECTED: &'static str = "single character string";
fn from_char(value: char) -> Result<Self, FieldError> {
Ok(value)
}
fn from_str(value: &str) -> Result<Self, FieldError> {
let mut chars = value.chars();
match (chars.next(), chars.next()) {
(Some(value), None) => Ok(value),
_ => Err(FieldError::type_mismatch(
Self::EXPECTED,
ActualValue::String,
)),
}
}
}
macro_rules! impl_signed_integer {
($($ty:ty),* $(,)?) => {
$(
impl ScalarFieldDecode for $ty {
const EXPECTED: &'static str = concat!("signed integer fitting in ", stringify!($ty));
fn from_i128(value: i128) -> Result<Self, FieldError> {
Self::try_from(value)
.map_err(|_| integer_out_of_range(value, stringify!($ty)))
}
fn from_u128(value: u128) -> Result<Self, FieldError> {
Self::try_from(value)
.map_err(|_| integer_out_of_range(value, stringify!($ty)))
}
}
)*
};
}
macro_rules! impl_unsigned_integer {
($($ty:ty),* $(,)?) => {
$(
impl ScalarFieldDecode for $ty {
const EXPECTED: &'static str = concat!("unsigned integer fitting in ", stringify!($ty));
fn from_i128(value: i128) -> Result<Self, FieldError> {
Self::try_from(value)
.map_err(|_| integer_out_of_range(value, stringify!($ty)))
}
fn from_u128(value: u128) -> Result<Self, FieldError> {
Self::try_from(value)
.map_err(|_| integer_out_of_range(value, stringify!($ty)))
}
}
)*
};
}
impl_signed_integer!(i8, i16, i32, i64, i128, isize);
impl_unsigned_integer!(u8, u16, u32, u64, u128, usize);
impl ScalarFieldDecode for f32 {
const EXPECTED: &'static str = "finite floating point number fitting in f32";
fn from_i128(value: i128) -> Result<Self, FieldError> {
f32_from_f64(value as f64)
}
fn from_u128(value: u128) -> Result<Self, FieldError> {
f32_from_f64(value as f64)
}
fn from_f64(value: f64) -> Result<Self, FieldError> {
f32_from_f64(value)
}
}
#[allow(clippy::cast_precision_loss)]
impl ScalarFieldDecode for f64 {
const EXPECTED: &'static str = "finite floating point number";
fn from_i128(value: i128) -> Result<Self, FieldError> {
Ok(value as Self)
}
fn from_u128(value: u128) -> Result<Self, FieldError> {
Ok(value as Self)
}
fn from_f64(value: f64) -> Result<Self, FieldError> {
if value.is_finite() {
Ok(value)
} else {
Err(FieldError::new(format!(
"floating point `{value}` is not finite"
)))
}
}
}
pub fn drain_seq<'de, A>(seq: &mut A) -> Result<(), A::Error>
where
A: SeqAccess<'de>,
{
while seq.next_element::<IgnoredAny>()?.is_some() {}
Ok(())
}
pub fn drain_map<'de, A>(map: &mut A) -> Result<(), A::Error>
where
A: MapAccess<'de>,
{
while map.next_entry::<IgnoredAny, IgnoredAny>()?.is_some() {}
Ok(())
}
pub fn invalid_seq<'de, A, T>(
seq: &mut A,
error: impl Into<FieldError>,
) -> Result<Field<T>, A::Error>
where
A: SeqAccess<'de>,
{
drain_seq(seq)?;
Ok(Field::invalid(error))
}
pub fn invalid_map<'de, A, T>(
map: &mut A,
error: impl Into<FieldError>,
) -> Result<Field<T>, A::Error>
where
A: MapAccess<'de>,
{
drain_map(map)?;
Ok(Field::invalid(error))
}
fn integer_out_of_range(value: impl fmt::Display, ty: &'static str) -> FieldError {
FieldError::new(format!("integer `{value}` does not fit in {ty}"))
}
#[allow(clippy::cast_possible_truncation)]
fn f32_from_f64(value: f64) -> Result<f32, FieldError> {
if value.is_finite() && value >= f64::from(f32::MIN) && value <= f64::from(f32::MAX) {
Ok(value as f32)
} else {
Err(FieldError::new(format!(
"floating point `{value}` does not fit in f32"
)))
}
}
#[cfg(test)]
mod tests {
use serde::{Deserialize, de::value};
use super::*;
#[test]
fn f64_rejects_non_finite_values() {
let deserializer = value::F64Deserializer::<value::Error>::new(f64::INFINITY);
let field = Field::<f64>::deserialize(deserializer).unwrap();
assert!(field.is_invalid());
assert!(matches!(
field.error().map(crate::FieldError::message),
Some(message) if message.contains("not finite")
));
}
}