1#![allow(clippy::integer_arithmetic)]
2use {
4 borsh::{
5 maybestd::io::{Error, Write},
6 schema::{BorshSchema, Declaration, Definition, Fields},
7 BorshDeserialize, BorshSerialize,
8 },
9 std::collections::HashMap,
10};
11
12fn get_declaration_packed_len(
14 declaration: &str,
15 definitions: &HashMap<Declaration, Definition>,
16) -> usize {
17 match definitions.get(declaration) {
18 Some(Definition::Array { length, elements }) => {
19 *length as usize * get_declaration_packed_len(elements, definitions)
20 }
21 Some(Definition::Enum { variants }) => {
22 1 + variants
23 .iter()
24 .map(|(_, declaration)| get_declaration_packed_len(declaration, definitions))
25 .max()
26 .unwrap_or(0)
27 }
28 Some(Definition::Struct { fields }) => match fields {
29 Fields::NamedFields(named_fields) => named_fields
30 .iter()
31 .map(|(_, declaration)| get_declaration_packed_len(declaration, definitions))
32 .sum(),
33 Fields::UnnamedFields(declarations) => declarations
34 .iter()
35 .map(|declaration| get_declaration_packed_len(declaration, definitions))
36 .sum(),
37 Fields::Empty => 0,
38 },
39 Some(Definition::Sequence {
40 elements: _elements,
41 }) => panic!("Missing support for Definition::Sequence"),
42 Some(Definition::Tuple { elements }) => elements
43 .iter()
44 .map(|element| get_declaration_packed_len(element, definitions))
45 .sum(),
46 None => match declaration {
47 "bool" | "u8" | "i8" => 1,
48 "u16" | "i16" => 2,
49 "u32" | "i32" => 4,
50 "u64" | "i64" => 8,
51 "u128" | "i128" => 16,
52 "nil" => 0,
53 _ => panic!("Missing primitive type: {}", declaration),
54 },
55 }
56}
57
58pub fn get_packed_len<S: BorshSchema>() -> usize {
63 let schema_container = S::schema_container();
64 get_declaration_packed_len(&schema_container.declaration, &schema_container.definitions)
65}
66
67pub fn try_from_slice_unchecked<T: BorshDeserialize>(data: &[u8]) -> Result<T, Error> {
77 let mut data_mut = data;
78 let result = T::deserialize(&mut data_mut)?;
79 Ok(result)
80}
81
82#[derive(Default)]
84struct WriteCounter {
85 count: usize,
86}
87
88impl Write for WriteCounter {
89 fn write(&mut self, data: &[u8]) -> Result<usize, Error> {
90 let amount = data.len();
91 self.count += amount;
92 Ok(amount)
93 }
94
95 fn flush(&mut self) -> Result<(), Error> {
96 Ok(())
97 }
98}
99
100pub fn get_instance_packed_len<T: BorshSerialize>(instance: &T) -> Result<usize, Error> {
107 let mut counter = WriteCounter::default();
108 instance.serialize(&mut counter)?;
109 Ok(counter.count)
110}
111
112#[cfg(test)]
113mod tests {
114 use {
115 super::*,
116 borsh::{maybestd::io::ErrorKind, BorshSchema, BorshSerialize},
117 std::{collections::HashMap, mem::size_of},
118 };
119
120 #[derive(PartialEq, Clone, Debug, BorshSerialize, BorshDeserialize, BorshSchema)]
121 enum TestEnum {
122 NoValue,
123 Value(u32),
124 StructValue {
125 #[allow(dead_code)]
126 number: u64,
127 #[allow(dead_code)]
128 array: [u8; 8],
129 },
130 }
131
132 impl Default for TestEnum {
134 fn default() -> Self {
135 Self::NoValue
136 }
137 }
138
139 #[derive(Default, BorshSerialize, BorshDeserialize, BorshSchema)]
140 struct TestStruct {
141 pub array: [u64; 16],
142 pub number_u128: u128,
143 pub number_u32: u32,
144 pub tuple: (u8, u16),
145 pub enumeration: TestEnum,
146 pub r#bool: bool,
147 }
148
149 #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)]
150 struct Child {
151 pub data: [u8; 64],
152 }
153
154 #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)]
155 struct Parent {
156 pub data: Vec<Child>,
157 }
158
159 #[test]
160 fn unchecked_deserialization() {
161 let data = vec![
162 Child { data: [0u8; 64] },
163 Child { data: [1u8; 64] },
164 Child { data: [2u8; 64] },
165 ];
166 let parent = Parent { data };
167
168 let mut byte_vec = vec![0u8; 4 + get_packed_len::<Child>() * 3];
170 let mut bytes = byte_vec.as_mut_slice();
171 parent.serialize(&mut bytes).unwrap();
172 let deserialized = Parent::try_from_slice(&byte_vec).unwrap();
173 assert_eq!(deserialized, parent);
174 let deserialized = try_from_slice_unchecked::<Parent>(&byte_vec).unwrap();
175 assert_eq!(deserialized, parent);
176
177 let mut byte_vec = vec![0u8; 4 + get_packed_len::<Child>() * 10];
179 let mut bytes = byte_vec.as_mut_slice();
180 parent.serialize(&mut bytes).unwrap();
181 let err = Parent::try_from_slice(&byte_vec).unwrap_err();
182 assert_eq!(err.kind(), ErrorKind::InvalidData);
183 let deserialized = try_from_slice_unchecked::<Parent>(&byte_vec).unwrap();
184 assert_eq!(deserialized, parent);
185 }
186
187 #[test]
188 fn packed_len() {
189 assert_eq!(
190 get_packed_len::<TestEnum>(),
191 size_of::<u8>() + size_of::<u64>() + size_of::<u8>() * 8
192 );
193 assert_eq!(
194 get_packed_len::<TestStruct>(),
195 size_of::<u64>() * 16
196 + size_of::<bool>()
197 + size_of::<u128>()
198 + size_of::<u32>()
199 + size_of::<u8>()
200 + size_of::<u16>()
201 + get_packed_len::<TestEnum>()
202 );
203 }
204
205 #[test]
206 fn instance_packed_len_matches_packed_len() {
207 let enumeration = TestEnum::StructValue {
208 number: u64::MAX,
209 array: [255; 8],
210 };
211 assert_eq!(
212 get_packed_len::<TestEnum>(),
213 get_instance_packed_len(&enumeration).unwrap(),
214 );
215 let test_struct = TestStruct {
216 enumeration,
217 ..TestStruct::default()
218 };
219 assert_eq!(
220 get_packed_len::<TestStruct>(),
221 get_instance_packed_len(&test_struct).unwrap(),
222 );
223 assert_eq!(
224 get_packed_len::<u8>(),
225 get_instance_packed_len(&0u8).unwrap(),
226 );
227 assert_eq!(
228 get_packed_len::<u16>(),
229 get_instance_packed_len(&0u16).unwrap(),
230 );
231 assert_eq!(
232 get_packed_len::<u32>(),
233 get_instance_packed_len(&0u32).unwrap(),
234 );
235 assert_eq!(
236 get_packed_len::<u64>(),
237 get_instance_packed_len(&0u64).unwrap(),
238 );
239 assert_eq!(
240 get_packed_len::<u128>(),
241 get_instance_packed_len(&0u128).unwrap(),
242 );
243 assert_eq!(
244 get_packed_len::<[u8; 10]>(),
245 get_instance_packed_len(&[0u8; 10]).unwrap(),
246 );
247 assert_eq!(
248 get_packed_len::<(i8, i16, i32, i64, i128)>(),
249 get_instance_packed_len(&(i8::MAX, i16::MAX, i32::MAX, i64::MAX, i128::MAX)).unwrap(),
250 );
251 }
252
253 #[test]
254 fn instance_packed_len_with_vec() {
255 let data = vec![
256 Child { data: [0u8; 64] },
257 Child { data: [1u8; 64] },
258 Child { data: [2u8; 64] },
259 Child { data: [3u8; 64] },
260 Child { data: [4u8; 64] },
261 Child { data: [5u8; 64] },
262 ];
263 let parent = Parent { data };
264 assert_eq!(
265 get_instance_packed_len(&parent).unwrap(),
266 4 + parent.data.len() * get_packed_len::<Child>()
267 );
268 }
269
270 #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)]
271 struct StructWithHashMap {
272 data: HashMap<String, TestEnum>,
273 }
274
275 #[test]
276 fn instance_packed_len_with_varying_sizes_in_hashmap() {
277 let mut data = HashMap::new();
278 let string1 = "the first string, it's actually really really long".to_string();
279 let enum1 = TestEnum::NoValue;
280 let string2 = "second string, shorter".to_string();
281 let enum2 = TestEnum::Value(u32::MAX);
282 let string3 = "third".to_string();
283 let enum3 = TestEnum::StructValue {
284 number: 0,
285 array: [0; 8],
286 };
287 data.insert(string1.clone(), enum1.clone());
288 data.insert(string2.clone(), enum2.clone());
289 data.insert(string3.clone(), enum3.clone());
290 let instance = StructWithHashMap { data };
291 assert_eq!(
292 get_instance_packed_len(&instance).unwrap(),
293 4 + get_instance_packed_len(&string1).unwrap()
294 + get_instance_packed_len(&enum1).unwrap()
295 + get_instance_packed_len(&string2).unwrap()
296 + get_instance_packed_len(&enum2).unwrap()
297 + get_instance_packed_len(&string3).unwrap()
298 + get_instance_packed_len(&enum3).unwrap()
299 );
300 }
301}