1use data_size::DataSize;
2use serde::{Deserialize, Deserializer};
3use std::fmt;
4
5use crate::{types::TypeInner, CandidType};
6
7pub const UNBOUNDED: usize = usize::MAX;
9
10#[derive(Clone, Eq, PartialEq, Debug, Default)]
29pub struct BoundedVec<
30 const MAX_ALLOWED_LEN: usize,
31 const MAX_ALLOWED_TOTAL_DATA_SIZE: usize,
32 const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize,
33 T,
34>(Vec<T>);
35
36impl<
37 const MAX_ALLOWED_LEN: usize,
38 const MAX_ALLOWED_TOTAL_DATA_SIZE: usize,
39 const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize,
40 T: CandidType,
41 > CandidType
42 for BoundedVec<MAX_ALLOWED_LEN, MAX_ALLOWED_TOTAL_DATA_SIZE, MAX_ALLOWED_ELEMENT_DATA_SIZE, T>
43{
44 fn _ty() -> super::Type {
45 TypeInner::Vec(T::_ty()).into()
46 }
47
48 fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
49 where
50 S: super::Serializer,
51 {
52 self.0.idl_serialize(serializer)
53 }
54}
55
56impl<
57 const MAX_ALLOWED_LEN: usize,
58 const MAX_ALLOWED_TOTAL_DATA_SIZE: usize,
59 const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize,
60 T,
61 > BoundedVec<MAX_ALLOWED_LEN, MAX_ALLOWED_TOTAL_DATA_SIZE, MAX_ALLOWED_ELEMENT_DATA_SIZE, T>
62{
63 pub fn new(data: Vec<T>) -> Self {
64 assert!(
65 MAX_ALLOWED_LEN != UNBOUNDED
66 || MAX_ALLOWED_TOTAL_DATA_SIZE != UNBOUNDED
67 || MAX_ALLOWED_ELEMENT_DATA_SIZE != UNBOUNDED,
68 "BoundedVec must be bounded by at least one parameter."
69 );
70
71 Self(data)
72 }
73
74 pub fn get(&self) -> &Vec<T> {
75 &self.0
76 }
77}
78
79impl<
80 'de,
81 const MAX_ALLOWED_LEN: usize,
82 const MAX_ALLOWED_TOTAL_DATA_SIZE: usize,
83 const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize,
84 T: Deserialize<'de> + DataSize,
85 > Deserialize<'de>
86 for BoundedVec<MAX_ALLOWED_LEN, MAX_ALLOWED_TOTAL_DATA_SIZE, MAX_ALLOWED_ELEMENT_DATA_SIZE, T>
87{
88 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
89 struct SeqVisitor<
90 const MAX_ALLOWED_LEN: usize,
91 const MAX_ALLOWED_TOTAL_DATA_SIZE: usize,
92 const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize,
93 T,
94 > {
95 _marker: std::marker::PhantomData<T>,
96 }
97
98 use serde::de::{SeqAccess, Visitor};
99
100 impl<
101 'de,
102 const MAX_ALLOWED_LEN: usize,
103 const MAX_ALLOWED_TOTAL_DATA_SIZE: usize,
104 const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize,
105 T: Deserialize<'de> + DataSize,
106 > Visitor<'de>
107 for SeqVisitor<
108 MAX_ALLOWED_LEN,
109 MAX_ALLOWED_TOTAL_DATA_SIZE,
110 MAX_ALLOWED_ELEMENT_DATA_SIZE,
111 T,
112 >
113 {
114 type Value = BoundedVec<
115 MAX_ALLOWED_LEN,
116 MAX_ALLOWED_TOTAL_DATA_SIZE,
117 MAX_ALLOWED_ELEMENT_DATA_SIZE,
118 T,
119 >;
120
121 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
122 write!(
123 formatter,
124 "{}",
125 describe_sequence(
126 MAX_ALLOWED_LEN,
127 MAX_ALLOWED_TOTAL_DATA_SIZE,
128 MAX_ALLOWED_ELEMENT_DATA_SIZE,
129 )
130 )
131 }
132
133 fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
134 where
135 S: SeqAccess<'de>,
136 {
137 let mut total_data_size = 0;
138 let mut elements = if MAX_ALLOWED_LEN == UNBOUNDED {
139 Vec::new()
140 } else {
141 Vec::with_capacity(MAX_ALLOWED_LEN)
142 };
143 while let Some(element) = seq.next_element::<T>()? {
144 if elements.len() >= MAX_ALLOWED_LEN {
145 return Err(serde::de::Error::custom(format!(
146 "The number of elements exceeds maximum allowed {MAX_ALLOWED_LEN}"
147 )));
148 }
149 let new_element_data_size = element.data_size();
151 if new_element_data_size > MAX_ALLOWED_ELEMENT_DATA_SIZE {
152 return Err(serde::de::Error::custom(format!(
153 "The single element data size exceeds maximum allowed {MAX_ALLOWED_ELEMENT_DATA_SIZE}"
154 )));
155 }
156 let new_total_data_size = total_data_size + new_element_data_size;
159 if new_total_data_size > MAX_ALLOWED_TOTAL_DATA_SIZE {
160 return Err(serde::de::Error::custom(format!(
161 "The total data size exceeds maximum allowed {MAX_ALLOWED_TOTAL_DATA_SIZE}"
162 )));
163 }
164 total_data_size = new_total_data_size;
165 elements.push(element);
166 }
167 Ok(BoundedVec::new(elements))
168 }
169 }
170
171 deserializer.deserialize_seq(SeqVisitor::<
172 MAX_ALLOWED_LEN,
173 MAX_ALLOWED_TOTAL_DATA_SIZE,
174 MAX_ALLOWED_ELEMENT_DATA_SIZE,
175 T,
176 > {
177 _marker: std::marker::PhantomData,
178 })
179 }
180}
181
182fn describe_sequence(
183 max_allowed_len: usize,
184 max_allowed_total_data_size: usize,
185 max_allowed_element_data_size: usize,
186) -> String {
187 let mut msg = String::new();
188 if max_allowed_len != UNBOUNDED {
189 msg.push_str(&format!("max {max_allowed_len} elements"));
190 };
191 if max_allowed_total_data_size != UNBOUNDED {
192 if !msg.is_empty() {
193 msg.push_str(", ");
194 }
195 msg.push_str(&format!("max {max_allowed_total_data_size} bytes total"));
196 };
197 if max_allowed_element_data_size != UNBOUNDED {
198 if !msg.is_empty() {
199 msg.push_str(", ");
200 }
201 msg.push_str(&format!(
202 "max {max_allowed_element_data_size} bytes per element"
203 ));
204 };
205 format!("a sequence with {msg}")
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use crate::{Decode, Encode};
212
213 #[test]
214 fn test_describe_sequence() {
215 assert_eq!(
216 describe_sequence(42, UNBOUNDED, UNBOUNDED),
217 "a sequence with max 42 elements".to_string()
218 );
219 assert_eq!(
220 describe_sequence(UNBOUNDED, 256, UNBOUNDED),
221 "a sequence with max 256 bytes total".to_string(),
222 );
223 assert_eq!(
224 describe_sequence(UNBOUNDED, UNBOUNDED, 64),
225 "a sequence with max 64 bytes per element".to_string(),
226 );
227 assert_eq!(
228 describe_sequence(42, 256, UNBOUNDED),
229 "a sequence with max 42 elements, max 256 bytes total".to_string(),
230 );
231 assert_eq!(
232 describe_sequence(42, UNBOUNDED, 64),
233 "a sequence with max 42 elements, max 64 bytes per element".to_string(),
234 );
235 assert_eq!(
236 describe_sequence(UNBOUNDED, 256, 64),
237 "a sequence with max 256 bytes total, max 64 bytes per element".to_string(),
238 );
239 assert_eq!(
240 describe_sequence(42, 256, 64),
241 "a sequence with max 42 elements, max 256 bytes total, max 64 bytes per element"
242 .to_string(),
243 );
244 }
245
246 #[test]
247 #[should_panic]
248 fn test_not_bounded_vector_fails() {
249 type NotBoundedVec = BoundedVec<UNBOUNDED, UNBOUNDED, UNBOUNDED, u8>;
250
251 let _ = NotBoundedVec::new(vec![1, 2, 3]);
252 }
253
254 #[test]
255 fn test_bounded_vector_lengths() {
256 type BoundedLen = BoundedVec<MAX_ALLOWED_LEN, UNBOUNDED, UNBOUNDED, u8>;
259
260 const MAX_ALLOWED_LEN: usize = 30;
261 const TEST_START: usize = 20;
262 const TEST_END: usize = 40;
263 for i in TEST_START..=TEST_END {
264 let data = BoundedLen::new(vec![42; i]);
266
267 let bytes = Encode!(&data).unwrap();
269 let result = Decode!(&bytes, BoundedLen);
270
271 if i <= MAX_ALLOWED_LEN {
273 assert!(result.is_ok());
275 assert_eq!(result.unwrap(), data);
276 } else {
277 assert!(result.is_err());
279 let error = result.unwrap_err();
280 assert!(
281 format!("{error:?}").contains(&format!(
282 "Deserialize error: The number of elements exceeds maximum allowed {MAX_ALLOWED_LEN}"
283 )),
284 "Actual: {}",
285 error
286 );
287 }
288 }
289 }
290
291 #[test]
292 fn test_bounded_vector_total_data_sizes() {
293 const MAX_ALLOWED_TOTAL_DATA_SIZE: usize = 100;
296 const ELEMENT_SIZE: usize = 37;
297 assert_ne!(MAX_ALLOWED_TOTAL_DATA_SIZE % ELEMENT_SIZE, 0);
299 for aimed_total_size in 64..=256 {
300 type BoundedSize =
302 BoundedVec<UNBOUNDED, MAX_ALLOWED_TOTAL_DATA_SIZE, UNBOUNDED, Vec<u8>>;
303 let element = vec![b'a'; ELEMENT_SIZE - std::mem::size_of::<Vec<u8>>()];
304 let elements_count = aimed_total_size / element.data_size();
305 let data = BoundedSize::new(vec![element; elements_count]);
306 let actual_total_size = data.get().data_size();
307
308 let bytes = Encode!(&data).unwrap();
310 let result = Decode!(&bytes, BoundedSize);
311
312 if actual_total_size <= MAX_ALLOWED_TOTAL_DATA_SIZE {
314 assert!(result.is_ok());
316 assert_eq!(result.unwrap(), data);
317 } else {
318 assert!(result.is_err());
320 let error = result.unwrap_err();
321 assert!(
322 format!("{error:?}").contains(&format!(
323 "Deserialize error: The total data size exceeds maximum allowed {MAX_ALLOWED_TOTAL_DATA_SIZE}"
324 )),
325 "Actual: {}",
326 error
327 );
328 }
329 }
330 }
331
332 #[test]
333 fn test_bounded_vector_element_data_sizes() {
334 const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize = 100;
337 for element_size in 64..=256 {
338 type BoundedSize =
340 BoundedVec<UNBOUNDED, UNBOUNDED, MAX_ALLOWED_ELEMENT_DATA_SIZE, Vec<u8>>;
341 let element = vec![b'a'; element_size - std::mem::size_of::<Vec<u8>>()];
342 let data = BoundedSize::new(vec![element; 42]);
343
344 let bytes = Encode!(&data).unwrap();
346 let result = Decode!(&bytes, BoundedSize);
347
348 if element_size <= MAX_ALLOWED_ELEMENT_DATA_SIZE {
350 assert!(result.is_ok());
352 assert_eq!(result.unwrap(), data);
353 } else {
354 assert!(result.is_err());
356 let error = result.unwrap_err();
357 assert!(
358 format!("{error:?}").contains(&format!(
359 "Deserialize error: The single element data size exceeds maximum allowed {MAX_ALLOWED_ELEMENT_DATA_SIZE}"
360 )),
361 "Actual: {}",
362 error
363 );
364 }
365 }
366 }
367}
368
369mod data_size {
370 pub trait DataSize {
374 fn data_size(&self) -> usize {
376 0
377 }
378 }
379
380 impl DataSize for u8 {
381 fn data_size(&self) -> usize {
382 std::mem::size_of::<u8>()
383 }
384 }
385
386 impl DataSize for [u8] {
387 fn data_size(&self) -> usize {
388 std::mem::size_of_val(self)
389 }
390 }
391
392 impl DataSize for u64 {
393 fn data_size(&self) -> usize {
394 std::mem::size_of::<u64>()
395 }
396 }
397
398 impl DataSize for &str {
399 fn data_size(&self) -> usize {
400 self.as_bytes().data_size()
401 }
402 }
403
404 impl DataSize for String {
405 fn data_size(&self) -> usize {
406 self.as_bytes().data_size()
407 }
408 }
409
410 impl<T: DataSize> DataSize for Vec<T> {
411 fn data_size(&self) -> usize {
412 std::mem::size_of::<Self>() + self.iter().map(|x| x.data_size()).sum::<usize>()
413 }
414 }
415
416 impl DataSize for ic_principal::Principal {
417 fn data_size(&self) -> usize {
418 self.as_slice().len()
419 }
420 }
421
422 #[cfg(test)]
423 mod tests {
424 use super::*;
425
426 #[test]
427 fn test_data_size_u8() {
428 assert_eq!(0_u8.data_size(), 1);
429 assert_eq!(42_u8.data_size(), 1);
430 }
431
432 #[test]
433 fn test_data_size_u8_slice() {
434 let a: [u8; 0] = [];
435 assert_eq!(a.data_size(), 0);
436 assert_eq!([1_u8].data_size(), 1);
437 assert_eq!([1_u8, 2_u8].data_size(), 2);
438 }
439
440 #[test]
441 fn test_data_size_u64() {
442 assert_eq!(0_u64.data_size(), 8);
443 assert_eq!(42_u64.data_size(), 8);
444 }
445
446 #[test]
447 fn test_data_size_u8_vec() {
448 let base = 24;
449 assert_eq!(Vec::<u8>::from([]).data_size(), base);
450 assert_eq!(Vec::<u8>::from([1]).data_size(), base + 1);
451 assert_eq!(Vec::<u8>::from([1, 2]).data_size(), base + 2);
452 }
453
454 #[test]
455 fn test_data_size_str() {
456 assert_eq!("a".data_size(), 1);
457 assert_eq!("ab".data_size(), 2);
458 }
459
460 #[test]
461 fn test_data_size_string() {
462 assert_eq!(String::from("a").data_size(), 1);
463 assert_eq!(String::from("ab").data_size(), 2);
464 for size_bytes in 0..1_024 {
465 assert_eq!(
466 String::from_utf8(vec![b'x'; size_bytes])
467 .unwrap()
468 .data_size(),
469 size_bytes
470 );
471 }
472 }
473 }
474}