use borsh::{BorshDeserialize, BorshSerialize};
use core::mem::MaybeUninit;
pub trait LengthPrefix: Copy {
const MAX_CAPACITY: usize;
fn serialize_length<W: std::io::Write>(len: usize, writer: &mut W) -> std::io::Result<()>;
fn deserialize_length<R: std::io::Read>(reader: &mut R) -> std::io::Result<usize>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct U8Prefix;
impl LengthPrefix for U8Prefix {
const MAX_CAPACITY: usize = u8::MAX as usize;
fn serialize_length<W: std::io::Write>(len: usize, writer: &mut W) -> std::io::Result<()> {
if len > Self::MAX_CAPACITY {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Length {} exceeds u8::MAX ({})", len, Self::MAX_CAPACITY),
));
}
writer.write_all(&[len as u8])
}
fn deserialize_length<R: std::io::Read>(reader: &mut R) -> std::io::Result<usize> {
let mut buf = [0u8; 1];
reader.read_exact(&mut buf)?;
Ok(buf[0] as usize)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct U16Prefix;
impl LengthPrefix for U16Prefix {
const MAX_CAPACITY: usize = u16::MAX as usize;
fn serialize_length<W: std::io::Write>(len: usize, writer: &mut W) -> std::io::Result<()> {
if len > Self::MAX_CAPACITY {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Length {} exceeds u16::MAX ({})", len, Self::MAX_CAPACITY),
));
}
writer.write_all(&(len as u16).to_le_bytes())
}
fn deserialize_length<R: std::io::Read>(reader: &mut R) -> std::io::Result<usize> {
let mut buf = [0u8; 2];
reader.read_exact(&mut buf)?;
Ok(u16::from_le_bytes(buf) as usize)
}
}
pub struct SmallVec<T, P: LengthPrefix = U8Prefix> {
data: [MaybeUninit<T>; 8],
len: u8,
_phantom: core::marker::PhantomData<P>,
}
impl<T, P: LengthPrefix> SmallVec<T, P> {
pub const MAX_CAPACITY: usize = 8;
pub fn new() -> Self {
Self {
data: unsafe { MaybeUninit::uninit().assume_init() },
len: 0,
_phantom: core::marker::PhantomData,
}
}
pub fn with_capacity(_capacity: usize) -> Self {
Self::new()
}
pub fn push(&mut self, value: T) {
assert!(
(self.len as usize) < Self::MAX_CAPACITY,
"SmallVec exceeds max capacity: {} >= {}",
self.len,
Self::MAX_CAPACITY
);
self.data[self.len as usize] = MaybeUninit::new(value);
self.len += 1;
}
pub fn len(&self) -> usize {
self.len as usize
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn as_slice(&self) -> &[T] {
unsafe {
core::slice::from_raw_parts(
self.data.as_ptr() as *const T,
self.len as usize
)
}
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
unsafe {
core::slice::from_raw_parts_mut(
self.data.as_mut_ptr() as *mut T,
self.len as usize
)
}
}
pub fn iter(&self) -> core::slice::Iter<'_, T> {
self.as_slice().iter()
}
pub fn iter_mut(&mut self) -> core::slice::IterMut<'_, T> {
self.as_mut_slice().iter_mut()
}
}
impl<T, P: LengthPrefix> Drop for SmallVec<T, P> {
fn drop(&mut self) {
for i in 0..self.len as usize {
unsafe {
self.data[i].assume_init_drop();
}
}
}
}
impl<T: Clone, P: LengthPrefix> Clone for SmallVec<T, P> {
fn clone(&self) -> Self {
let mut result = Self::new();
for item in self.as_slice() {
result.push(item.clone());
}
result
}
}
impl<T: core::fmt::Debug, P: LengthPrefix> core::fmt::Debug for SmallVec<T, P> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_list().entries(self.as_slice()).finish()
}
}
impl<T: PartialEq, P: LengthPrefix> PartialEq for SmallVec<T, P> {
fn eq(&self, other: &Self) -> bool {
self.as_slice() == other.as_slice()
}
}
impl<T: Eq, P: LengthPrefix> Eq for SmallVec<T, P> {}
impl<T, P: LengthPrefix> Default for SmallVec<T, P> {
fn default() -> Self {
Self::new()
}
}
impl<T, P: LengthPrefix> core::ops::Deref for SmallVec<T, P> {
type Target = [T];
fn deref(&self) -> &Self::Target {
self.as_slice()
}
}
impl<T, P: LengthPrefix> core::ops::DerefMut for SmallVec<T, P> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_mut_slice()
}
}
impl<T, P: LengthPrefix> From<Vec<T>> for SmallVec<T, P> {
fn from(vec: Vec<T>) -> Self {
assert!(
vec.len() <= Self::MAX_CAPACITY,
"Vec length {} exceeds SmallVec max capacity {}",
vec.len(),
Self::MAX_CAPACITY
);
let mut result = Self::new();
for item in vec {
result.push(item);
}
result
}
}
impl<T: BorshSerialize, P: LengthPrefix> BorshSerialize for SmallVec<T, P> {
fn serialize<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
P::serialize_length(self.len as usize, writer)?;
for item in self.as_slice() {
item.serialize(writer)?;
}
Ok(())
}
}
impl<T: BorshDeserialize, P: LengthPrefix> BorshDeserialize for SmallVec<T, P> {
fn deserialize_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<Self> {
let len = P::deserialize_length(reader)?;
if len > Self::MAX_CAPACITY {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Deserialized length {} exceeds max capacity {}", len, Self::MAX_CAPACITY),
));
}
let mut result = Self::new();
for _ in 0..len {
result.push(T::deserialize_reader(reader)?);
}
Ok(result)
}
}
#[cfg(feature = "anchor")]
impl<T, P: LengthPrefix> anchor_lang::AnchorDeserialize for SmallVec<T, P>
where
T: BorshDeserialize,
{
fn deserialize_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<Self> {
<Self as BorshDeserialize>::deserialize_reader(reader)
}
}
#[cfg(feature = "anchor")]
impl<T, P: LengthPrefix> anchor_lang::AnchorSerialize for SmallVec<T, P>
where
T: BorshSerialize,
{
fn serialize<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
<Self as BorshSerialize>::serialize(self, writer)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_u8_prefix_serialization() {
let small_vec: SmallVec<u32, U8Prefix> = vec![1, 2, 3].into();
let mut serialized = Vec::new();
small_vec.serialize(&mut serialized).unwrap();
assert_eq!(serialized[0], 3u8, "First byte should be length as u8");
assert_eq!(serialized.len(), 1 + 3 * 4);
let mut slice = serialized.as_slice();
let deserialized = SmallVec::<u32, U8Prefix>::deserialize(&mut slice).unwrap();
assert_eq!(deserialized.as_slice(), &[1, 2, 3]);
}
#[test]
fn test_u16_prefix_serialization() {
let small_vec: SmallVec<u32, U16Prefix> = vec![1, 2, 3].into();
let mut serialized = Vec::new();
small_vec.serialize(&mut serialized).unwrap();
assert_eq!(serialized[0..2], [3u8, 0u8], "First 2 bytes should be length as u16 LE");
assert_eq!(serialized.len(), 2 + 3 * 4);
let mut slice = serialized.as_slice();
let deserialized = SmallVec::<u32, U16Prefix>::deserialize(&mut slice).unwrap();
assert_eq!(deserialized.as_slice(), &[1, 2, 3]);
}
#[test]
fn test_size_comparison() {
let regular_vec = vec![1u32, 2u32, 3u32];
let mut regular_serialized = Vec::new();
regular_vec.serialize(&mut regular_serialized).unwrap();
let small_vec_u8: SmallVec<u32, U8Prefix> = vec![1, 2, 3].into();
let mut small_u8_serialized = Vec::new();
small_vec_u8.serialize(&mut small_u8_serialized).unwrap();
let small_vec_u16: SmallVec<u32, U16Prefix> = vec![1, 2, 3].into();
let mut small_u16_serialized = Vec::new();
small_vec_u16.serialize(&mut small_u16_serialized).unwrap();
println!("Size comparison for Vec<u32> with 3 elements:");
println!(" Standard Vec: {} bytes", regular_serialized.len());
println!(" SmallVec<_, U8Prefix>: {} bytes (saves {} bytes)",
small_u8_serialized.len(),
regular_serialized.len() - small_u8_serialized.len());
println!(" SmallVec<_, U16Prefix>: {} bytes (saves {} bytes)",
small_u16_serialized.len(),
regular_serialized.len() - small_u16_serialized.len());
assert_eq!(regular_serialized.len(), 16); assert_eq!(small_u8_serialized.len(), 13); assert_eq!(small_u16_serialized.len(), 14); }
#[test]
#[should_panic(expected = "exceeds max capacity")]
fn test_overflow() {
let mut small_vec: SmallVec<u8, U8Prefix> = SmallVec::new();
for i in 0..=8 {
small_vec.push(i);
}
}
#[test]
fn test_max_capacity() {
let mut small_vec: SmallVec<u8, U8Prefix> = SmallVec::new();
for i in 0..8 {
small_vec.push(i);
}
assert_eq!(small_vec.len(), 8);
assert_eq!(small_vec.as_slice(), &[0, 1, 2, 3, 4, 5, 6, 7]);
}
}