1use {
4 crate::{
5 bytemuck::{
6 pod_from_bytes, pod_from_bytes_mut, pod_slice_from_bytes, pod_slice_from_bytes_mut,
7 },
8 error::PodSliceError,
9 primitives::PodU32,
10 },
11 bytemuck::Pod,
12 clone_solana_program_error::ProgramError,
13};
14
15const LENGTH_SIZE: usize = std::mem::size_of::<PodU32>();
16pub struct PodSlice<'data, T: Pod> {
18 length: &'data PodU32,
19 data: &'data [T],
20}
21impl<'data, T: Pod> PodSlice<'data, T> {
22 pub fn unpack<'a>(data: &'a [u8]) -> Result<Self, ProgramError>
24 where
25 'a: 'data,
26 {
27 if data.len() < LENGTH_SIZE {
28 return Err(PodSliceError::BufferTooSmall.into());
29 }
30 let (length, data) = data.split_at(LENGTH_SIZE);
31 let length = pod_from_bytes::<PodU32>(length)?;
32 let _max_length = max_len_for_type::<T>(data.len())?;
33 let data = pod_slice_from_bytes(data)?;
34 Ok(Self { length, data })
35 }
36
37 pub fn data(&self) -> &[T] {
39 let length = u32::from(*self.length) as usize;
40 &self.data[..length]
41 }
42
43 pub fn size_of(num_items: usize) -> Result<usize, ProgramError> {
45 std::mem::size_of::<T>()
46 .checked_mul(num_items)
47 .and_then(|len| len.checked_add(LENGTH_SIZE))
48 .ok_or_else(|| PodSliceError::CalculationFailure.into())
49 }
50}
51
52pub struct PodSliceMut<'data, T: Pod> {
54 length: &'data mut PodU32,
55 data: &'data mut [T],
56 max_length: usize,
57}
58impl<'data, T: Pod> PodSliceMut<'data, T> {
59 fn unpack_internal<'a>(data: &'a mut [u8], init: bool) -> Result<Self, ProgramError>
62 where
63 'a: 'data,
64 {
65 if data.len() < LENGTH_SIZE {
66 return Err(PodSliceError::BufferTooSmall.into());
67 }
68 let (length, data) = data.split_at_mut(LENGTH_SIZE);
69 let length = pod_from_bytes_mut::<PodU32>(length)?;
70 if init {
71 *length = 0.into();
72 }
73 let max_length = max_len_for_type::<T>(data.len())?;
74 let data = pod_slice_from_bytes_mut(data)?;
75 Ok(Self {
76 length,
77 data,
78 max_length,
79 })
80 }
81
82 pub fn unpack<'a>(data: &'a mut [u8]) -> Result<Self, ProgramError>
84 where
85 'a: 'data,
86 {
87 Self::unpack_internal(data, false)
88 }
89
90 pub fn init<'a>(data: &'a mut [u8]) -> Result<Self, ProgramError>
93 where
94 'a: 'data,
95 {
96 Self::unpack_internal(data, true)
97 }
98
99 pub fn push(&mut self, t: T) -> Result<(), ProgramError> {
101 let length = u32::from(*self.length);
102 if length as usize == self.max_length {
103 Err(PodSliceError::BufferTooSmall.into())
104 } else {
105 self.data[length as usize] = t;
106 *self.length = length.saturating_add(1).into();
107 Ok(())
108 }
109 }
110}
111
112fn max_len_for_type<T>(data_len: usize) -> Result<usize, ProgramError> {
113 let size: usize = std::mem::size_of::<T>();
114 let max_len = data_len
115 .checked_div(size)
116 .ok_or(PodSliceError::CalculationFailure)?;
117 if max_len.saturating_mul(size) != data_len {
119 if max_len == 0 {
120 Err(PodSliceError::BufferTooSmall.into())
122 } else {
123 Err(PodSliceError::BufferTooLarge.into())
124 }
125 } else {
126 Ok(max_len)
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use {
133 super::*,
134 crate::bytemuck::pod_slice_to_bytes,
135 bytemuck_derive::{Pod, Zeroable},
136 };
137
138 #[repr(C)]
139 #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
140 struct TestStruct {
141 test_field: u8,
142 test_pubkey: [u8; 32],
143 }
144
145 #[test]
146 fn test_pod_slice() {
147 let test_field_bytes = [0];
148 let test_pubkey_bytes = [1; 32];
149 let len_bytes = [2, 0, 0, 0];
150
151 let mut data_bytes = [0; 66];
153 data_bytes[0..1].copy_from_slice(&test_field_bytes);
154 data_bytes[1..33].copy_from_slice(&test_pubkey_bytes);
155 data_bytes[33..34].copy_from_slice(&test_field_bytes);
156 data_bytes[34..66].copy_from_slice(&test_pubkey_bytes);
157
158 let mut pod_slice_bytes = [0; 70];
159 pod_slice_bytes[0..4].copy_from_slice(&len_bytes);
160 pod_slice_bytes[4..70].copy_from_slice(&data_bytes);
161
162 let pod_slice = PodSlice::<TestStruct>::unpack(&pod_slice_bytes).unwrap();
163 let pod_slice_data = pod_slice.data();
164
165 assert_eq!(*pod_slice.length, PodU32::from(2));
166 assert_eq!(pod_slice_to_bytes(pod_slice.data()), data_bytes);
167 assert_eq!(pod_slice_data[0].test_field, test_field_bytes[0]);
168 assert_eq!(pod_slice_data[0].test_pubkey, test_pubkey_bytes);
169 assert_eq!(PodSlice::<TestStruct>::size_of(1).unwrap(), 37);
170 }
171
172 #[test]
173 fn test_pod_slice_buffer_too_large() {
174 let pod_slice_bytes = [1; 38];
177 let err = PodSlice::<TestStruct>::unpack(&pod_slice_bytes)
178 .err()
179 .unwrap();
180 assert_eq!(
181 err,
182 PodSliceError::BufferTooLarge.into(),
183 "Expected an `PodSliceError::BufferTooLarge` error"
184 );
185 }
186
187 #[test]
188 fn test_pod_slice_buffer_too_small() {
189 let pod_slice_bytes = [1; 36];
192 let err = PodSlice::<TestStruct>::unpack(&pod_slice_bytes)
193 .err()
194 .unwrap();
195 assert_eq!(
196 err,
197 PodSliceError::BufferTooSmall.into(),
198 "Expected an `PodSliceError::BufferTooSmall` error"
199 );
200 }
201
202 #[test]
203 fn test_pod_slice_mut() {
204 let mut pod_slice_bytes = [0; 70];
206 let len_bytes = [1, 0, 0, 0];
208 pod_slice_bytes[0..4].copy_from_slice(&len_bytes);
209
210 let mut pod_slice = PodSliceMut::<TestStruct>::unpack(&mut pod_slice_bytes).unwrap();
211
212 assert_eq!(*pod_slice.length, PodU32::from(1));
213 pod_slice.push(TestStruct::default()).unwrap();
214 assert_eq!(*pod_slice.length, PodU32::from(2));
215 let err = pod_slice
216 .push(TestStruct::default())
217 .expect_err("Expected an `PodSliceError::BufferTooSmall` error");
218 assert_eq!(err, PodSliceError::BufferTooSmall.into());
219 }
220}