use crate::{
align1::Align1,
bail,
data_types::PackedValue,
ensure, error,
errors::ErrorInfo as _,
unsize::{
init::{DefaultInit, UnsizedInit},
wrapper::ExclusiveRecurse,
FromOwned, RawSliceAdvance, UnsizedType, UnsizedTypePtr,
},
util::uninit_array_bytes,
ErrorCode, Result,
};
use advancer::Advance;
use bytemuck::{
bytes_of, cast_slice, cast_slice_mut, checked, CheckedBitPattern, NoUninit, Pod, Zeroable,
};
use itertools::Itertools;
use num_traits::{FromPrimitive, ToPrimitive, Zero};
use ptr_meta::Pointee;
use star_frame_proc::unsized_impl;
use std::{
any::type_name,
borrow::Borrow,
cmp::Ordering,
iter,
iter::FusedIterator,
marker::PhantomData,
mem::size_of,
ops::{Deref, DerefMut, Index, IndexMut, RangeBounds},
};
pub trait ListLength: Pod + ToPrimitive + FromPrimitive {}
impl<T> ListLength for T where T: Pod + ToPrimitive + FromPrimitive {}
#[derive(Debug, Pointee)]
#[repr(C)]
pub struct List<T, L = u32>
where
L: ListLength,
T: CheckedBitPattern + NoUninit + Align1,
{
len: PackedValue<L>,
phantom_t: PhantomData<fn() -> T>,
bytes: [u8],
}
#[cfg(all(feature = "idl", not(target_os = "solana")))]
mod idl_impl {
use super::*;
use crate::{idl::TypeToIdl, prelude::System};
use star_frame_idl::{ty::IdlTypeDef, IdlDefinition};
impl<T, L> TypeToIdl for List<T, L>
where
T: CheckedBitPattern + NoUninit + Align1 + TypeToIdl,
L: ListLength + TypeToIdl,
{
type AssociatedProgram = System;
fn type_to_idl(idl_definition: &mut IdlDefinition) -> crate::IdlResult<IdlTypeDef> {
let inner_type = T::type_to_idl(idl_definition)?;
Ok(IdlTypeDef::List {
item_ty: Box::new(inner_type),
len_ty: Box::new(L::type_to_idl(idl_definition)?),
})
}
}
}
impl<T, L> List<T, L>
where
T: CheckedBitPattern + NoUninit + Align1,
L: ListLength,
{
#[inline]
pub fn len(&self) -> usize {
let len = self
.len
.to_usize()
.expect("List size should convert to usize");
debug_assert_eq!(len, self.bytes.len() / size_of::<T>());
len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn get(&self, index: usize) -> Option<&T> {
if index < self.len() {
Some(&self[index])
} else {
None
}
}
#[inline]
pub fn get_mut(&mut self, index: usize) -> Option<&mut T> {
if index < self.len() {
Some(&mut self[index])
} else {
None
}
}
#[inline]
pub fn as_slice(&self) -> &[T]
where
T: Pod,
{
cast_slice(&self.bytes)
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [T]
where
T: Pod,
{
cast_slice_mut(&mut self.bytes)
}
#[inline]
pub fn as_checked_slice(&self) -> Result<&[T]> {
checked::try_cast_slice(&self.bytes).map_err(Into::into)
}
#[inline]
pub fn as_checked_mut_slice(&mut self) -> Result<&mut [T]> {
checked::try_cast_slice_mut(&mut self.bytes).map_err(Into::into)
}
pub fn iter(&self) -> ListIter<'_, T, L> {
ListIter {
list: self,
index: 0,
}
}
pub fn iter_mut(&mut self) -> ListIterMut<'_, T, L> {
ListIterMut {
list_bytes_ptr: &raw mut self.bytes,
remaining: self.len(),
phantom_data: Default::default(),
}
}
pub fn binary_search(&self, x: &T) -> std::result::Result<usize, usize>
where
T: Ord,
{
Self::binary_search_by(self, |p| p.cmp(x))
}
pub fn binary_search_by<F>(&self, mut f: F) -> Result<usize, usize>
where
F: FnMut(&T) -> Ordering,
{
let size = self.len();
let mut left = 0;
let mut right = size;
while left < right {
let mid = (left + right) / 2;
match f(&self[mid]) {
Ordering::Less => left = mid + 1,
Ordering::Equal => return Ok(mid),
Ordering::Greater => right = mid,
}
}
Err(left)
}
}
impl<T, L> Deref for List<T, L>
where
L: ListLength,
T: Pod + Align1,
{
type Target = [T];
fn deref(&self) -> &Self::Target {
cast_slice(&self.bytes)
}
}
impl<T, L> DerefMut for List<T, L>
where
L: ListLength,
T: Pod + Align1,
{
fn deref_mut(&mut self) -> &mut Self::Target {
cast_slice_mut(&mut self.bytes)
}
}
unsafe impl<T, L> Align1 for List<T, L>
where
T: Align1 + CheckedBitPattern + NoUninit,
L: ListLength,
{
}
impl<T, L> Index<usize> for List<T, L>
where
T: CheckedBitPattern + NoUninit + Align1,
L: ListLength,
{
type Output = T;
fn index(&self, index: usize) -> &Self::Output {
checked::try_from_bytes(&self.bytes[index * size_of::<T>()..][..size_of::<T>()])
.expect("Invalid data for index")
}
}
impl<T, L> IndexMut<usize> for List<T, L>
where
T: CheckedBitPattern + NoUninit + Align1,
L: ListLength,
{
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
checked::try_from_bytes_mut(&mut self.bytes[index * size_of::<T>()..][..size_of::<T>()])
.expect("Invalid data for index")
}
}
fn get_bounds<T, L>(list: &List<T, L>, range: impl RangeBounds<usize>) -> (usize, usize)
where
T: CheckedBitPattern + NoUninit + Align1,
L: ListLength,
{
let start = match range.start_bound() {
std::ops::Bound::Included(&start) => start * size_of::<T>(),
std::ops::Bound::Excluded(&start) => (start + 1) * size_of::<T>(),
std::ops::Bound::Unbounded => 0,
};
let end = match range.end_bound() {
std::ops::Bound::Included(&end) => (end + 1) * size_of::<T>(),
std::ops::Bound::Excluded(&end) => end * size_of::<T>(),
std::ops::Bound::Unbounded => list.len.to_usize().expect("Invalid length") * size_of::<T>(),
};
(start, end)
}
trait GetBounds: RangeBounds<usize> {}
impl GetBounds for std::ops::RangeFull {}
impl GetBounds for std::ops::Range<usize> {}
impl GetBounds for std::ops::RangeFrom<usize> {}
impl GetBounds for std::ops::RangeTo<usize> {}
impl GetBounds for std::ops::RangeInclusive<usize> {}
impl GetBounds for std::ops::RangeToInclusive<usize> {}
impl<T, L, R> Index<R> for List<T, L>
where
T: CheckedBitPattern + NoUninit + Align1,
L: ListLength,
R: GetBounds,
{
type Output = [T];
fn index(&self, index: R) -> &Self::Output {
let (start, end) = get_bounds(self, index);
checked::try_cast_slice(&self.bytes[start..end]).expect("Invalid data for range")
}
}
impl<T, L, R> IndexMut<R> for List<T, L>
where
T: CheckedBitPattern + NoUninit + Align1,
L: ListLength,
R: GetBounds,
{
fn index_mut(&mut self, index: R) -> &mut Self::Output {
let (start, end) = get_bounds(self, index);
checked::try_cast_slice_mut(&mut self.bytes[start..end]).expect("Invalid data for range")
}
}
#[derive(Debug)]
pub struct ListPtr<T, L = u32>(*mut List<T, L>)
where
L: ListLength,
T: CheckedBitPattern + NoUninit + Align1;
impl<T, L> Deref for ListPtr<T, L>
where
L: ListLength,
T: CheckedBitPattern + NoUninit + Align1,
{
type Target = List<T, L>;
fn deref(&self) -> &Self::Target {
unsafe { &*self.0 }
}
}
impl<T, L> DerefMut for ListPtr<T, L>
where
L: ListLength,
T: CheckedBitPattern + NoUninit + Align1,
{
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.0 }
}
}
unsafe impl<T, L> UnsizedTypePtr for ListPtr<T, L>
where
L: ListLength,
T: CheckedBitPattern + NoUninit + Align1,
{
type UnsizedType = List<T, L>;
fn check_pointers(&self, range: &std::ops::Range<usize>, cursor: &mut usize) -> bool {
let addr = self.0.addr();
let is_advanced = addr >= *cursor;
*cursor = addr;
is_advanced && range.contains(&addr)
}
}
unsafe impl<T, L> UnsizedType for List<T, L>
where
L: ListLength,
T: Align1 + CheckedBitPattern + NoUninit,
{
type Ptr = ListPtr<T, L>;
type Owned = Vec<T>;
const ZST_STATUS: bool = { size_of::<L>() != 0 };
unsafe fn get_ptr(data: &mut *mut [u8]) -> Result<Self::Ptr> {
let len_ptr = data.try_advance(size_of::<L>()).with_ctx(|| {
format!(
"Failed to read length bytes of size {} for {}",
size_of::<L>(),
type_name::<Self>()
)
})?;
let len_l: L = bytemuck::try_pod_read_unaligned(unsafe { &*len_ptr })?;
let length = len_l.to_usize().ok_or_else(|| {
error!(
ErrorCode::ToPrimitiveError,
"Could not convert list size to usize"
)
})?;
data.try_advance(size_of::<T>() * length).with_ctx(|| {
format!(
"Failed to read mutable list elements of total size {} for {}",
size_of::<T>() * length,
type_name::<Self>()
)
})?;
Ok(ListPtr(ptr_meta::from_raw_parts_mut(
len_ptr.cast::<()>(),
size_of::<T>() * length,
)))
}
#[inline]
fn data_len(m: &Self::Ptr) -> usize {
m.bytes.len() + size_of::<L>()
}
#[inline]
fn start_ptr(m: &Self::Ptr) -> *mut () {
m.0.cast::<()>()
}
fn owned_from_ptr(r: &Self::Ptr) -> Result<Self::Owned> {
Ok(checked::try_cast_slice(&r.bytes)?.to_vec())
}
unsafe fn resize_notification(
self_mut: &mut Self::Ptr,
source_ptr: *const (),
change: isize,
) -> Result<()> {
let self_ptr = self_mut.0;
if source_ptr < self_ptr.cast_const().cast() {
self_mut.0 = self_ptr.wrapping_byte_offset(change);
}
Ok(())
}
}
impl<T, L> List<T, L>
where
L: ListLength,
T: Align1 + CheckedBitPattern + NoUninit,
{
pub(super) fn byte_size_from_len(len: usize) -> usize {
size_of::<L>() + size_of::<T>() * len
}
pub(super) fn from_owned_from_iter<I>(items: I, bytes: &mut &mut [u8]) -> Result<usize>
where
I: IntoIterator<Item = T>,
I::IntoIter: ExactSizeIterator,
{
let items = items.into_iter();
let len = items.len();
bytes
.try_advance(size_of::<L>())?
.copy_from_slice(bytes_of(&L::from_usize(len).unwrap()));
for item in items {
bytes
.try_advance(size_of::<T>())?
.copy_from_slice(bytes_of(&item));
}
Ok(Self::byte_size_from_len(len))
}
}
impl<T, L> FromOwned for List<T, L>
where
L: ListLength,
T: Align1 + CheckedBitPattern + NoUninit,
{
fn byte_size(owned: &Self::Owned) -> usize {
Self::byte_size_from_len(owned.len())
}
fn from_owned(owned: Self::Owned, bytes: &mut &mut [u8]) -> Result<usize> {
Self::from_owned_from_iter(owned, bytes)
}
}
#[unsized_impl]
impl<T, L> List<T, L>
where
T: Align1 + NoUninit + CheckedBitPattern,
L: ListLength,
{
#[inline]
pub fn push(&mut self, item: T) -> Result<()> {
let len = self.len();
self.insert(len, item)
}
#[inline]
pub fn push_all<I>(&mut self, items: I) -> Result<()>
where
I: IntoIterator<Item = T>,
I::IntoIter: ExactSizeIterator,
{
self.insert_all(self.len(), items)
}
#[inline]
pub fn insert(&mut self, index: usize, item: T) -> Result<()> {
self.insert_all(index, iter::once(item))
}
pub fn insert_all<I>(&mut self, index: usize, items: I) -> Result<()>
where
I: IntoIterator,
I::IntoIter: ExactSizeIterator,
I::Item: Borrow<T>,
{
let iter = items.into_iter();
let to_add = iter.len();
let byte_index = index * size_of::<T>();
let (end_ptr, old_len, new_len, source_ptr) = {
let list: &mut List<T, L> = self;
let old_len = list.len();
if index > old_len {
bail!(
ErrorCode::IndexOutOfBounds,
"Index {index} is out of bounds for list of length {old_len}"
);
}
let new_len = L::from_usize(old_len + to_add).ok_or_else(|| {
error!(
ErrorCode::ToPrimitiveError,
"Failed to convert new len to L"
)
})?;
let end_ptr = list.bytes.as_mut_ptr().wrapping_add(byte_index).cast();
(end_ptr, old_len, new_len, self.0.cast_const().cast::<()>())
};
unsafe {
ExclusiveRecurse::add_bytes(self, source_ptr, end_ptr, size_of::<T>() * to_add)?;
};
self.len = PackedValue(new_len);
self.0 =
ptr_meta::from_raw_parts_mut(self.0.cast::<()>(), (old_len + to_add) * size_of::<T>());
for ((i, value), _) in iter.enumerate().zip_eq(0..to_add) {
let bytes = &mut self.bytes;
bytes[byte_index + i * size_of::<T>()..][..size_of::<T>()]
.copy_from_slice(bytes_of(value.borrow()));
}
Ok(())
}
#[inline]
pub fn pop(&mut self) -> Result<Option<()>> {
if self.is_empty() {
return Ok(None);
}
self.remove(self.len() - 1).map(Some)
}
#[inline]
pub fn remove(&mut self, index: usize) -> Result<()> {
self.remove_range(index..=index)
}
pub fn remove_range(&mut self, indices: impl RangeBounds<usize>) -> Result<()> {
let start = match indices.start_bound() {
std::ops::Bound::Included(start) => *start,
std::ops::Bound::Excluded(start) => start + 1,
std::ops::Bound::Unbounded => 0,
};
let end = match indices.end_bound() {
std::ops::Bound::Included(end) => *end + 1,
std::ops::Bound::Excluded(end) => *end,
std::ops::Bound::Unbounded => self.len(),
};
ensure!(start <= end, ErrorCode::InvalidRange);
ensure!(
end <= self.len(),
ErrorCode::IndexOutOfBounds,
"End index {end} for List of length {} out of bounds",
self.len()
);
let to_remove = end - start;
let old_len = self.len();
let new_len = old_len - to_remove;
let source_ptr: *const () = self.0.cast_const().cast();
let bytes_ptr = self.bytes.as_ptr();
let start_ptr = bytes_ptr.wrapping_add(start * size_of::<T>()).cast::<()>();
let end_ptr = bytes_ptr.wrapping_add(end * size_of::<T>()).cast::<()>();
unsafe {
ExclusiveRecurse::remove_bytes(self, source_ptr, start_ptr..end_ptr)?;
};
{
self.len = PackedValue(L::from_usize(new_len).ok_or_else(|| {
error!(
ErrorCode::ToPrimitiveError,
"Failed to convert new list len to L"
)
})?);
self.0 = ptr_meta::from_raw_parts_mut(self.0.cast::<()>(), new_len * size_of::<T>());
}
Ok(())
}
#[inline]
pub fn clear(&mut self) -> Result<()> {
self.remove_range(..)
}
}
impl<T, L> UnsizedInit<DefaultInit> for List<T, L>
where
L: ListLength,
T: CheckedBitPattern + NoUninit + Align1,
{
const INIT_BYTES: usize = size_of::<L>();
fn init(bytes: &mut &mut [u8], _arg: DefaultInit) -> Result<()> {
bytes
.try_advance(<Self as UnsizedInit<DefaultInit>>::INIT_BYTES)
.with_ctx(|| {
format!(
"Failed to advance {} bytes during default initialization of {}",
<Self as UnsizedInit<DefaultInit>>::INIT_BYTES,
std::any::type_name::<Self>()
)
})?
.copy_from_slice(bytes_of(&<PackedValue<L>>::zeroed()));
Ok(())
}
}
impl<const N: usize, T, L> UnsizedInit<&[T; N]> for List<T, L>
where
L: ListLength + Zero,
T: CheckedBitPattern + NoUninit + Align1,
{
const INIT_BYTES: usize = size_of::<L>() + size_of::<T>() * N;
fn init(bytes: &mut &mut [u8], array: &[T; N]) -> Result<()> {
let len_bytes = L::from_usize(N).ok_or_else(|| {
error!(
ErrorCode::ToPrimitiveError,
"Init array length larger than max size of List length {}",
type_name::<L>()
)
})?;
let array_bytes = bytes
.try_advance(<Self as UnsizedInit<&[T; N]>>::INIT_BYTES)
.with_ctx(|| {
format!(
"Failed to advance {} bytes during array initialization of {}",
<Self as UnsizedInit<&[T; N]>>::INIT_BYTES,
std::any::type_name::<Self>()
)
})?;
array_bytes[0..size_of::<L>()].copy_from_slice(bytes_of(&len_bytes));
array_bytes[size_of::<L>()..].copy_from_slice(uninit_array_bytes(array));
Ok(())
}
}
impl<const N: usize, T, L> UnsizedInit<[T; N]> for List<T, L>
where
L: ListLength + Zero,
T: CheckedBitPattern + NoUninit + Align1,
{
const INIT_BYTES: usize = <Self as UnsizedInit<&[T; N]>>::INIT_BYTES;
#[inline]
fn init(bytes: &mut &mut [u8], array: [T; N]) -> Result<()> {
<Self as UnsizedInit<&[T; N]>>::init(bytes, &array)
}
}
#[derive(Debug, Clone)]
pub struct ListIter<'a, T, L>
where
T: CheckedBitPattern + NoUninit + Align1,
L: ListLength,
{
list: &'a List<T, L>,
index: usize,
}
#[derive(Debug)]
pub struct ListIterMut<'a, T, L>
where
T: CheckedBitPattern + NoUninit + Align1,
L: ListLength,
{
list_bytes_ptr: *mut [u8],
remaining: usize,
phantom_data: PhantomData<&'a mut (T, L)>,
}
impl<'a, T, L> Iterator for ListIter<'a, T, L>
where
T: CheckedBitPattern + NoUninit + Align1,
L: ListLength,
{
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.list.len() {
return None;
}
let item = &self.list[self.index];
self.index += 1;
Some(item)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.list.len() - self.index;
(remaining, Some(remaining))
}
}
impl<T, L> ExactSizeIterator for ListIter<'_, T, L>
where
T: CheckedBitPattern + NoUninit + Align1,
L: ListLength,
{
fn len(&self) -> usize {
self.list.len() - self.index
}
}
impl<T, L> FusedIterator for ListIter<'_, T, L>
where
T: CheckedBitPattern + NoUninit + Align1,
L: ListLength,
{
}
impl<'a, T, L> IntoIterator for &'a List<T, L>
where
T: CheckedBitPattern + NoUninit + Align1,
L: ListLength,
{
type Item = &'a T;
type IntoIter = ListIter<'a, T, L>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<'a, T, L> Iterator for ListIterMut<'a, T, L>
where
T: CheckedBitPattern + NoUninit + Align1,
L: ListLength,
{
type Item = &'a mut T;
fn next(&mut self) -> Option<Self::Item> {
if self.remaining == 0 {
return None;
}
let mut list_bytes = unsafe { &mut *self.list_bytes_ptr };
let item_data = list_bytes.advance(size_of::<T>());
let item = checked::from_bytes_mut(item_data);
self.remaining -= 1;
self.list_bytes_ptr = list_bytes;
Some(item)
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
}
}
impl<T, L> ExactSizeIterator for ListIterMut<'_, T, L>
where
T: CheckedBitPattern + NoUninit + Align1,
L: ListLength,
{
fn len(&self) -> usize {
self.remaining
}
}
impl<T, L> FusedIterator for ListIterMut<'_, T, L>
where
T: CheckedBitPattern + NoUninit + Align1,
L: ListLength,
{
}
impl<'a, T, L> IntoIterator for &'a mut List<T, L>
where
T: CheckedBitPattern + NoUninit + Align1,
L: ListLength,
{
type Item = &'a mut T;
type IntoIter = ListIterMut<'a, T, L>;
fn into_iter(self) -> Self::IntoIter {
self.iter_mut()
}
}
#[cfg(all(test, feature = "test_helpers"))]
mod tests {
use super::*;
use crate::unsize::{unsized_type, NewByteSet};
use pretty_assertions::assert_eq;
#[test]
fn test_list_crud() -> Result<()> {
let mut vec = Vec::<PackedValue<u16>>::new();
let list_byte_set = List::<PackedValue<u16>>::new_default_byte_set()?;
let mut list = list_byte_set.data_mut()?;
assert_eq!(&*vec, &***list);
vec.extend_from_slice(&[10.into(), 20.into(), 30.into()]);
list.push_all([10.into(), 20.into(), 30.into()])?;
assert_eq!(&*vec, &***list);
vec.insert(1, 12.into());
vec.insert(2, 14.into());
vec.insert(1, 13.into());
list.insert_all(1, [PackedValue(12), 14.into()])?;
list.insert(1, 13.into())?;
list.as_mut_slice().copy_from_slice(&[
10.into(),
13.into(),
12.into(),
14.into(),
20.into(),
30.into(),
]);
assert_eq!(&*vec, &***list);
vec.pop();
list.pop()?;
assert_eq!(&*vec, &***list);
vec.remove(1);
vec.remove(1);
list.remove_range(1..3)?;
assert_eq!(&*vec, &***list);
Ok(())
}
#[unsized_type(skip_idl)]
struct InnerList {
#[unsized_start]
list: List<u8>,
}
#[unsized_type(skip_idl)]
struct OuterList {
#[unsized_start]
inner_list: InnerList,
}
#[test]
fn test_inner_list() -> Result<()> {
let test_bytes = OuterList::new_default_byte_set()?;
let mut bytes = test_bytes.data_mut()?;
let mut inner_list = bytes.inner_list();
inner_list.list().push(1)?;
inner_list.list().push(2)?;
drop(bytes);
let _bytes = test_bytes.data()?;
Ok(())
}
}