use {
crate::{
bytemuck::{
pod_from_bytes, pod_from_bytes_mut, pod_slice_from_bytes, pod_slice_from_bytes_mut,
},
error::PodSliceError,
list::{list_view_mut::ListViewMut, list_view_read_only::ListViewReadOnly},
pod_length::PodLength,
primitives::PodU32,
},
bytemuck::Pod,
solana_program_error::ProgramError,
std::{
marker::PhantomData,
mem::{align_of, size_of},
ops::Range,
},
};
pub struct ListView<T: Pod, L: PodLength = PodU32>(PhantomData<(T, L)>);
struct Layout {
length_range: Range<usize>,
data_range: Range<usize>,
}
impl<T: Pod, L: PodLength> ListView<T, L> {
pub fn size_of(num_items: usize) -> Result<usize, ProgramError> {
let header_padding = Self::header_padding()?;
size_of::<T>()
.checked_mul(num_items)
.and_then(|curr| curr.checked_add(size_of::<L>()))
.and_then(|curr| curr.checked_add(header_padding))
.ok_or_else(|| PodSliceError::CalculationFailure.into())
}
pub fn unpack(buf: &[u8]) -> Result<ListViewReadOnly<T, L>, ProgramError> {
let layout = Self::calculate_layout(buf.len())?;
let len_bytes = &buf[layout.length_range];
let data_bytes = &buf[layout.data_range];
let length = pod_from_bytes::<L>(len_bytes)?;
let data = pod_slice_from_bytes::<T>(data_bytes)?;
let capacity = data.len();
if (*length).into() > capacity {
return Err(PodSliceError::BufferTooSmall.into());
}
Ok(ListViewReadOnly {
length,
data,
capacity,
})
}
pub fn unpack_mut(buf: &mut [u8]) -> Result<ListViewMut<T, L>, ProgramError> {
let view = Self::build_mut_view(buf)?;
if (*view.length).into() > view.capacity {
return Err(PodSliceError::BufferTooSmall.into());
}
Ok(view)
}
#[inline]
fn build_mut_view(buf: &mut [u8]) -> Result<ListViewMut<T, L>, ProgramError> {
let layout = Self::calculate_layout(buf.len())?;
let (header_bytes, data_bytes) = buf.split_at_mut(layout.data_range.start);
let len_bytes = &mut header_bytes[layout.length_range];
let length = pod_from_bytes_mut::<L>(len_bytes)?;
let data = pod_slice_from_bytes_mut::<T>(data_bytes)?;
let capacity = data.len();
Ok(ListViewMut {
length,
data,
capacity,
})
}
#[inline]
fn calculate_layout(buf_len: usize) -> Result<Layout, ProgramError> {
let len_field_end = size_of::<L>();
let header_padding = Self::header_padding()?;
let data_start = len_field_end.saturating_add(header_padding);
if buf_len < data_start {
return Err(PodSliceError::BufferTooSmall.into());
}
Ok(Layout {
length_range: 0..len_field_end,
data_range: data_start..buf_len,
})
}
#[inline]
fn header_padding() -> Result<usize, ProgramError> {
if align_of::<L>() != 1 {
return Err(ProgramError::InvalidArgument);
}
let length_size = size_of::<L>();
let data_align = align_of::<T>();
if data_align == 0 || data_align == 1 {
return Ok(0);
}
#[allow(clippy::arithmetic_side_effects)]
let remainder = length_size.wrapping_rem(data_align);
if remainder == 0 {
Ok(0)
} else {
Ok(data_align.wrapping_sub(remainder))
}
}
}
impl<T: Pod, L> ListView<T, L>
where
L: PodLength,
PodSliceError: From<<L as TryFrom<usize>>::Error>,
{
pub fn init(buf: &mut [u8]) -> Result<ListViewMut<T, L>, ProgramError> {
let view = Self::build_mut_view(buf)?;
*view.length = L::try_from(0).map_err(PodSliceError::from)?;
Ok(view)
}
}
#[cfg(test)]
mod tests {
#[cfg(not(target_arch = "bpf"))]
use crate::primitives::PodU128;
use {
super::*,
crate::{
list::List,
primitives::{PodU16, PodU32, PodU64},
},
bytemuck_derive::{Pod as DerivePod, Zeroable},
};
#[test]
fn test_size_of_no_padding() {
assert_eq!(ListView::<u8, PodU32>::size_of(10).unwrap(), 14);
assert_eq!(ListView::<u32>::size_of(10).unwrap(), 44);
assert_eq!(ListView::<u32>::size_of(0).unwrap(), 4);
}
#[test]
fn test_size_of_with_padding() {
assert_eq!(ListView::<u64, PodU32>::size_of(10).unwrap(), 88);
#[repr(C, align(16))]
#[derive(DerivePod, Zeroable, Copy, Clone)]
struct Align16(u128);
assert_eq!(ListView::<Align16>::size_of(10).unwrap(), 176);
assert_eq!(ListView::<u64, PodU32>::size_of(0).unwrap(), 8);
}
#[test]
fn test_size_of_overflow() {
let err = ListView::<u16, PodU32>::size_of(usize::MAX).unwrap_err();
assert_eq!(err, PodSliceError::CalculationFailure.into());
let err = ListView::<u8, PodU32>::size_of(usize::MAX).unwrap_err();
assert_eq!(err, PodSliceError::CalculationFailure.into());
}
#[test]
fn test_fails_with_non_aligned_length_type() {
#[repr(C, align(4))]
#[derive(Debug, Copy, Clone, Zeroable, DerivePod)]
struct TestPodU32(u32);
impl From<TestPodU32> for usize {
fn from(val: TestPodU32) -> Self {
val.0 as usize
}
}
impl TryFrom<usize> for TestPodU32 {
type Error = PodSliceError;
fn try_from(val: usize) -> Result<Self, Self::Error> {
Ok(Self(u32::try_from(val)?))
}
}
let mut buf = [0u8; 100];
let err_size_of = ListView::<u8, TestPodU32>::size_of(10).unwrap_err();
assert_eq!(err_size_of, ProgramError::InvalidArgument);
let err_unpack = ListView::<u8, TestPodU32>::unpack(&buf).unwrap_err();
assert_eq!(err_unpack, ProgramError::InvalidArgument);
let err_init = ListView::<u8, TestPodU32>::init(&mut buf).unwrap_err();
assert_eq!(err_init, ProgramError::InvalidArgument);
}
#[test]
fn test_padding_calculation() {
assert_eq!(ListView::<u8, PodU32>::header_padding().unwrap(), 0);
assert_eq!(ListView::<(), PodU64>::header_padding().unwrap(), 0);
assert_eq!(ListView::<u16, PodU16>::header_padding().unwrap(), 0);
assert_eq!(ListView::<u32, PodU32>::header_padding().unwrap(), 0);
assert_eq!(ListView::<u64, PodU64>::header_padding().unwrap(), 0);
assert_eq!(ListView::<u16, PodU64>::header_padding().unwrap(), 0); assert_eq!(ListView::<u32, PodU64>::header_padding().unwrap(), 0);
assert_eq!(ListView::<u32, PodU16>::header_padding().unwrap(), 2); assert_eq!(ListView::<u64, PodU16>::header_padding().unwrap(), 6); assert_eq!(ListView::<u64, PodU32>::header_padding().unwrap(), 4);
#[repr(C, align(8))]
#[derive(DerivePod, Zeroable, Copy, Clone)]
struct Align8(u64);
assert_eq!(ListView::<Align8, PodU16>::header_padding().unwrap(), 6); assert_eq!(ListView::<Align8, PodU32>::header_padding().unwrap(), 4); assert_eq!(ListView::<Align8, PodU64>::header_padding().unwrap(), 0);
#[repr(C, align(16))]
#[derive(DerivePod, Zeroable, Copy, Clone)]
struct Align16(u128);
assert_eq!(ListView::<Align16, PodU16>::header_padding().unwrap(), 14); assert_eq!(ListView::<Align16, PodU32>::header_padding().unwrap(), 12); assert_eq!(ListView::<Align16, PodU64>::header_padding().unwrap(), 8); }
#[test]
fn test_unpack_success_no_padding() {
let length: u32 = 2;
let capacity: usize = 3;
let item_size = size_of::<u32>();
let len_size = size_of::<PodU32>();
let buf_size = len_size + capacity * item_size;
let mut buf = vec![0u8; buf_size];
let pod_len: PodU32 = length.into();
buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len));
let data_start = len_size;
let items = [100u32, 200u32];
let items_bytes = bytemuck::cast_slice(&items);
buf[data_start..(data_start + items_bytes.len())].copy_from_slice(items_bytes);
let view_ro = ListView::<u32, PodU32>::unpack(&buf).unwrap();
assert_eq!(view_ro.len(), length as usize);
assert_eq!(view_ro.capacity(), capacity);
assert_eq!(*view_ro, items[..]);
let view_mut = ListView::<u32, PodU32>::unpack_mut(&mut buf).unwrap();
assert_eq!(view_mut.len(), length as usize);
assert_eq!(view_mut.capacity(), capacity);
assert_eq!(*view_mut, items[..]);
}
#[test]
fn test_unpack_success_with_padding() {
let padding = ListView::<u64, PodU32>::header_padding().unwrap();
assert_eq!(padding, 4);
let length: u32 = 2;
let capacity: usize = 2;
let item_size = size_of::<u64>();
let len_size = size_of::<PodU32>();
let buf_size = len_size + padding + capacity * item_size;
let mut buf = vec![0u8; buf_size];
let pod_len: PodU32 = length.into();
buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len));
let data_start = len_size + padding;
let items = [100u64, 200u64];
let items_bytes = bytemuck::cast_slice(&items);
buf[data_start..(data_start + items_bytes.len())].copy_from_slice(items_bytes);
let view_ro = ListView::<u64, PodU32>::unpack(&buf).unwrap();
assert_eq!(view_ro.len(), length as usize);
assert_eq!(view_ro.capacity(), capacity);
assert_eq!(*view_ro, items[..]);
let view_mut = ListView::<u64, PodU32>::unpack_mut(&mut buf).unwrap();
assert_eq!(view_mut.len(), length as usize);
assert_eq!(view_mut.capacity(), capacity);
assert_eq!(*view_mut, items[..]);
}
#[test]
fn test_unpack_success_zero_length() {
let capacity: usize = 5;
let item_size = size_of::<u32>();
let len_size = size_of::<PodU32>();
let buf_size = len_size + capacity * item_size;
let mut buf = vec![0u8; buf_size];
let pod_len: PodU32 = 0u32.into();
buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len));
let view_ro = ListView::<u32, PodU32>::unpack(&buf).unwrap();
assert_eq!(view_ro.len(), 0);
assert_eq!(view_ro.capacity(), capacity);
assert!(view_ro.is_empty());
assert_eq!(&*view_ro, &[] as &[u32]);
let view_mut = ListView::<u32, PodU32>::unpack_mut(&mut buf).unwrap();
assert_eq!(view_mut.len(), 0);
assert_eq!(view_mut.capacity(), capacity);
assert!(view_mut.is_empty());
assert_eq!(&*view_mut, &[] as &[u32]);
}
#[test]
fn test_unpack_success_full_capacity() {
let length: u64 = 3;
let capacity: usize = 3;
let item_size = size_of::<u64>();
let len_size = size_of::<PodU64>();
let buf_size = len_size + capacity * item_size;
let mut buf = vec![0u8; buf_size];
let pod_len: PodU64 = length.into();
buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len));
let data_start = len_size;
let items = [1u64, 2u64, 3u64];
let items_bytes = bytemuck::cast_slice(&items);
buf[data_start..].copy_from_slice(items_bytes);
let view_ro = ListView::<u64>::unpack(&buf).unwrap();
assert_eq!(view_ro.len(), length as usize);
assert_eq!(view_ro.capacity(), capacity);
assert_eq!(*view_ro, items[..]);
let view_mut = ListView::<u64>::unpack_mut(&mut buf).unwrap();
assert_eq!(view_mut.len(), length as usize);
assert_eq!(view_mut.capacity(), capacity);
assert_eq!(*view_mut, items[..]);
}
#[test]
fn test_unpack_fail_buffer_too_small_for_header() {
let header_size = ListView::<u64, PodU32>::size_of(0).unwrap();
assert_eq!(header_size, 8);
let mut buf = vec![0u8; header_size - 1];
let err = ListView::<u64, PodU32>::unpack(&buf).unwrap_err();
assert_eq!(err, PodSliceError::BufferTooSmall.into());
let err = ListView::<u64, PodU32>::unpack_mut(&mut buf).unwrap_err();
assert_eq!(err, PodSliceError::BufferTooSmall.into());
}
#[test]
fn test_unpack_fail_declared_length_exceeds_capacity() {
let declared_length: u32 = 4;
let capacity: usize = 3; let item_size = size_of::<u32>();
let len_size = size_of::<PodU32>();
let buf_size = len_size + capacity * item_size;
let mut buf = vec![0u8; buf_size];
let pod_len: PodU32 = declared_length.into();
buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len));
let err = ListView::<u32, PodU32>::unpack(&buf).unwrap_err();
assert_eq!(err, PodSliceError::BufferTooSmall.into());
let err = ListView::<u32, PodU32>::unpack_mut(&mut buf).unwrap_err();
assert_eq!(err, PodSliceError::BufferTooSmall.into());
}
#[test]
fn test_unpack_fail_data_part_not_multiple_of_item_size() {
let len_size = size_of::<PodU32>();
let buf_size = len_size + 5;
let mut buf = vec![0u8; buf_size];
let err = ListView::<u32, PodU32>::unpack(&buf).unwrap_err();
assert_eq!(err, ProgramError::InvalidArgument);
let err = ListView::<u32, PodU32>::unpack_mut(&mut buf).unwrap_err();
assert_eq!(err, ProgramError::InvalidArgument);
}
#[test]
fn test_unpack_empty_buffer() {
let mut buf = [];
let err = ListView::<u32, PodU32>::unpack(&buf).unwrap_err();
assert_eq!(err, PodSliceError::BufferTooSmall.into());
let err = ListView::<u32, PodU32>::unpack_mut(&mut buf).unwrap_err();
assert_eq!(err, PodSliceError::BufferTooSmall.into());
}
#[test]
fn test_init_success_no_padding() {
let capacity: usize = 5;
let len_size = size_of::<PodU32>();
let buf_size = ListView::<u32, PodU32>::size_of(capacity).unwrap();
let mut buf = vec![0xFFu8; buf_size];
let view = ListView::<u32, PodU32>::init(&mut buf).unwrap();
assert_eq!(view.len(), 0);
assert_eq!(view.capacity(), capacity);
assert!(view.is_empty());
let length_bytes = &buf[0..len_size];
assert_eq!(length_bytes, &[0u8; 4]);
}
#[test]
fn test_init_success_with_padding() {
let capacity: usize = 3;
let len_size = size_of::<PodU32>();
let buf_size = ListView::<u64, PodU32>::size_of(capacity).unwrap();
let mut buf = vec![0xFFu8; buf_size];
let view = ListView::<u64, PodU32>::init(&mut buf).unwrap();
assert_eq!(view.len(), 0);
assert_eq!(view.capacity(), capacity);
assert!(view.is_empty());
let length_bytes = &buf[0..len_size];
assert_eq!(length_bytes, &[0u8; 4]);
}
#[test]
fn test_init_success_zero_capacity() {
let buf_size = ListView::<u64, PodU32>::size_of(0).unwrap();
assert_eq!(buf_size, 8);
let mut buf = vec![0xFFu8; buf_size];
let view = ListView::<u64, PodU32>::init(&mut buf).unwrap();
assert_eq!(view.len(), 0);
assert_eq!(view.capacity(), 0);
assert!(view.is_empty());
let len_size = size_of::<PodU32>();
let length_bytes = &buf[0..len_size];
assert_eq!(length_bytes, &[0u8; 4]);
}
#[test]
fn test_init_fail_buffer_too_small() {
let mut buf = vec![0u8; 3];
let err = ListView::<u32, PodU32>::init(&mut buf).unwrap_err();
assert_eq!(err, PodSliceError::BufferTooSmall.into());
let mut buf_padded = vec![0u8; 7];
let err_padded = ListView::<u64, PodU32>::init(&mut buf_padded).unwrap_err();
assert_eq!(err_padded, PodSliceError::BufferTooSmall.into());
}
#[test]
fn test_init_success_default_length_type() {
let capacity = 5;
let len_size = size_of::<PodU32>(); let buf_size = ListView::<u32>::size_of(capacity).unwrap();
let mut buf = vec![0xFFu8; buf_size];
let view = ListView::<u32>::init(&mut buf).unwrap();
assert_eq!(view.len(), 0);
assert_eq!(view.capacity(), capacity);
assert!(view.is_empty());
let length_bytes = &buf[0..len_size];
assert_eq!(length_bytes, &[0u8; 4]);
}
macro_rules! test_list_view_for_length_type {
($test_name:ident, $LengthType:ty) => {
#[test]
fn $test_name() {
type T = u64;
let padding = ListView::<T, $LengthType>::header_padding().unwrap();
let length_usize = 2usize;
let capacity = 3;
let item_size = size_of::<T>();
let len_size = size_of::<$LengthType>();
let buf_size = len_size + padding + capacity * item_size;
let mut buf = vec![0u8; buf_size];
let pod_len = <$LengthType>::try_from(length_usize).unwrap();
buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len));
let data_start = len_size + padding;
let items = [1000 as T, 2000 as T];
let items_bytes = bytemuck::cast_slice(&items);
buf[data_start..(data_start + items_bytes.len())].copy_from_slice(items_bytes);
let view_ro = ListView::<T, $LengthType>::unpack(&buf).unwrap();
assert_eq!(view_ro.len(), length_usize);
assert_eq!(view_ro.capacity(), capacity);
assert_eq!(*view_ro, items[..]);
let mut buf_mut = buf.clone();
let view_mut = ListView::<T, $LengthType>::unpack_mut(&mut buf_mut).unwrap();
assert_eq!(view_mut.len(), length_usize);
assert_eq!(view_mut.capacity(), capacity);
assert_eq!(*view_mut, items[..]);
let mut init_buf = vec![0xFFu8; buf_size];
let init_view = ListView::<T, $LengthType>::init(&mut init_buf).unwrap();
assert_eq!(init_view.len(), 0);
assert_eq!(init_view.capacity(), capacity);
assert_eq!(<$LengthType>::try_from(0usize).unwrap(), *init_view.length);
}
};
}
test_list_view_for_length_type!(list_view_with_pod_u16, PodU16);
test_list_view_for_length_type!(list_view_with_pod_u32, PodU32);
test_list_view_for_length_type!(list_view_with_pod_u64, PodU64);
#[cfg(not(target_arch = "bpf"))]
test_list_view_for_length_type!(list_view_with_pod_u128, PodU128);
}