ruint 1.18.0

Unsigned integer type with const-generic bit length
Documentation
//! Support for the [`serde`](https://crates.io/crates/serde) crate.

#![cfg(feature = "serde")]
#![cfg_attr(docsrs, doc(cfg(feature = "serde")))]

use crate::{Bits, Uint, fmt::StackString, nbytes};
use core::{fmt, str};
use serde_core::{
    Deserialize, Deserializer, Serialize, Serializer,
    de::{Error, Unexpected, Visitor},
};

/// Canonical serialization for all human-readable instances of `Uint<0, 0>`,
/// and minimal human-readable `Uint<BITS, LIMBS>::ZERO` for any bit size.
const ZERO_STR: &str = "0x0";

impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
    fn serialize_human<const FULL: bool, S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
        // Ideally `SIZE` is `2 + Self::BYTES * 2`.
        if BITS <= 8 {
            self.serialize_human_sized::<FULL, 8, _>(s)
        } else {
            self.serialize_human_sized::<FULL, BITS, _>(s)
        }
    }

    fn serialize_human_sized<const FULL: bool, const SIZE: usize, S: Serializer>(
        &self,
        s: S,
    ) -> Result<S::Ok, S::Error> {
        let mut buffer = StackString::<SIZE>::new();
        self.write_hex::<FULL>(&mut buffer);
        s.serialize_str(buffer.as_str())
    }

    fn serialize_binary<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
        s.serialize_bytes(&self.to_be_bytes_vec())
    }

    fn write_hex<const FULL: bool>(&self, s: &mut (impl fmt::Write + ?Sized)) {
        if BITS == 0 {
            s.write_str(ZERO_STR)
        } else if FULL {
            write!(s, "{:#0w$x}", *self, w = 2 + Self::BYTES * 2)
        } else {
            write!(s, "{:#x}", *self)
        }
        .unwrap();
    }
}

/// Serialize a [`Uint`] value.
///
/// For human readable formats a `0x` prefixed lower case hex string is used.
/// For binary formats a byte array is used. Leading zeros are included.
impl<const BITS: usize, const LIMBS: usize> Serialize for Uint<BITS, LIMBS> {
    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
        if serializer.is_human_readable() {
            self.serialize_human::<false, _>(serializer)
        } else {
            self.serialize_binary(serializer)
        }
    }
}

/// Deserialize human readable hex strings or byte arrays into [`Uint`].
///
/// Hex strings can be upper/lower/mixed case, have an optional `0x` prefix, and
/// can be any length. They are interpreted big-endian.
impl<'de, const BITS: usize, const LIMBS: usize> Deserialize<'de> for Uint<BITS, LIMBS> {
    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
        if deserializer.is_human_readable() {
            deserializer.deserialize_any(HrVisitor)
        } else {
            deserializer.deserialize_bytes(ByteVisitor)
        }
    }
}

impl<const BITS: usize, const LIMBS: usize> Serialize for Bits<BITS, LIMBS> {
    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
        if serializer.is_human_readable() {
            self.as_uint().serialize_human::<true, _>(serializer)
        } else {
            self.as_uint().serialize_binary(serializer)
        }
    }
}

impl<'de, const BITS: usize, const LIMBS: usize> Deserialize<'de> for Bits<BITS, LIMBS> {
    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
        Uint::deserialize(deserializer).map(Self::from)
    }
}

/// Serde Visitor for human readable formats.
///
/// Accepts either a primitive number, a decimal or a hexadecimal string.
struct HrVisitor<const BITS: usize, const LIMBS: usize>;

impl<const BITS: usize, const LIMBS: usize> Visitor<'_> for HrVisitor<BITS, LIMBS> {
    type Value = Uint<BITS, LIMBS>;

    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(formatter, "a {} byte hex string", nbytes(BITS))
    }

    fn visit_u64<E: Error>(self, v: u64) -> Result<Self::Value, E> {
        Uint::try_from(v).map_err(|_| Error::invalid_value(Unexpected::Unsigned(v), &self))
    }

    fn visit_u128<E: Error>(self, v: u128) -> Result<Self::Value, E> {
        // `Unexpected::Unsigned` cannot contain a `u128`
        Uint::try_from(v).map_err(Error::custom)
    }

    fn visit_str<E: Error>(self, value: &str) -> Result<Self::Value, E> {
        // Shortcut for common case
        if value == ZERO_STR {
            return Ok(Uint::<BITS, LIMBS>::ZERO);
        }
        // `ZERO_STR` is the only valid serialization of `Uint<0, 0>`, so if we
        // have not shortcut, we are in an error case
        if BITS == 0 {
            return Err(Error::invalid_value(Unexpected::Str(value), &self));
        }

        value
            .parse()
            .map_err(|_| Error::invalid_value(Unexpected::Str(value), &self))
    }
}

