1use crate::error::ProgramError;
38
39pub trait TailCodec: Sized {
47 const MAX_ENCODED_LEN: usize;
52
53 fn encode(&self, out: &mut [u8]) -> Result<usize, ProgramError>;
57
58 fn decode(input: &[u8]) -> Result<(Self, usize), ProgramError>;
61}
62
63impl TailCodec for u8 {
66 const MAX_ENCODED_LEN: usize = 1;
67 #[inline]
68 fn encode(&self, out: &mut [u8]) -> Result<usize, ProgramError> {
69 if out.is_empty() {
70 return Err(ProgramError::AccountDataTooSmall);
71 }
72 out[0] = *self;
73 Ok(1)
74 }
75 #[inline]
76 fn decode(input: &[u8]) -> Result<(Self, usize), ProgramError> {
77 input
78 .first()
79 .copied()
80 .map(|b| (b, 1))
81 .ok_or(ProgramError::InvalidAccountData)
82 }
83}
84
85macro_rules! tail_codec_int {
86 ( $( $ty:ty : $n:expr ),+ $(,)? ) => {
87 $(
88 impl TailCodec for $ty {
89 const MAX_ENCODED_LEN: usize = $n;
90 #[inline]
91 fn encode(&self, out: &mut [u8]) -> Result<usize, ProgramError> {
92 if out.len() < $n {
93 return Err(ProgramError::AccountDataTooSmall);
94 }
95 out[..$n].copy_from_slice(&self.to_le_bytes());
96 Ok($n)
97 }
98 #[inline]
99 fn decode(input: &[u8]) -> Result<(Self, usize), ProgramError> {
100 if input.len() < $n {
101 return Err(ProgramError::InvalidAccountData);
102 }
103 let mut bytes = [0u8; $n];
104 bytes.copy_from_slice(&input[..$n]);
105 Ok((Self::from_le_bytes(bytes), $n))
106 }
107 }
108 )+
109 };
110}
111
112tail_codec_int! {
113 u16: 2, u32: 4, u64: 8, u128: 16,
114 i16: 2, i32: 4, i64: 8, i128: 16,
115}
116
117impl TailCodec for bool {
119 const MAX_ENCODED_LEN: usize = 1;
120 #[inline]
121 fn encode(&self, out: &mut [u8]) -> Result<usize, ProgramError> {
122 if out.is_empty() {
123 return Err(ProgramError::AccountDataTooSmall);
124 }
125 out[0] = if *self { 1 } else { 0 };
126 Ok(1)
127 }
128 #[inline]
129 fn decode(input: &[u8]) -> Result<(Self, usize), ProgramError> {
130 match input.first().copied() {
131 Some(0) => Ok((false, 1)),
132 Some(1) => Ok((true, 1)),
133 _ => Err(ProgramError::InvalidAccountData),
134 }
135 }
136}
137
138impl<const N: usize> TailCodec for [u8; N] {
140 const MAX_ENCODED_LEN: usize = N;
141 #[inline]
142 fn encode(&self, out: &mut [u8]) -> Result<usize, ProgramError> {
143 if out.len() < N {
144 return Err(ProgramError::AccountDataTooSmall);
145 }
146 out[..N].copy_from_slice(self);
147 Ok(N)
148 }
149 #[inline]
150 fn decode(input: &[u8]) -> Result<(Self, usize), ProgramError> {
151 if input.len() < N {
152 return Err(ProgramError::InvalidAccountData);
153 }
154 let mut out = [0u8; N];
155 out.copy_from_slice(&input[..N]);
156 Ok((out, N))
157 }
158}
159
160impl<T: TailCodec> TailCodec for Option<T> {
162 const MAX_ENCODED_LEN: usize = 1 + T::MAX_ENCODED_LEN;
163 #[inline]
164 fn encode(&self, out: &mut [u8]) -> Result<usize, ProgramError> {
165 if out.is_empty() {
166 return Err(ProgramError::AccountDataTooSmall);
167 }
168 match self {
169 None => {
170 out[0] = 0;
171 Ok(1)
172 }
173 Some(inner) => {
174 out[0] = 1;
175 let written = inner.encode(&mut out[1..])?;
176 Ok(1 + written)
177 }
178 }
179 }
180 #[inline]
181 fn decode(input: &[u8]) -> Result<(Self, usize), ProgramError> {
182 match input.first().copied() {
183 Some(0) => Ok((None, 1)),
184 Some(1) => {
185 let (inner, n) = T::decode(&input[1..])?;
186 Ok((Some(inner), 1 + n))
187 }
188 _ => Err(ProgramError::InvalidAccountData),
189 }
190 }
191}
192
193#[inline]
203pub fn read_tail_len(data: &[u8], body_end: usize) -> Result<u32, ProgramError> {
204 let end = body_end
205 .checked_add(4)
206 .ok_or(ProgramError::AccountDataTooSmall)?;
207 if data.len() < end {
208 return Err(ProgramError::AccountDataTooSmall);
209 }
210 let mut bytes = [0u8; 4];
211 bytes.copy_from_slice(&data[body_end..end]);
212 Ok(u32::from_le_bytes(bytes))
213}
214
215#[inline]
218pub fn tail_payload(data: &[u8], body_end: usize) -> Result<&[u8], ProgramError> {
219 let len = read_tail_len(data, body_end)? as usize;
220 let start = body_end + 4;
221 let end = start.checked_add(len).ok_or(ProgramError::InvalidAccountData)?;
222 if data.len() < end {
223 return Err(ProgramError::InvalidAccountData);
224 }
225 Ok(&data[start..end])
226}
227
228#[inline]
232pub fn read_tail<T: TailCodec>(
233 data: &[u8],
234 body_end: usize,
235) -> Result<T, ProgramError> {
236 let payload = tail_payload(data, body_end)?;
237 let (value, consumed) = T::decode(payload)?;
238 if consumed != payload.len() {
239 return Err(ProgramError::InvalidAccountData);
240 }
241 Ok(value)
242}
243
244#[inline]
249pub fn write_tail<T: TailCodec>(
250 data: &mut [u8],
251 body_end: usize,
252 tail: &T,
253) -> Result<usize, ProgramError> {
254 let prefix_end = body_end
255 .checked_add(4)
256 .ok_or(ProgramError::AccountDataTooSmall)?;
257 if data.len() < prefix_end {
258 return Err(ProgramError::AccountDataTooSmall);
259 }
260 let written = tail.encode(&mut data[prefix_end..])?;
261 if written > u32::MAX as usize {
262 return Err(ProgramError::InvalidAccountData);
263 }
264 data[body_end..prefix_end].copy_from_slice(&(written as u32).to_le_bytes());
265 Ok(written)
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 #[test]
273 fn u32_roundtrip() {
274 let mut buf = [0u8; 8];
275 let n = 0xDEAD_BEEFu32.encode(&mut buf).unwrap();
276 assert_eq!(n, 4);
277 let (back, consumed) = u32::decode(&buf).unwrap();
278 assert_eq!(consumed, 4);
279 assert_eq!(back, 0xDEAD_BEEF);
280 }
281
282 #[test]
283 fn u64_roundtrip() {
284 let mut buf = [0u8; 8];
285 0x0123_4567_89AB_CDEFu64.encode(&mut buf).unwrap();
286 let (back, _) = u64::decode(&buf).unwrap();
287 assert_eq!(back, 0x0123_4567_89AB_CDEF);
288 }
289
290 #[test]
291 fn bool_encode_decode() {
292 let mut buf = [0u8; 1];
293 true.encode(&mut buf).unwrap();
294 assert_eq!(buf[0], 1);
295 assert_eq!(bool::decode(&buf).unwrap(), (true, 1));
296 false.encode(&mut buf).unwrap();
297 assert_eq!(buf[0], 0);
298 assert_eq!(bool::decode(&buf).unwrap(), (false, 1));
299 }
300
301 #[test]
302 fn bool_rejects_garbage() {
303 let buf = [2u8];
304 assert!(bool::decode(&buf).is_err());
305 }
306
307 #[test]
308 fn byte_array_roundtrip() {
309 let src: [u8; 8] = *b"HOPPER!!";
310 let mut buf = [0u8; 16];
311 let n = src.encode(&mut buf).unwrap();
312 assert_eq!(n, 8);
313 let (back, consumed) = <[u8; 8]>::decode(&buf).unwrap();
314 assert_eq!(consumed, 8);
315 assert_eq!(back, src);
316 }
317
318 #[test]
319 fn option_none_encodes_to_one_byte() {
320 let mut buf = [0u8; 16];
321 let n = Option::<u64>::None.encode(&mut buf).unwrap();
322 assert_eq!(n, 1);
323 assert_eq!(buf[0], 0);
324 let (back, c) = <Option<u64>>::decode(&buf).unwrap();
325 assert_eq!(back, None);
326 assert_eq!(c, 1);
327 }
328
329 #[test]
330 fn option_some_includes_inner_payload() {
331 let mut buf = [0u8; 16];
332 let n = Option::<u64>::Some(0xAAAA_BBBB_CCCC_DDDD).encode(&mut buf).unwrap();
333 assert_eq!(n, 9);
334 assert_eq!(buf[0], 1);
335 let (back, c) = <Option<u64>>::decode(&buf).unwrap();
336 assert_eq!(back, Some(0xAAAA_BBBB_CCCC_DDDD));
337 assert_eq!(c, 9);
338 }
339
340 #[test]
341 fn option_rejects_invalid_tag() {
342 let buf = [7u8, 0, 0, 0, 0, 0, 0, 0, 0];
343 assert!(<Option<u64>>::decode(&buf).is_err());
344 }
345
346 #[test]
347 fn tail_length_prefix_roundtrip() {
348 let mut data = [0u8; 64];
351 let body_end = 24usize;
352 let tail_value: u64 = 0x1234_5678_9ABC_DEF0;
353 let written = write_tail(&mut data, body_end, &tail_value).unwrap();
354 assert_eq!(written, 8);
355 let read_len = read_tail_len(&data, body_end).unwrap();
356 assert_eq!(read_len, 8);
357 let back: u64 = read_tail::<u64>(&data, body_end).unwrap();
358 assert_eq!(back, tail_value);
359 }
360
361 #[test]
362 fn tail_decode_rejects_excess_payload() {
363 let mut data = [0u8; 32];
366 let body_end = 16usize;
370 data[body_end..body_end + 4].copy_from_slice(&8u32.to_le_bytes());
371 data[body_end + 4..body_end + 8].copy_from_slice(&0x1122_3344u32.to_le_bytes());
374 data[body_end + 8..body_end + 12].copy_from_slice(&0xFFu32.to_le_bytes());
375 let result = read_tail::<u32>(&data, body_end);
377 assert!(result.is_err());
378 }
379
380 #[test]
381 fn tail_bounds_check_on_truncated_buffer() {
382 let data = [0u8; 10];
383 assert!(read_tail_len(&data, 16).is_err());
384 assert!(tail_payload(&data, 16).is_err());
385 }
386
387 #[test]
388 fn max_encoded_len_matches_actual_encode_size() {
389 let mut buf = [0u8; 32];
390 assert_eq!(0u32.encode(&mut buf).unwrap(), u32::MAX_ENCODED_LEN);
391 assert_eq!(0u64.encode(&mut buf).unwrap(), u64::MAX_ENCODED_LEN);
392 assert_eq!(true.encode(&mut buf).unwrap(), bool::MAX_ENCODED_LEN);
393 assert_eq!([0u8; 7].encode(&mut buf).unwrap(), <[u8; 7]>::MAX_ENCODED_LEN);
394 assert_eq!(Option::<u32>::None.encode(&mut buf).unwrap(), 1);
395 assert_eq!(
396 Option::<u32>::Some(0).encode(&mut buf).unwrap(),
397 <Option<u32>>::MAX_ENCODED_LEN
398 );
399 }
400}