use crate::account::{FixedLayout, Pod};
use hopper_runtime::error::ProgramError;
const HEADER_SIZE: usize = 4;
pub struct SortedVec<'a, T: Pod + FixedLayout + Ord> {
data: &'a mut [u8],
_phantom: core::marker::PhantomData<T>,
}
impl<'a, T: Pod + FixedLayout + Ord> SortedVec<'a, T> {
#[inline]
pub fn from_bytes(data: &'a mut [u8]) -> Result<Self, ProgramError> {
if data.len() < HEADER_SIZE {
return Err(ProgramError::AccountDataTooSmall);
}
Ok(Self {
data,
_phantom: core::marker::PhantomData,
})
}
#[inline(always)]
pub fn len(&self) -> usize {
u32::from_le_bytes([self.data[0], self.data[1], self.data[2], self.data[3]]) as usize
}
#[inline(always)]
pub fn capacity(&self) -> usize {
(self.data.len() - HEADER_SIZE) / T::SIZE
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline(always)]
pub fn is_full(&self) -> bool {
self.len() >= self.capacity()
}
#[inline(always)]
fn set_len(&mut self, len: usize) {
let bytes = (len as u32).to_le_bytes();
self.data[0] = bytes[0];
self.data[1] = bytes[1];
self.data[2] = bytes[2];
self.data[3] = bytes[3];
}
#[inline(always)]
fn element_offset(index: usize) -> usize {
HEADER_SIZE + index * T::SIZE
}
#[inline]
fn read_at(&self, index: usize) -> T {
let offset = Self::element_offset(index);
unsafe { core::ptr::read_unaligned(self.data.as_ptr().add(offset) as *const T) }
}
#[inline]
fn write_at(&mut self, index: usize, value: T) {
let offset = Self::element_offset(index);
unsafe { core::ptr::write_unaligned(self.data.as_mut_ptr().add(offset) as *mut T, value) }
}
#[inline]
pub fn binary_search(&self, target: &T) -> Result<usize, usize> {
let len = self.len();
if len == 0 {
return Err(0);
}
let mut lo: usize = 0;
let mut hi: usize = len;
while lo < hi {
let mid = lo + (hi - lo) / 2;
let elem = self.read_at(mid);
match elem.cmp(target) {
core::cmp::Ordering::Less => lo = mid + 1,
core::cmp::Ordering::Equal => return Ok(mid),
core::cmp::Ordering::Greater => hi = mid,
}
}
Err(lo)
}
#[inline]
pub fn contains(&self, target: &T) -> bool {
self.binary_search(target).is_ok()
}
#[inline]
pub fn get(&self, index: usize) -> Result<T, ProgramError> {
if index >= self.len() {
return Err(ProgramError::InvalidArgument);
}
Ok(self.read_at(index))
}
#[inline]
pub fn insert(&mut self, value: T) -> Result<usize, ProgramError> {
let len = self.len();
if len >= self.capacity() {
return Err(ProgramError::AccountDataTooSmall);
}
let insert_idx = match self.binary_search(&value) {
Ok(idx) => idx, Err(idx) => idx,
};
if insert_idx < len {
let src_offset = Self::element_offset(insert_idx);
let dst_offset = Self::element_offset(insert_idx + 1);
let byte_count = (len - insert_idx) * T::SIZE;
unsafe {
core::ptr::copy(
self.data.as_ptr().add(src_offset),
self.data.as_mut_ptr().add(dst_offset),
byte_count,
);
}
}
self.write_at(insert_idx, value);
self.set_len(len + 1);
Ok(insert_idx)
}
#[inline]
pub fn insert_unique(&mut self, value: T) -> Result<usize, usize> {
match self.binary_search(&value) {
Ok(existing) => Err(existing),
Err(insert_idx) => {
let len = self.len();
if len >= self.capacity() {
return Err(usize::MAX); }
if insert_idx < len {
let src_offset = Self::element_offset(insert_idx);
let dst_offset = Self::element_offset(insert_idx + 1);
let byte_count = (len - insert_idx) * T::SIZE;
unsafe {
core::ptr::copy(
self.data.as_ptr().add(src_offset),
self.data.as_mut_ptr().add(dst_offset),
byte_count,
);
}
}
self.write_at(insert_idx, value);
self.set_len(len + 1);
Ok(insert_idx)
}
}
}
#[inline]
pub fn remove(&mut self, index: usize) -> Result<T, ProgramError> {
let len = self.len();
if index >= len {
return Err(ProgramError::InvalidArgument);
}
let removed = self.read_at(index);
if index + 1 < len {
let src_offset = Self::element_offset(index + 1);
let dst_offset = Self::element_offset(index);
let byte_count = (len - index - 1) * T::SIZE;
unsafe {
core::ptr::copy(
self.data.as_ptr().add(src_offset),
self.data.as_mut_ptr().add(dst_offset),
byte_count,
);
}
}
let last_offset = Self::element_offset(len - 1);
for b in &mut self.data[last_offset..last_offset + T::SIZE] {
*b = 0;
}
self.set_len(len - 1);
Ok(removed)
}
#[inline]
pub fn remove_value(&mut self, value: &T) -> Result<T, ProgramError> {
match self.binary_search(value) {
Ok(idx) => self.remove(idx),
Err(_) => Err(ProgramError::InvalidArgument),
}
}
#[inline]
pub fn min(&self) -> Result<T, ProgramError> {
if self.is_empty() {
return Err(ProgramError::InvalidArgument);
}
Ok(self.read_at(0))
}
#[inline]
pub fn max(&self) -> Result<T, ProgramError> {
let len = self.len();
if len == 0 {
return Err(ProgramError::InvalidArgument);
}
Ok(self.read_at(len - 1))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::abi::WireU64;
#[test]
fn sorted_vec_insert_and_search() {
let mut buf = [0u8; 4 + 8 * 8]; let mut sv = SortedVec::<WireU64>::from_bytes(&mut buf).unwrap();
sv.insert(WireU64::new(50)).unwrap();
sv.insert(WireU64::new(10)).unwrap();
sv.insert(WireU64::new(30)).unwrap();
sv.insert(WireU64::new(20)).unwrap();
sv.insert(WireU64::new(40)).unwrap();
assert_eq!(sv.len(), 5);
assert_eq!(sv.get(0).unwrap().get(), 10);
assert_eq!(sv.get(1).unwrap().get(), 20);
assert_eq!(sv.get(2).unwrap().get(), 30);
assert_eq!(sv.get(3).unwrap().get(), 40);
assert_eq!(sv.get(4).unwrap().get(), 50);
assert!(sv.contains(&WireU64::new(30)));
assert!(!sv.contains(&WireU64::new(25)));
assert_eq!(sv.binary_search(&WireU64::new(30)), Ok(2));
assert_eq!(sv.binary_search(&WireU64::new(25)), Err(2));
}
#[test]
fn sorted_vec_min_max() {
let mut buf = [0u8; 4 + 8 * 4];
let mut sv = SortedVec::<WireU64>::from_bytes(&mut buf).unwrap();
sv.insert(WireU64::new(100)).unwrap();
sv.insert(WireU64::new(5)).unwrap();
sv.insert(WireU64::new(42)).unwrap();
assert_eq!(sv.min().unwrap().get(), 5);
assert_eq!(sv.max().unwrap().get(), 100);
}
#[test]
fn sorted_vec_remove() {
let mut buf = [0u8; 4 + 8 * 4];
let mut sv = SortedVec::<WireU64>::from_bytes(&mut buf).unwrap();
sv.insert(WireU64::new(10)).unwrap();
sv.insert(WireU64::new(20)).unwrap();
sv.insert(WireU64::new(30)).unwrap();
sv.remove_value(&WireU64::new(20)).unwrap();
assert_eq!(sv.len(), 2);
assert_eq!(sv.get(0).unwrap().get(), 10);
assert_eq!(sv.get(1).unwrap().get(), 30);
}
#[test]
fn sorted_vec_insert_unique() {
let mut buf = [0u8; 4 + 8 * 4];
let mut sv = SortedVec::<WireU64>::from_bytes(&mut buf).unwrap();
assert!(sv.insert_unique(WireU64::new(10)).is_ok());
assert!(sv.insert_unique(WireU64::new(20)).is_ok());
assert!(sv.insert_unique(WireU64::new(10)).is_err());
assert_eq!(sv.len(), 2);
}
}