use crate::{
error::{pointer_sized_decode_error, preallocation_size_limit, ReadResult, WriteResult},
io::{Reader, Writer},
schema::{SchemaRead, SchemaWrite},
};
pub trait SeqLen {
fn read<'de, T>(reader: &mut impl Reader<'de>) -> ReadResult<usize>;
fn write(writer: &mut impl Writer, len: usize) -> WriteResult<()>;
fn write_bytes_needed(len: usize) -> WriteResult<usize>;
}
const DEFAULT_BINCODE_LEN_MAX_SIZE: usize = 4 << 20; pub struct BincodeLen<const MAX_SIZE: usize = DEFAULT_BINCODE_LEN_MAX_SIZE>;
impl<const MAX_SIZE: usize> SeqLen for BincodeLen<MAX_SIZE> {
#[inline(always)]
fn read<'de, T>(reader: &mut impl Reader<'de>) -> ReadResult<usize> {
let len = u64::get(reader)
.and_then(|len| usize::try_from(len).map_err(|_| pointer_sized_decode_error()))?;
let needed = len
.checked_mul(size_of::<T>())
.ok_or_else(|| preallocation_size_limit(usize::MAX, MAX_SIZE))?;
if needed > MAX_SIZE {
return Err(preallocation_size_limit(needed, MAX_SIZE));
}
Ok(len)
}
#[inline(always)]
fn write(writer: &mut impl Writer, len: usize) -> WriteResult<()> {
u64::write(writer, &(len as u64))
}
#[inline(always)]
fn write_bytes_needed(_len: usize) -> WriteResult<usize> {
Ok(size_of::<u64>())
}
}
#[cfg(feature = "solana-short-vec")]
pub mod short_vec {
use {
super::*,
crate::error::{read_length_encoding_overflow, write_length_encoding_overflow},
core::{
mem::{transmute, MaybeUninit},
ptr,
},
solana_short_vec::{decode_shortu16_len, ShortU16},
};
impl<'de> SchemaRead<'de> for ShortU16 {
type Dst = Self;
fn read(reader: &mut impl Reader<'de>, dst: &mut MaybeUninit<Self::Dst>) -> ReadResult<()> {
let Ok((len, read)) = decode_shortu16_len(reader.fill_buf(3)?) else {
return Err(read_length_encoding_overflow("u16::MAX"));
};
unsafe { reader.consume_unchecked(read) };
let slot = unsafe { &mut *(&raw mut (*dst.as_mut_ptr()).0).cast::<MaybeUninit<u16>>() };
slot.write(len as u16);
Ok(())
}
}
impl SchemaWrite for ShortU16 {
type Src = Self;
fn size_of(src: &Self::Src) -> WriteResult<usize> {
Ok(short_u16_bytes_needed(src.0))
}
fn write(writer: &mut impl Writer, src: &Self::Src) -> WriteResult<()> {
let val = src.0;
let needed = short_u16_bytes_needed(val);
let mut buf = [MaybeUninit::<u8>::uninit(); 3];
unsafe { encode_short_u16(buf.as_mut_ptr().cast::<u8>(), needed, val) };
let buf =
unsafe { transmute::<&[MaybeUninit<u8>], &[u8]>(buf.get_unchecked(..needed)) };
writer.write(buf)?;
Ok(())
}
}
pub type ShortU16Len = ShortU16;
#[inline(always)]
#[allow(clippy::arithmetic_side_effects)]
fn short_u16_bytes_needed(len: u16) -> usize {
1 + (len >= 0x80) as usize + (len >= 0x4000) as usize
}
#[inline(always)]
fn try_short_u16_bytes_needed<T: TryInto<u16>>(len: T) -> WriteResult<usize> {
match len.try_into() {
Ok(len) => Ok(short_u16_bytes_needed(len)),
Err(_) => Err(write_length_encoding_overflow("u16::MAX")),
}
}
#[inline(always)]
unsafe fn encode_short_u16(dst: *mut u8, needed: usize, len: u16) {
match needed {
1 => ptr::write(dst, len as u8),
2 => {
ptr::write(dst, ((len & 0x7f) as u8) | 0x80);
ptr::write(dst.add(1), (len >> 7) as u8);
}
3 => {
ptr::write(dst, ((len & 0x7f) as u8) | 0x80);
ptr::write(dst.add(1), (((len >> 7) & 0x7f) as u8) | 0x80);
ptr::write(dst.add(2), (len >> 14) as u8);
}
_ => unreachable!(),
}
}
impl SeqLen for ShortU16Len {
#[inline(always)]
fn read<'de, T>(reader: &mut impl Reader<'de>) -> ReadResult<usize> {
let Ok((len, read)) = decode_shortu16_len(reader.fill_buf(3)?) else {
return Err(read_length_encoding_overflow("u16::MAX"));
};
unsafe { reader.consume_unchecked(read) };
Ok(len)
}
#[inline(always)]
fn write(writer: &mut impl Writer, len: usize) -> WriteResult<()> {
if len > u16::MAX as usize {
return Err(write_length_encoding_overflow("u16::MAX"));
}
<ShortU16 as SchemaWrite>::write(writer, &ShortU16(len as u16))
}
#[inline(always)]
fn write_bytes_needed(len: usize) -> WriteResult<usize> {
try_short_u16_bytes_needed(len)
}
}
#[cfg(all(test, feature = "alloc", feature = "derive"))]
mod tests {
use {
super::*,
crate::{
containers::{self, Pod},
proptest_config::proptest_cfg,
},
alloc::vec::Vec,
proptest::prelude::*,
solana_short_vec::ShortU16,
wincode_derive::{SchemaRead, SchemaWrite},
};
fn our_short_u16_encode(len: u16) -> Vec<u8> {
let needed = short_u16_bytes_needed(len);
let mut buf = Vec::with_capacity(needed);
unsafe {
encode_short_u16(buf.as_mut_ptr(), needed, len);
buf.set_len(needed);
}
buf
}
#[derive(
serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq, SchemaWrite, SchemaRead,
)]
#[wincode(internal)]
struct ShortVecStruct {
#[serde(with = "solana_short_vec")]
#[wincode(with = "containers::Vec<Pod<u8>, ShortU16Len>")]
bytes: Vec<u8>,
#[serde(with = "solana_short_vec")]
#[wincode(with = "containers::Vec<Pod<[u8; 32]>, ShortU16Len>")]
ar: Vec<[u8; 32]>,
}
#[derive(SchemaWrite, SchemaRead, serde::Serialize, serde::Deserialize)]
#[wincode(internal)]
struct ShortVecAsSchema {
short_u16: ShortU16,
}
fn strat_short_vec_struct() -> impl Strategy<Value = ShortVecStruct> {
(
proptest::collection::vec(any::<u8>(), 0..=100),
proptest::collection::vec(any::<[u8; 32]>(), 0..=16),
)
.prop_map(|(bytes, ar)| ShortVecStruct { bytes, ar })
}
proptest! {
#![proptest_config(proptest_cfg())]
#[test]
fn encode_u16_equivalence(len in 0..=u16::MAX) {
let our = our_short_u16_encode(len);
let bincode = bincode::serialize(&ShortU16(len)).unwrap();
prop_assert_eq!(our, bincode);
}
#[test]
fn test_short_vec_struct(short_vec_struct in strat_short_vec_struct()) {
let bincode_serialized = bincode::serialize(&short_vec_struct).unwrap();
let schema_serialized = crate::serialize(&short_vec_struct).unwrap();
prop_assert_eq!(&bincode_serialized, &schema_serialized);
let bincode_deserialized: ShortVecStruct = bincode::deserialize(&bincode_serialized).unwrap();
let schema_deserialized: ShortVecStruct = crate::deserialize(&schema_serialized).unwrap();
prop_assert_eq!(&short_vec_struct, &bincode_deserialized);
prop_assert_eq!(short_vec_struct, schema_deserialized);
}
#[test]
fn test_short_vec_as_schema(sv in any::<u16>()) {
let val = ShortVecAsSchema { short_u16: ShortU16(sv) };
let bincode_serialized = bincode::serialize(&val).unwrap();
let wincode_serialized = crate::serialize(&val).unwrap();
prop_assert_eq!(&bincode_serialized, &wincode_serialized);
let bincode_deserialized: ShortVecAsSchema = bincode::deserialize(&bincode_serialized).unwrap();
let wincode_deserialized: ShortVecAsSchema = crate::deserialize(&wincode_serialized).unwrap();
prop_assert_eq!(val.short_u16.0, bincode_deserialized.short_u16.0);
prop_assert_eq!(val.short_u16.0, wincode_deserialized.short_u16.0);
}
}
}
}
#[cfg(feature = "solana-short-vec")]
pub use short_vec::*;