use {
crate::{
SchemaRead, SchemaWrite, TypeMeta,
config::{ConfigCore, PREALLOCATION_SIZE_LIMIT_DISABLED},
error::{
PreallocationError, ReadResult, WriteResult, pointer_sized_decode_error,
preallocation_size_limit, write_length_encoding_overflow,
},
int_encoding::{ByteOrder, Endian},
io::{Reader, Writer},
},
core::{any::type_name, marker::PhantomData},
};
pub const PREALLOCATION_SIZE_LIMIT_USE_CONFIG: usize = 0;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum PreallocationLimitOverride {
#[default]
UseConfig,
NoLimit,
Override(usize),
}
impl PreallocationLimitOverride {
#[inline]
pub const fn to_opt_limit_with_config<C: ConfigCore>(self) -> Option<usize> {
match self {
PreallocationLimitOverride::UseConfig => C::PREALLOCATION_SIZE_LIMIT,
PreallocationLimitOverride::NoLimit => None,
PreallocationLimitOverride::Override(limit) => Some(limit),
}
}
#[inline]
pub const fn from_usize(limit: usize) -> Self {
match limit {
PREALLOCATION_SIZE_LIMIT_USE_CONFIG => PreallocationLimitOverride::UseConfig,
PREALLOCATION_SIZE_LIMIT_DISABLED => PreallocationLimitOverride::NoLimit,
_ => PreallocationLimitOverride::Override(limit),
}
}
}
pub unsafe trait SeqLen<C: ConfigCore> {
const PREALLOCATION_SIZE_LIMIT_OVERRIDE: PreallocationLimitOverride =
PreallocationLimitOverride::UseConfig;
#[inline]
fn prealloc_check<T>(len: usize) -> Result<(), PreallocationError> {
fn check(len: usize, type_size: usize, limit: usize) -> Result<(), PreallocationError> {
let needed = len
.checked_mul(type_size)
.ok_or_else(|| preallocation_size_limit(usize::MAX, limit))?;
if needed > limit {
return Err(preallocation_size_limit(needed, limit));
}
Ok(())
}
if let Some(prealloc_limit) =
Self::PREALLOCATION_SIZE_LIMIT_OVERRIDE.to_opt_limit_with_config::<C>()
{
check(len, size_of::<T>(), prealloc_limit)?;
}
Ok(())
}
#[inline]
fn read_prealloc_check<'de, T>(reader: impl Reader<'de>) -> ReadResult<usize> {
let len = Self::read(reader)?;
Self::prealloc_check::<T>(len)?;
Ok(len)
}
fn read<'de>(reader: impl Reader<'de>) -> ReadResult<usize>;
fn write(writer: impl Writer, len: usize) -> WriteResult<()>;
fn write_bytes_needed_prealloc_check<T>(len: usize) -> WriteResult<usize> {
Self::prealloc_check::<T>(len)?;
Self::write_bytes_needed(len)
}
fn write_bytes_needed(len: usize) -> WriteResult<usize>;
}
pub struct UseIntLen<T, const PREALLOCATION_SIZE_LIMIT: usize = PREALLOCATION_SIZE_LIMIT_USE_CONFIG>(
PhantomData<T>,
);
unsafe impl<const PREALLOCATION_SIZE_LIMIT: usize, T, C: ConfigCore> SeqLen<C>
for UseIntLen<T, PREALLOCATION_SIZE_LIMIT>
where
T: SchemaWrite<C> + for<'de> SchemaRead<'de, C>,
T::Src: TryFrom<usize>,
usize: for<'de> TryFrom<<T as SchemaRead<'de, C>>::Dst>,
{
const PREALLOCATION_SIZE_LIMIT_OVERRIDE: PreallocationLimitOverride =
PreallocationLimitOverride::from_usize(PREALLOCATION_SIZE_LIMIT);
#[inline(always)]
fn read<'de>(reader: impl Reader<'de>) -> ReadResult<usize> {
let len = T::get(reader)?;
let Ok(len) = usize::try_from(len) else {
return Err(pointer_sized_decode_error());
};
Ok(len)
}
#[inline(always)]
fn write(writer: impl Writer, len: usize) -> WriteResult<()> {
let Ok(len) = T::Src::try_from(len) else {
return Err(write_length_encoding_overflow(type_name::<T::Src>()));
};
T::write(writer, &len)
}
#[inline(always)]
fn write_bytes_needed(len: usize) -> WriteResult<usize> {
if let TypeMeta::Static { size, .. } = <T as SchemaWrite<C>>::TYPE_META {
return Ok(size);
}
let Ok(len) = T::Src::try_from(len) else {
return Err(write_length_encoding_overflow(type_name::<T::Src>()));
};
T::size_of(&len)
}
}
macro_rules! impl_use_int_primitive {
($($type:ty),+) => {
$(
unsafe impl<C: ConfigCore> SeqLen<C> for $type {
#[inline(always)]
#[allow(irrefutable_let_patterns)]
fn read<'de>(reader: impl Reader<'de>) -> ReadResult<usize> {
let len = <$type as SchemaRead<C>>::get(reader)?;
let Ok(len) = usize::try_from(len) else {
return Err(pointer_sized_decode_error());
};
Ok(len)
}
#[inline(always)]
fn write(writer: impl Writer, len: usize) -> WriteResult<()> {
let Ok(len) = <$type>::try_from(len) else {
return Err(write_length_encoding_overflow(type_name::<$type>()));
};
<$type as SchemaWrite<C>>::write(writer, &len)
}
#[inline(always)]
fn write_bytes_needed(len: usize) -> WriteResult<usize> {
if let TypeMeta::Static { size, .. } = <$type as SchemaWrite<C>>::TYPE_META {
return Ok(size);
}
let Ok(len) = <$type>::try_from(len) else {
return Err(write_length_encoding_overflow(type_name::<$type>()));
};
<$type as SchemaWrite<C>>::size_of(&len)
}
}
)+
};
}
impl_use_int_primitive!(u8, u16, u32, u64, u128, i8, i16, i32, i64, i128);
pub struct FixIntLen<T, const PREALLOCATION_SIZE_LIMIT: usize = PREALLOCATION_SIZE_LIMIT_USE_CONFIG>(
PhantomData<T>,
);
macro_rules! impl_fix_int {
($type:ty) => {
unsafe impl<const PREALLOCATION_SIZE_LIMIT: usize, C: ConfigCore> SeqLen<C>
for FixIntLen<$type, PREALLOCATION_SIZE_LIMIT>
{
const PREALLOCATION_SIZE_LIMIT_OVERRIDE: PreallocationLimitOverride =
PreallocationLimitOverride::from_usize(PREALLOCATION_SIZE_LIMIT);
#[inline(always)]
#[allow(irrefutable_let_patterns)]
fn read<'de>(mut reader: impl Reader<'de>) -> ReadResult<usize> {
let bytes = reader.take_array::<{ size_of::<$type>() }>()?;
let len = match C::ByteOrder::ENDIAN {
Endian::Big => <$type>::from_be_bytes(bytes),
Endian::Little => <$type>::from_le_bytes(bytes),
};
let Ok(len) = usize::try_from(len) else {
return Err(pointer_sized_decode_error());
};
Ok(len)
}
#[inline(always)]
fn write(mut writer: impl Writer, len: usize) -> WriteResult<()> {
let Ok(len) = <$type>::try_from(len) else {
return Err(write_length_encoding_overflow(type_name::<$type>()));
};
let bytes = match C::ByteOrder::ENDIAN {
Endian::Big => len.to_be_bytes(),
Endian::Little => len.to_le_bytes(),
};
writer.write(&bytes)?;
Ok(())
}
#[inline(always)]
fn write_bytes_needed(_: usize) -> WriteResult<usize> {
Ok(size_of::<$type>())
}
}
};
}
impl_fix_int!(u8);
impl_fix_int!(u16);
impl_fix_int!(u32);
impl_fix_int!(u64);
impl_fix_int!(u128);
impl_fix_int!(i8);
impl_fix_int!(i16);
impl_fix_int!(i32);
impl_fix_int!(i64);
impl_fix_int!(i128);
pub type BincodeLen<const PREALLOCATION_SIZE_LIMIT: usize = PREALLOCATION_SIZE_LIMIT_USE_CONFIG> =
UseIntLen<u64, PREALLOCATION_SIZE_LIMIT>;
#[cfg(feature = "solana-short-vec")]
pub mod short_vec {
pub use solana_short_vec::ShortU16;
use {
super::*,
crate::{
SchemaReadContext,
error::{ReadError, write_length_encoding_overflow},
},
core::mem::MaybeUninit,
};
unsafe impl<'de, C: ConfigCore> SchemaRead<'de, C> for ShortU16 {
type Dst = Self;
#[inline]
fn read(reader: impl Reader<'de>, dst: &mut MaybeUninit<Self::Dst>) -> ReadResult<()> {
let len = decode_short_u16_from_reader(reader)?;
let slot = unsafe { &mut *(&raw mut (*dst.as_mut_ptr()).0).cast::<MaybeUninit<u16>>() };
slot.write(len);
Ok(())
}
}
#[inline]
pub fn decode_short_u16_with_ctx<'de, const N: usize>(
ctx: [u8; N],
reader: impl Reader<'de>,
) -> ReadResult<u16> {
struct Read<const N: usize, R> {
ctx: [u8; N],
reader: R,
}
impl<'a, const N: usize, R> Read<N, R>
where
R: Reader<'a>,
{
#[inline(always)]
fn take_byte<const I: usize>(&mut self) -> ReadResult<u8> {
if I < N {
Ok(self.ctx[I])
} else {
Ok(self.reader.take_byte()?)
}
}
}
let mut reader = Read { ctx, reader };
let b0 = reader.take_byte::<0>()?;
if b0 < 0x80 {
return Ok(b0 as u16);
}
let b1 = reader.take_byte::<1>()?;
if b1 == 0 {
return Err(non_canonical_err());
}
if b1 < 0x80 {
return Ok(((b0 & 0x7f) as u16) | ((b1 as u16) << 7));
}
let b2 = reader.take_byte::<2>()?;
if b2 == 0 {
return Err(non_canonical_err());
}
if b2 > 3 {
return Err(overflow_err());
}
Ok(((b0 & 0x7f) as u16) | (((b1 & 0x7f) as u16) << 7) | ((b2 as u16) << 14))
}
unsafe impl<'de, const N: usize, C: ConfigCore> SchemaReadContext<'de, C, [u8; N]> for ShortU16 {
type Dst = Self;
#[inline]
fn read_with_context(
ctx: [u8; N],
reader: impl Reader<'de>,
dst: &mut MaybeUninit<Self::Dst>,
) -> ReadResult<()> {
let len = decode_short_u16_with_ctx(ctx, reader)?;
dst.write(ShortU16(len));
Ok(())
}
}
unsafe impl<C: ConfigCore> SchemaWrite<C> for ShortU16 {
type Src = Self;
#[inline]
fn size_of(src: &Self::Src) -> WriteResult<usize> {
Ok(short_u16_bytes_needed(src.0))
}
#[inline]
fn write(mut writer: impl Writer, src: &Self::Src) -> WriteResult<()> {
let mut buf = [MaybeUninit::<u8>::uninit(); 3];
let bytes = encode_short_u16(&mut buf, src.0);
writer.write(bytes)?;
Ok(())
}
}
#[inline(always)]
#[allow(clippy::arithmetic_side_effects)]
fn short_u16_bytes_needed(len: u16) -> usize {
1 + (len >= 0x80) as usize + (len >= 0x4000) as usize
}
#[inline(always)]
fn try_short_u16_bytes_needed<T: TryInto<u16>>(len: T) -> WriteResult<usize> {
match len.try_into() {
Ok(len) => Ok(short_u16_bytes_needed(len)),
Err(_) => Err(write_length_encoding_overflow("u16::MAX")),
}
}
#[inline(always)]
fn encode_short_u16(dst: &mut [MaybeUninit<u8>], len: u16) -> &[u8] {
use core::slice::from_raw_parts;
let written = match len {
0..=0x7f => {
dst[0].write(len as u8);
1
}
0x80..=0x3fff => {
dst[0].write(((len & 0x7f) as u8) | 0x80);
dst[1].write((len >> 7) as u8);
2
}
_ => {
dst[0].write(((len & 0x7f) as u8) | 0x80);
dst[1].write((((len >> 7) & 0x7f) as u8) | 0x80);
dst[2].write((len >> 14) as u8);
3
}
};
unsafe { from_raw_parts(dst.as_ptr().cast(), written) }
}
#[cold]
const fn overflow_err() -> ReadError {
ReadError::LengthEncodingOverflow("u16::MAX")
}
#[cold]
const fn non_canonical_err() -> ReadError {
ReadError::InvalidValue("short u16: non-canonical encoding")
}
#[cold]
const fn incomplete_err() -> ReadError {
ReadError::InvalidValue("short u16: unexpected end of input")
}
#[inline]
pub const fn decode_short_u16(bytes: &[u8]) -> ReadResult<(u16, usize)> {
if bytes.is_empty() {
return Err(incomplete_err());
}
let b0 = bytes[0];
if b0 < 0x80 {
return Ok((b0 as u16, 1));
}
if bytes.len() < 2 {
return Err(incomplete_err());
}
let b1 = bytes[1];
if b1 == 0 {
return Err(non_canonical_err());
}
if b1 < 0x80 {
let val = ((b0 & 0x7f) as u16) | ((b1 as u16) << 7);
return Ok((val, 2));
}
if bytes.len() < 3 {
return Err(incomplete_err());
}
let b2 = bytes[2];
if b2 == 0 {
return Err(non_canonical_err());
}
if b2 > 3 {
return Err(overflow_err());
}
let val = ((b0 & 0x7f) as u16) | (((b1 & 0x7f) as u16) << 7) | ((b2 as u16) << 14);
Ok((val, 3))
}
#[inline(always)]
fn decode_short_u16_from_reader<'de>(reader: impl Reader<'de>) -> ReadResult<u16> {
decode_short_u16_with_ctx([], reader)
}
unsafe impl<C: ConfigCore> SeqLen<C> for ShortU16 {
#[inline(always)]
fn read<'de>(reader: impl Reader<'de>) -> ReadResult<usize> {
Ok(decode_short_u16_from_reader(reader)? as usize)
}
#[inline(always)]
fn write(writer: impl Writer, len: usize) -> WriteResult<()> {
if len > u16::MAX as usize {
return Err(write_length_encoding_overflow("u16::MAX"));
}
<ShortU16 as SchemaWrite<C>>::write(writer, &ShortU16(len as u16))
}
#[inline(always)]
fn write_bytes_needed(len: usize) -> WriteResult<usize> {
try_short_u16_bytes_needed(len)
}
}
#[cfg(all(test, feature = "alloc", feature = "derive"))]
mod tests {
use {
super::*,
crate::{containers, io::Cursor, proptest_config::proptest_cfg},
alloc::vec::Vec,
proptest::prelude::*,
solana_short_vec::ShortU16,
wincode_derive::{SchemaRead, SchemaWrite},
};
fn our_short_u16_encode(len: u16) -> Vec<u8> {
let mut buf = Vec::with_capacity(3);
let bytes = encode_short_u16(buf.spare_capacity_mut(), len);
let written = bytes.len();
unsafe { buf.set_len(written) }
buf
}
#[derive(
serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq, SchemaWrite, SchemaRead,
)]
#[wincode(internal)]
struct ShortVecStruct {
#[serde(with = "solana_short_vec")]
#[wincode(with = "containers::Vec<u8, ShortU16>")]
bytes: Vec<u8>,
#[serde(with = "solana_short_vec")]
#[wincode(with = "containers::Vec<[u8; 32], ShortU16>")]
ar: Vec<[u8; 32]>,
}
#[derive(SchemaWrite, SchemaRead, serde::Serialize, serde::Deserialize)]
#[wincode(internal)]
struct ShortVecAsSchema {
short_u16: ShortU16,
}
fn strat_short_vec_struct() -> impl Strategy<Value = ShortVecStruct> {
(
proptest::collection::vec(any::<u8>(), 0..=100),
proptest::collection::vec(any::<[u8; 32]>(), 0..=16),
)
.prop_map(|(bytes, ar)| ShortVecStruct { bytes, ar })
}
#[test]
fn decode_short_u16_with_ctx_uses_only_ctx_when_complete() {
let mut reader = Cursor::new(&[0xff][..]);
let decoded = decode_short_u16_with_ctx([0x80, 0x80, 0x01], &mut reader).unwrap();
assert_eq!(decoded, 0x4000);
assert_eq!(reader.position(), 0);
}
#[test]
fn decode_short_u16_with_ctx_uses_only_ctx_for_one_byte_encoding() {
let mut reader = Cursor::new(&[0xff][..]);
let decoded = decode_short_u16_with_ctx([0x7f], &mut reader).unwrap();
assert_eq!(decoded, 0x7f);
assert_eq!(reader.position(), 0);
}
#[test]
fn decode_short_u16_with_ctx_uses_only_ctx_for_two_byte_encoding() {
let mut reader = Cursor::new(&[0xff][..]);
let decoded = decode_short_u16_with_ctx([0x80, 0x01], &mut reader).unwrap();
assert_eq!(decoded, 0x80);
assert_eq!(reader.position(), 0);
}
#[test]
fn decode_short_u16_with_ctx_stops_after_second_byte_from_reader() {
let mut reader = Cursor::new(&[0x01, 0xff][..]);
let decoded = decode_short_u16_with_ctx([0x80], &mut reader).unwrap();
assert_eq!(decoded, 0x80);
assert_eq!(reader.position(), 1);
}
#[test]
fn decode_short_u16_with_ctx_reads_remaining_bytes_from_reader() {
let mut reader = Cursor::new(&[0x80, 0x01, 0xff][..]);
let decoded = decode_short_u16_with_ctx([0x80], &mut reader).unwrap();
assert_eq!(decoded, 0x4000);
assert_eq!(reader.position(), 2);
}
#[test]
fn decode_short_u16_with_ctx_non_canonical_second_byte_from_reader() {
let mut reader = Cursor::new(&[0x00][..]);
let err = decode_short_u16_with_ctx([0x80], &mut reader).unwrap_err();
assert!(matches!(
err,
ReadError::InvalidValue("short u16: non-canonical encoding")
));
assert_eq!(reader.position(), 1);
}
#[test]
fn decode_short_u16_with_ctx_incomplete_second_byte_from_reader() {
let mut reader = Cursor::new(&[][..]);
let err = decode_short_u16_with_ctx([0x80], &mut reader).unwrap_err();
assert!(matches!(
err,
ReadError::Io(crate::io::ReadError::ReadSizeLimit(1))
));
assert_eq!(reader.position(), 0);
}
#[test]
fn decode_short_u16_with_ctx_non_canonical_third_byte_from_reader() {
let mut reader = Cursor::new(&[0x00][..]);
let err = decode_short_u16_with_ctx([0x80, 0x80], &mut reader).unwrap_err();
assert!(matches!(
err,
ReadError::InvalidValue("short u16: non-canonical encoding")
));
assert_eq!(reader.position(), 1);
}
#[test]
fn decode_short_u16_with_ctx_incomplete_third_byte_from_reader() {
let mut reader = Cursor::new(&[][..]);
let err = decode_short_u16_with_ctx([0x80, 0x80], &mut reader).unwrap_err();
assert!(matches!(
err,
ReadError::Io(crate::io::ReadError::ReadSizeLimit(1))
));
assert_eq!(reader.position(), 0);
}
#[test]
fn decode_short_u16_with_ctx_overflow_third_byte_from_reader() {
let mut reader = Cursor::new(&[0x04][..]);
let err = decode_short_u16_with_ctx([0x80, 0x80], &mut reader).unwrap_err();
assert!(matches!(err, ReadError::LengthEncodingOverflow("u16::MAX")));
assert_eq!(reader.position(), 1);
}
#[test]
fn decode_short_u16_with_ctx_non_canonical_second_byte_in_ctx() {
let mut reader = Cursor::new(&[0xff][..]);
let err = decode_short_u16_with_ctx([0x80, 0x00], &mut reader).unwrap_err();
assert!(matches!(
err,
ReadError::InvalidValue("short u16: non-canonical encoding")
));
assert_eq!(reader.position(), 0);
}
proptest! {
#![proptest_config(proptest_cfg())]
#[test]
fn encode_u16_equivalence(len in 0..=u16::MAX) {
let our = our_short_u16_encode(len);
let bincode = bincode::serialize(&ShortU16(len)).unwrap();
prop_assert_eq!(our, bincode);
}
#[test]
fn test_short_vec_struct(short_vec_struct in strat_short_vec_struct()) {
let bincode_serialized = bincode::serialize(&short_vec_struct).unwrap();
let schema_serialized = crate::serialize(&short_vec_struct).unwrap();
prop_assert_eq!(&bincode_serialized, &schema_serialized);
let bincode_deserialized: ShortVecStruct = bincode::deserialize(&bincode_serialized).unwrap();
let schema_deserialized: ShortVecStruct = crate::deserialize(&schema_serialized).unwrap();
prop_assert_eq!(&short_vec_struct, &bincode_deserialized);
prop_assert_eq!(short_vec_struct, schema_deserialized);
}
#[test]
fn encode_decode_short_u16_roundtrip(len in 0..=u16::MAX) {
let our = our_short_u16_encode(len);
let (decoded_len, read) = decode_short_u16(&our).unwrap();
let (sdk_decoded_len, sdk_read) = solana_short_vec::decode_shortu16_len(&our).unwrap();
let sdk_decoded_len = sdk_decoded_len as u16;
prop_assert_eq!(len, decoded_len);
prop_assert_eq!(len, sdk_decoded_len);
prop_assert_eq!(read, sdk_read);
}
#[test]
fn decode_short_u16_err_equivalence(bytes in prop::collection::vec(any::<u8>(), 0..=3)) {
let our_decode = decode_short_u16(&bytes);
let sdk_decode = solana_short_vec::decode_shortu16_len(&bytes);
prop_assert_eq!(our_decode.is_err(), sdk_decode.is_err());
prop_assert_eq!(our_decode.is_ok(), sdk_decode.is_ok());
}
#[test]
fn test_short_vec_as_schema(sv in any::<u16>()) {
let val = ShortVecAsSchema { short_u16: ShortU16(sv) };
let bincode_serialized = bincode::serialize(&val).unwrap();
let wincode_serialized = crate::serialize(&val).unwrap();
prop_assert_eq!(&bincode_serialized, &wincode_serialized);
let bincode_deserialized: ShortVecAsSchema = bincode::deserialize(&bincode_serialized).unwrap();
let wincode_deserialized: ShortVecAsSchema = crate::deserialize(&wincode_serialized).unwrap();
prop_assert_eq!(val.short_u16.0, bincode_deserialized.short_u16.0);
prop_assert_eq!(val.short_u16.0, wincode_deserialized.short_u16.0);
}
}
}
}
#[cfg(feature = "solana-short-vec")]
pub use short_vec::*;