/// Serde Visitor for non-human readable formats
struct ByteVisitor<const BITS: usize, const LIMBS: usize>;

impl<const BITS: usize, const LIMBS: usize> Visitor<'_> for ByteVisitor<BITS, LIMBS> {
    type Value = Uint<BITS, LIMBS>;

    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(formatter, "{BITS} bits of binary data in big endian order")
    }

    fn visit_bytes<E: Error>(self, value: &[u8]) -> Result<Self::Value, E> {
        if value.len() != nbytes(BITS) {
            return Err(Error::invalid_length(value.len(), &self));
        }
        Uint::try_from_be_slice(value).ok_or_else(|| {
            Error::invalid_value(
                Unexpected::Other(&format!("too large for Uint<{BITS}>")),
                &self,
            )
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{const_for, nlimbs};
    use proptest::proptest;

    #[allow(unused_imports)]
    use alloc::vec::Vec;

    #[test]
    fn test_serde_human_readable() {
        const_for!(BITS in SIZES {
            const LIMBS: usize = nlimbs(BITS);
            proptest!(|(value: Uint<BITS, LIMBS>)| {
                let s = format!("U{BITS} => {value}");
                let serialized = serde_json::to_string(&value).expect(&s);
                let deserialized = serde_json::from_str::<Uint<BITS, LIMBS>>(&serialized).expect(&s);
                assert_eq!(value, deserialized);
            });
            proptest!(|(value: Bits<BITS, LIMBS>)| {
                let serialized = serde_json::to_string(&value).unwrap();
                let deserialized = serde_json::from_str::<Bits<BITS, LIMBS>>(&serialized).unwrap();
                assert_eq!(value, deserialized);
            });
        });
    }

    #[test]
    fn test_human_readable_de() {
        let jason = r#"[
            1,
            "0x1",
            "0o1",
            "0b1"
        ]"#;
        let numbers: Vec<Uint<1, 1>> = serde_json::from_str(jason).unwrap();
        uint! {
            assert_eq!(numbers, vec![1_U1, 1_U1, 1_U1, 1_U1]);
        }

        let jason = r#"[
            "",
            "0x",
            "0o",
            "0b"
        ]"#;
        let numbers: Vec<Uint<1, 1>> = serde_json::from_str(jason).unwrap();
        uint! {
            assert_eq!(numbers, vec![0_U1, 0_U1, 0_U1, 0_U1]);
        }
    }

    #[test]
    fn test_serde_machine_readable() {
        const_for!(BITS in SIZES {
            const LIMBS: usize = nlimbs(BITS);
            proptest!(|(value: Uint<BITS, LIMBS>)| {
                let serialized = bincode::serialize(&value).unwrap();
                let deserialized = bincode::deserialize::<Uint<BITS, LIMBS>>(&serialized[..]).unwrap();
                assert_eq!(value, deserialized);
            });
            proptest!(|(value: Bits<BITS, LIMBS>)| {
                let serialized = bincode::serialize(&value).unwrap();
                let deserialized = bincode::deserialize::<Bits<BITS, LIMBS>>(&serialized[..]).unwrap();
                assert_eq!(value, deserialized);
            });
        });
    }

    #[test]
    fn test_serde_invalid_size_error() {
        // Test that if we add a character to a value that is already the max length for
        // the given number of bits, we get an error.
        const_for!(BITS in SIZES {
            const LIMBS: usize = nlimbs(BITS);
            let value = Uint::<BITS, LIMBS>::MAX;
            let mut serialized = serde_json::to_string(&value).unwrap();

            // ensure format of serialized value is correct ("0x...")
            assert_eq!(&serialized[..3], "\"0x");
            // last character should be a quote
            assert_eq!(&serialized[serialized.len() - 1..], "\"");

            // strip the last character, add a zero, and finish with a quote
            serialized.pop();
            serialized.push('0');
            serialized.push('"');
            let deserialized = serde_json::from_str::<Uint<BITS, LIMBS>>(&serialized);
            assert!(deserialized.is_err(), "U{BITS} {serialized} => {deserialized:?}");
        });
    }
}