Skip to main content

kaigan/types/
remainder_vec.rs

1use std::fmt::Debug;
2#[cfg(not(feature = "anchor"))]
3use std::io::Write;
4use std::ops::{Deref, DerefMut};
5
6#[cfg(feature = "anchor")]
7use anchor_lang::prelude::{
8    AnchorDeserialize as CrateDeserialize, AnchorSerialize as CrateSerialize,
9};
10#[cfg(all(not(feature = "anchor"), not(feature = "borsh-v1")))]
11use borsh::{BorshDeserialize as CrateDeserialize, BorshSerialize as CrateSerialize};
12#[cfg(not(feature = "anchor"))]
13use borsh_1_5::io::Read;
14#[cfg(all(not(feature = "anchor"), feature = "borsh-v1"))]
15use borsh_1_5::{BorshDeserialize as CrateDeserialize, BorshSerialize as CrateSerialize};
16
17/// A vector that deserializes from a stream of bytes.
18///
19/// This is useful for deserializing a vector that does not have
20/// a length prefix. In order to determine how many elements to deserialize,
21/// the type of the elements must implement the trait `Sized`.
22#[derive(Clone, Eq, PartialEq)]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24#[cfg_attr(
25    feature = "anchor",
26    derive(
27        anchor_lang::prelude::AnchorSerialize,
28        anchor_lang::prelude::AnchorDeserialize
29    )
30)]
31pub struct RemainderVec<T: CrateSerialize + CrateDeserialize>(Vec<T>);
32
33/// Dereferences the inner `Vec` type.
34impl<T> Deref for RemainderVec<T>
35where
36    T: CrateSerialize + CrateDeserialize,
37{
38    type Target = Vec<T>;
39
40    fn deref(&self) -> &Self::Target {
41        &self.0
42    }
43}
44
45/// Dereferences the inner `Vec` type as mutable.
46impl<T> DerefMut for RemainderVec<T>
47where
48    T: CrateSerialize + CrateDeserialize,
49{
50    fn deref_mut(&mut self) -> &mut Self::Target {
51        &mut self.0
52    }
53}
54
55/// `Debug` implementation for `RemainderVec`.
56///
57/// This implementation simply forwards to the inner `Vec` type.
58impl<T> Debug for RemainderVec<T>
59where
60    T: CrateSerialize + CrateDeserialize + Debug,
61{
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        f.write_fmt(format_args!("{:?}", self.0))
64    }
65}
66
67#[cfg(not(feature = "anchor"))]
68impl<T> CrateDeserialize for RemainderVec<T>
69where
70    T: CrateSerialize + CrateDeserialize,
71{
72    fn deserialize_reader<R: Read>(reader: &mut R) -> borsh_1_5::io::Result<Self> {
73        let mut data = Vec::new();
74        reader.read_to_end(&mut data)?;
75
76        let mut items: Vec<T> = Vec::new();
77        let mut cursor = std::io::Cursor::new(&data);
78
79        while (cursor.position() as usize) < data.len() {
80            match T::deserialize_reader(&mut cursor) {
81                Ok(item) => items.push(item),
82                Err(_) => {
83                    return Err(borsh_1_5::io::Error::new(
84                        borsh_1_5::io::ErrorKind::InvalidData,
85                        "unexpected trailing bytes",
86                    ));
87                }
88            }
89        }
90
91        Ok(Self(items))
92    }
93}
94
95#[cfg(not(feature = "anchor"))]
96impl<T> CrateSerialize for RemainderVec<T>
97where
98    T: CrateSerialize + CrateDeserialize,
99{
100    fn serialize<W: Write>(&self, writer: &mut W) -> borsh_1_5::io::Result<()> {
101        // serialize each item without adding a prefix for the length
102        for item in self.0.iter() {
103            item.serialize(writer)?;
104        }
105
106        Ok(())
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    #[test]
115    fn deserialize_data() {
116        // slices of bytes (3 u64 values)
117        let mut data = [0u8; 24];
118        data[0..8].copy_from_slice(u64::to_le_bytes(5).as_slice());
119        data[8..16].copy_from_slice(u64::to_le_bytes(15).as_slice());
120        data[16..].copy_from_slice(u64::to_le_bytes(7).as_slice());
121
122        let vec = RemainderVec::<u64>::try_from_slice(&data).unwrap();
123
124        assert_eq!(vec.len(), 3);
125        assert_eq!(vec.as_slice(), &[5, 15, 7]);
126    }
127
128    #[test]
129    fn serialize_data() {
130        let values = (0..10).collect::<Vec<u32>>();
131        let source = RemainderVec::<u32>(values);
132
133        let mut data = Vec::new();
134        source.serialize(&mut data).unwrap();
135
136        let restored = RemainderVec::<u32>::try_from_slice(&data).unwrap();
137
138        assert_eq!(restored.len(), source.len());
139        assert_eq!(restored.as_slice(), source.as_slice());
140    }
141
142    #[test]
143    fn fail_deserialize_invalid_data_length() {
144        // slices of bytes (3 u64 values) + 4 bytes
145        let mut data = [0u8; 28];
146        data[0..8].copy_from_slice(u64::to_le_bytes(5).as_slice());
147        data[8..16].copy_from_slice(u64::to_le_bytes(15).as_slice());
148        data[16..24].copy_from_slice(u64::to_le_bytes(7).as_slice());
149
150        let error = RemainderVec::<u64>::try_from_slice(&data).unwrap_err();
151
152        assert_eq!(error.kind(), borsh_1_5::io::ErrorKind::InvalidData);
153    }
154}