clone_spl_pod/
slice.rs

1//! Special types for working with slices of `Pod`s
2
3use {
4    crate::{
5        bytemuck::{
6            pod_from_bytes, pod_from_bytes_mut, pod_slice_from_bytes, pod_slice_from_bytes_mut,
7        },
8        error::PodSliceError,
9        primitives::PodU32,
10    },
11    bytemuck::Pod,
12    clone_solana_program_error::ProgramError,
13};
14
15const LENGTH_SIZE: usize = std::mem::size_of::<PodU32>();
16/// Special type for using a slice of `Pod`s in a zero-copy way
17pub struct PodSlice<'data, T: Pod> {
18    length: &'data PodU32,
19    data: &'data [T],
20}
21impl<'data, T: Pod> PodSlice<'data, T> {
22    /// Unpack the buffer into a slice
23    pub fn unpack<'a>(data: &'a [u8]) -> Result<Self, ProgramError>
24    where
25        'a: 'data,
26    {
27        if data.len() < LENGTH_SIZE {
28            return Err(PodSliceError::BufferTooSmall.into());
29        }
30        let (length, data) = data.split_at(LENGTH_SIZE);
31        let length = pod_from_bytes::<PodU32>(length)?;
32        let _max_length = max_len_for_type::<T>(data.len())?;
33        let data = pod_slice_from_bytes(data)?;
34        Ok(Self { length, data })
35    }
36
37    /// Get the slice data
38    pub fn data(&self) -> &[T] {
39        let length = u32::from(*self.length) as usize;
40        &self.data[..length]
41    }
42
43    /// Get the amount of bytes used by `num_items`
44    pub fn size_of(num_items: usize) -> Result<usize, ProgramError> {
45        std::mem::size_of::<T>()
46            .checked_mul(num_items)
47            .and_then(|len| len.checked_add(LENGTH_SIZE))
48            .ok_or_else(|| PodSliceError::CalculationFailure.into())
49    }
50}
51
52/// Special type for using a slice of mutable `Pod`s in a zero-copy way
53pub struct PodSliceMut<'data, T: Pod> {
54    length: &'data mut PodU32,
55    data: &'data mut [T],
56    max_length: usize,
57}
58impl<'data, T: Pod> PodSliceMut<'data, T> {
59    /// Unpack the mutable buffer into a mutable slice, with the option to
60    /// initialize the data
61    fn unpack_internal<'a>(data: &'a mut [u8], init: bool) -> Result<Self, ProgramError>
62    where
63        'a: 'data,
64    {
65        if data.len() < LENGTH_SIZE {
66            return Err(PodSliceError::BufferTooSmall.into());
67        }
68        let (length, data) = data.split_at_mut(LENGTH_SIZE);
69        let length = pod_from_bytes_mut::<PodU32>(length)?;
70        if init {
71            *length = 0.into();
72        }
73        let max_length = max_len_for_type::<T>(data.len())?;
74        let data = pod_slice_from_bytes_mut(data)?;
75        Ok(Self {
76            length,
77            data,
78            max_length,
79        })
80    }
81
82    /// Unpack the mutable buffer into a mutable slice
83    pub fn unpack<'a>(data: &'a mut [u8]) -> Result<Self, ProgramError>
84    where
85        'a: 'data,
86    {
87        Self::unpack_internal(data, /* init */ false)
88    }
89
90    /// Unpack the mutable buffer into a mutable slice, and initialize the
91    /// slice to 0-length
92    pub fn init<'a>(data: &'a mut [u8]) -> Result<Self, ProgramError>
93    where
94        'a: 'data,
95    {
96        Self::unpack_internal(data, /* init */ true)
97    }
98
99    /// Add another item to the slice
100    pub fn push(&mut self, t: T) -> Result<(), ProgramError> {
101        let length = u32::from(*self.length);
102        if length as usize == self.max_length {
103            Err(PodSliceError::BufferTooSmall.into())
104        } else {
105            self.data[length as usize] = t;
106            *self.length = length.saturating_add(1).into();
107            Ok(())
108        }
109    }
110}
111
112fn max_len_for_type<T>(data_len: usize) -> Result<usize, ProgramError> {
113    let size: usize = std::mem::size_of::<T>();
114    let max_len = data_len
115        .checked_div(size)
116        .ok_or(PodSliceError::CalculationFailure)?;
117    // check that it isn't over or under allocated
118    if max_len.saturating_mul(size) != data_len {
119        if max_len == 0 {
120            // Size of T is greater than buffer size
121            Err(PodSliceError::BufferTooSmall.into())
122        } else {
123            Err(PodSliceError::BufferTooLarge.into())
124        }
125    } else {
126        Ok(max_len)
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use {
133        super::*,
134        crate::bytemuck::pod_slice_to_bytes,
135        bytemuck_derive::{Pod, Zeroable},
136    };
137
138    #[repr(C)]
139    #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
140    struct TestStruct {
141        test_field: u8,
142        test_pubkey: [u8; 32],
143    }
144
145    #[test]
146    fn test_pod_slice() {
147        let test_field_bytes = [0];
148        let test_pubkey_bytes = [1; 32];
149        let len_bytes = [2, 0, 0, 0];
150
151        // Slice will contain 2 `TestStruct`
152        let mut data_bytes = [0; 66];
153        data_bytes[0..1].copy_from_slice(&test_field_bytes);
154        data_bytes[1..33].copy_from_slice(&test_pubkey_bytes);
155        data_bytes[33..34].copy_from_slice(&test_field_bytes);
156        data_bytes[34..66].copy_from_slice(&test_pubkey_bytes);
157
158        let mut pod_slice_bytes = [0; 70];
159        pod_slice_bytes[0..4].copy_from_slice(&len_bytes);
160        pod_slice_bytes[4..70].copy_from_slice(&data_bytes);
161
162        let pod_slice = PodSlice::<TestStruct>::unpack(&pod_slice_bytes).unwrap();
163        let pod_slice_data = pod_slice.data();
164
165        assert_eq!(*pod_slice.length, PodU32::from(2));
166        assert_eq!(pod_slice_to_bytes(pod_slice.data()), data_bytes);
167        assert_eq!(pod_slice_data[0].test_field, test_field_bytes[0]);
168        assert_eq!(pod_slice_data[0].test_pubkey, test_pubkey_bytes);
169        assert_eq!(PodSlice::<TestStruct>::size_of(1).unwrap(), 37);
170    }
171
172    #[test]
173    fn test_pod_slice_buffer_too_large() {
174        // 1 `TestStruct` + length = 37 bytes
175        // we pass 38 to trigger BufferTooLarge
176        let pod_slice_bytes = [1; 38];
177        let err = PodSlice::<TestStruct>::unpack(&pod_slice_bytes)
178            .err()
179            .unwrap();
180        assert_eq!(
181            err,
182            PodSliceError::BufferTooLarge.into(),
183            "Expected an `PodSliceError::BufferTooLarge` error"
184        );
185    }
186
187    #[test]
188    fn test_pod_slice_buffer_too_small() {
189        // 1 `TestStruct` + length = 37 bytes
190        // we pass 36 to trigger BufferTooSmall
191        let pod_slice_bytes = [1; 36];
192        let err = PodSlice::<TestStruct>::unpack(&pod_slice_bytes)
193            .err()
194            .unwrap();
195        assert_eq!(
196            err,
197            PodSliceError::BufferTooSmall.into(),
198            "Expected an `PodSliceError::BufferTooSmall` error"
199        );
200    }
201
202    #[test]
203    fn test_pod_slice_mut() {
204        // slice can fit 2 `TestStruct`
205        let mut pod_slice_bytes = [0; 70];
206        // set length to 1, so we have room to push 1 more item
207        let len_bytes = [1, 0, 0, 0];
208        pod_slice_bytes[0..4].copy_from_slice(&len_bytes);
209
210        let mut pod_slice = PodSliceMut::<TestStruct>::unpack(&mut pod_slice_bytes).unwrap();
211
212        assert_eq!(*pod_slice.length, PodU32::from(1));
213        pod_slice.push(TestStruct::default()).unwrap();
214        assert_eq!(*pod_slice.length, PodU32::from(2));
215        let err = pod_slice
216            .push(TestStruct::default())
217            .expect_err("Expected an `PodSliceError::BufferTooSmall` error");
218        assert_eq!(err, PodSliceError::BufferTooSmall.into());
219    }
220}