gemachain_program/
borsh.rs

1#![allow(clippy::integer_arithmetic)]
2//! Borsh utils
3use {
4    borsh::{
5        maybestd::io::{Error, Write},
6        schema::{BorshSchema, Declaration, Definition, Fields},
7        BorshDeserialize, BorshSerialize,
8    },
9    std::collections::HashMap,
10};
11
12/// Get packed length for the given BorchSchema Declaration
13fn get_declaration_packed_len(
14    declaration: &str,
15    definitions: &HashMap<Declaration, Definition>,
16) -> usize {
17    match definitions.get(declaration) {
18        Some(Definition::Array { length, elements }) => {
19            *length as usize * get_declaration_packed_len(elements, definitions)
20        }
21        Some(Definition::Enum { variants }) => {
22            1 + variants
23                .iter()
24                .map(|(_, declaration)| get_declaration_packed_len(declaration, definitions))
25                .max()
26                .unwrap_or(0)
27        }
28        Some(Definition::Struct { fields }) => match fields {
29            Fields::NamedFields(named_fields) => named_fields
30                .iter()
31                .map(|(_, declaration)| get_declaration_packed_len(declaration, definitions))
32                .sum(),
33            Fields::UnnamedFields(declarations) => declarations
34                .iter()
35                .map(|declaration| get_declaration_packed_len(declaration, definitions))
36                .sum(),
37            Fields::Empty => 0,
38        },
39        Some(Definition::Sequence {
40            elements: _elements,
41        }) => panic!("Missing support for Definition::Sequence"),
42        Some(Definition::Tuple { elements }) => elements
43            .iter()
44            .map(|element| get_declaration_packed_len(element, definitions))
45            .sum(),
46        None => match declaration {
47            "bool" | "u8" | "i8" => 1,
48            "u16" | "i16" => 2,
49            "u32" | "i32" => 4,
50            "u64" | "i64" => 8,
51            "u128" | "i128" => 16,
52            "nil" => 0,
53            _ => panic!("Missing primitive type: {}", declaration),
54        },
55    }
56}
57
58/// Get the worst-case packed length for the given BorshSchema
59///
60/// Note: due to the serializer currently used by Borsh, this function cannot
61/// be used on-chain in the Gemachain BPF execution environment.
62pub fn get_packed_len<S: BorshSchema>() -> usize {
63    let schema_container = S::schema_container();
64    get_declaration_packed_len(&schema_container.declaration, &schema_container.definitions)
65}
66
67/// Deserializes without checking that the entire slice has been consumed
68///
69/// Normally, `try_from_slice` checks the length of the final slice to ensure
70/// that the deserialization uses up all of the bytes in the slice.
71///
72/// Note that there is a potential issue with this function. Any buffer greater than
73/// or equal to the expected size will properly deserialize. For example, if the
74/// user passes a buffer destined for a different type, the error won't get caught
75/// as easily.
76pub fn try_from_slice_unchecked<T: BorshDeserialize>(data: &[u8]) -> Result<T, Error> {
77    let mut data_mut = data;
78    let result = T::deserialize(&mut data_mut)?;
79    Ok(result)
80}
81
82/// Helper struct which to count how much data would be written during serialization
83#[derive(Default)]
84struct WriteCounter {
85    count: usize,
86}
87
88impl Write for WriteCounter {
89    fn write(&mut self, data: &[u8]) -> Result<usize, Error> {
90        let amount = data.len();
91        self.count += amount;
92        Ok(amount)
93    }
94
95    fn flush(&mut self) -> Result<(), Error> {
96        Ok(())
97    }
98}
99
100/// Get the packed length for the serialized form of this object instance.
101///
102/// Useful when working with instances of types that contain a variable-length
103/// sequence, such as a Vec or HashMap.  Since it is impossible to know the packed
104/// length only from the type's schema, this can be used when an instance already
105/// exists, to figure out how much space to allocate in an account.
106pub fn get_instance_packed_len<T: BorshSerialize>(instance: &T) -> Result<usize, Error> {
107    let mut counter = WriteCounter::default();
108    instance.serialize(&mut counter)?;
109    Ok(counter.count)
110}
111
112#[cfg(test)]
113mod tests {
114    use {
115        super::*,
116        borsh::{maybestd::io::ErrorKind, BorshSchema, BorshSerialize},
117        std::{collections::HashMap, mem::size_of},
118    };
119
120    #[derive(PartialEq, Clone, Debug, BorshSerialize, BorshDeserialize, BorshSchema)]
121    enum TestEnum {
122        NoValue,
123        Value(u32),
124        StructValue {
125            #[allow(dead_code)]
126            number: u64,
127            #[allow(dead_code)]
128            array: [u8; 8],
129        },
130    }
131
132    // for test simplicity
133    impl Default for TestEnum {
134        fn default() -> Self {
135            Self::NoValue
136        }
137    }
138
139    #[derive(Default, BorshSerialize, BorshDeserialize, BorshSchema)]
140    struct TestStruct {
141        pub array: [u64; 16],
142        pub number_u128: u128,
143        pub number_u32: u32,
144        pub tuple: (u8, u16),
145        pub enumeration: TestEnum,
146        pub r#bool: bool,
147    }
148
149    #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)]
150    struct Child {
151        pub data: [u8; 64],
152    }
153
154    #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)]
155    struct Parent {
156        pub data: Vec<Child>,
157    }
158
159    #[test]
160    fn unchecked_deserialization() {
161        let data = vec![
162            Child { data: [0u8; 64] },
163            Child { data: [1u8; 64] },
164            Child { data: [2u8; 64] },
165        ];
166        let parent = Parent { data };
167
168        // exact size, both work
169        let mut byte_vec = vec![0u8; 4 + get_packed_len::<Child>() * 3];
170        let mut bytes = byte_vec.as_mut_slice();
171        parent.serialize(&mut bytes).unwrap();
172        let deserialized = Parent::try_from_slice(&byte_vec).unwrap();
173        assert_eq!(deserialized, parent);
174        let deserialized = try_from_slice_unchecked::<Parent>(&byte_vec).unwrap();
175        assert_eq!(deserialized, parent);
176
177        // too big, only unchecked works
178        let mut byte_vec = vec![0u8; 4 + get_packed_len::<Child>() * 10];
179        let mut bytes = byte_vec.as_mut_slice();
180        parent.serialize(&mut bytes).unwrap();
181        let err = Parent::try_from_slice(&byte_vec).unwrap_err();
182        assert_eq!(err.kind(), ErrorKind::InvalidData);
183        let deserialized = try_from_slice_unchecked::<Parent>(&byte_vec).unwrap();
184        assert_eq!(deserialized, parent);
185    }
186
187    #[test]
188    fn packed_len() {
189        assert_eq!(
190            get_packed_len::<TestEnum>(),
191            size_of::<u8>() + size_of::<u64>() + size_of::<u8>() * 8
192        );
193        assert_eq!(
194            get_packed_len::<TestStruct>(),
195            size_of::<u64>() * 16
196                + size_of::<bool>()
197                + size_of::<u128>()
198                + size_of::<u32>()
199                + size_of::<u8>()
200                + size_of::<u16>()
201                + get_packed_len::<TestEnum>()
202        );
203    }
204
205    #[test]
206    fn instance_packed_len_matches_packed_len() {
207        let enumeration = TestEnum::StructValue {
208            number: u64::MAX,
209            array: [255; 8],
210        };
211        assert_eq!(
212            get_packed_len::<TestEnum>(),
213            get_instance_packed_len(&enumeration).unwrap(),
214        );
215        let test_struct = TestStruct {
216            enumeration,
217            ..TestStruct::default()
218        };
219        assert_eq!(
220            get_packed_len::<TestStruct>(),
221            get_instance_packed_len(&test_struct).unwrap(),
222        );
223        assert_eq!(
224            get_packed_len::<u8>(),
225            get_instance_packed_len(&0u8).unwrap(),
226        );
227        assert_eq!(
228            get_packed_len::<u16>(),
229            get_instance_packed_len(&0u16).unwrap(),
230        );
231        assert_eq!(
232            get_packed_len::<u32>(),
233            get_instance_packed_len(&0u32).unwrap(),
234        );
235        assert_eq!(
236            get_packed_len::<u64>(),
237            get_instance_packed_len(&0u64).unwrap(),
238        );
239        assert_eq!(
240            get_packed_len::<u128>(),
241            get_instance_packed_len(&0u128).unwrap(),
242        );
243        assert_eq!(
244            get_packed_len::<[u8; 10]>(),
245            get_instance_packed_len(&[0u8; 10]).unwrap(),
246        );
247        assert_eq!(
248            get_packed_len::<(i8, i16, i32, i64, i128)>(),
249            get_instance_packed_len(&(i8::MAX, i16::MAX, i32::MAX, i64::MAX, i128::MAX)).unwrap(),
250        );
251    }
252
253    #[test]
254    fn instance_packed_len_with_vec() {
255        let data = vec![
256            Child { data: [0u8; 64] },
257            Child { data: [1u8; 64] },
258            Child { data: [2u8; 64] },
259            Child { data: [3u8; 64] },
260            Child { data: [4u8; 64] },
261            Child { data: [5u8; 64] },
262        ];
263        let parent = Parent { data };
264        assert_eq!(
265            get_instance_packed_len(&parent).unwrap(),
266            4 + parent.data.len() * get_packed_len::<Child>()
267        );
268    }
269
270    #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)]
271    struct StructWithHashMap {
272        data: HashMap<String, TestEnum>,
273    }
274
275    #[test]
276    fn instance_packed_len_with_varying_sizes_in_hashmap() {
277        let mut data = HashMap::new();
278        let string1 = "the first string, it's actually really really long".to_string();
279        let enum1 = TestEnum::NoValue;
280        let string2 = "second string, shorter".to_string();
281        let enum2 = TestEnum::Value(u32::MAX);
282        let string3 = "third".to_string();
283        let enum3 = TestEnum::StructValue {
284            number: 0,
285            array: [0; 8],
286        };
287        data.insert(string1.clone(), enum1.clone());
288        data.insert(string2.clone(), enum2.clone());
289        data.insert(string3.clone(), enum3.clone());
290        let instance = StructWithHashMap { data };
291        assert_eq!(
292            get_instance_packed_len(&instance).unwrap(),
293            4 + get_instance_packed_len(&string1).unwrap()
294                + get_instance_packed_len(&enum1).unwrap()
295                + get_instance_packed_len(&string2).unwrap()
296                + get_instance_packed_len(&enum2).unwrap()
297                + get_instance_packed_len(&string3).unwrap()
298                + get_instance_packed_len(&enum3).unwrap()
299        );
300    }
301}