use proptest::{collection::size_range, prelude::*};
use std::matches;
use crate::serialization::{
arbitrary::max_allocation_is_big_enough, zcash_deserialize::MAX_U8_ALLOCATION,
SerializationError, TrustedPreallocate, ZcashDeserialize, ZcashSerialize,
MAX_PROTOCOL_MESSAGE_LEN,
};
impl ZcashSerialize for u8 {
fn zcash_serialize<W: std::io::Write>(&self, mut writer: W) -> Result<(), std::io::Error> {
writer.write_all(&[*self])
}
}
impl TrustedPreallocate for u8 {
fn max_allocation() -> u64 {
(MAX_PROTOCOL_MESSAGE_LEN - 5)
.try_into()
.expect("MAX_PROTOCOL_MESSAGE_LEN fits in u64")
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(4))]
#[test]
fn u8_ser_deser_roundtrip(input in any_with::<Vec<u8>>(size_range(MAX_U8_ALLOCATION).lift()) ) {
let serialized = input.zcash_serialize_to_vec().expect("Serialization to vec must succeed");
let cursor = std::io::Cursor::new(serialized);
let deserialized = <Vec<u8>>::zcash_deserialize(cursor).expect("deserialization from vec must succeed");
prop_assert_eq!(deserialized, input)
}
}
#[test]
fn u8_deser_accepts_max_valid_input() {
let serialized = vec![0u8; MAX_U8_ALLOCATION]
.zcash_serialize_to_vec()
.expect("Serialization to vec must succeed");
let cursor = std::io::Cursor::new(serialized);
let deserialized = <Vec<u8>>::zcash_deserialize(cursor);
assert!(deserialized.is_ok())
}
#[test]
fn u8_deser_throws_when_input_too_large() {
let serialized = vec![0u8; MAX_U8_ALLOCATION + 1]
.zcash_serialize_to_vec()
.expect("Serialization to vec must succeed");
let cursor = std::io::Cursor::new(serialized);
let deserialized = <Vec<u8>>::zcash_deserialize(cursor);
assert!(matches!(
deserialized,
Err(SerializationError::Parse(
"Byte vector longer than MAX_U8_ALLOCATION"
))
))
}
#[test]
fn u8_size_is_correct() {
for byte in u8::MIN..=u8::MAX {
let serialized = byte
.zcash_serialize_to_vec()
.expect("Serialization to vec must succeed");
assert!(serialized.len() == 1)
}
}
#[test]
fn u8_max_allocation_is_correct() {
let (
smallest_disallowed_vec_len,
smallest_disallowed_serialized_len,
largest_allowed_vec_len,
largest_allowed_serialized_len,
) = max_allocation_is_big_enough(0u8);
assert_eq!((smallest_disallowed_vec_len - 1), MAX_U8_ALLOCATION);
assert!(smallest_disallowed_serialized_len > MAX_PROTOCOL_MESSAGE_LEN);
assert_eq!(largest_allowed_vec_len, MAX_U8_ALLOCATION);
assert_eq!(largest_allowed_serialized_len, MAX_PROTOCOL_MESSAGE_LEN);
}