Skip to main content

candid/types/
bounded_vec.rs

1use data_size::DataSize;
2use serde::{Deserialize, Deserializer};
3use std::fmt;
4
5use crate::{types::TypeInner, CandidType};
6
7/// Indicates that `BoundedVec<...>` template parameter (eg. length, total data size, etc) is unbounded.
8pub const UNBOUNDED: usize = usize::MAX;
9
10/// Struct for bounding a vector by different parameters:
11/// - number of elements
12/// - total data size in bytes
13/// - single element data size in bytes
14///
15/// ```
16/// # use candid::{Decode, Encode};
17/// # use candid::types::bounded_vec::{BoundedVec, UNBOUNDED};
18/// // E.g., a user of your service sends candid-encoded bytes:
19/// let too_long = vec![13u64; 11];
20/// let bytes_too_long = Encode!(&too_long).unwrap();
21/// // Your service should decode the untrusted bytes with care, by decoding to BoundedVec<10, _, _, _>.
22/// // Since the user sent 11 elements and you allow at most 10, this will fail, keeping the service safe from certain exploits.
23/// let error =
24///     Decode!(&bytes_too_long, BoundedVec<10, UNBOUNDED, UNBOUNDED, u64>).unwrap_err();
25/// assert!(format!("{error:?}").contains("exceeds maximum allowed"));
26/// ```
27///
28#[derive(Clone, Eq, PartialEq, Debug, Default)]
29pub struct BoundedVec<
30    const MAX_ALLOWED_LEN: usize,
31    const MAX_ALLOWED_TOTAL_DATA_SIZE: usize,
32    const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize,
33    T,
34>(Vec<T>);
35
36impl<
37        const MAX_ALLOWED_LEN: usize,
38        const MAX_ALLOWED_TOTAL_DATA_SIZE: usize,
39        const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize,
40        T: CandidType,
41    > CandidType
42    for BoundedVec<MAX_ALLOWED_LEN, MAX_ALLOWED_TOTAL_DATA_SIZE, MAX_ALLOWED_ELEMENT_DATA_SIZE, T>
43{
44    fn _ty() -> super::Type {
45        TypeInner::Vec(T::_ty()).into()
46    }
47
48    fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
49    where
50        S: super::Serializer,
51    {
52        self.0.idl_serialize(serializer)
53    }
54}
55
56impl<
57        const MAX_ALLOWED_LEN: usize,
58        const MAX_ALLOWED_TOTAL_DATA_SIZE: usize,
59        const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize,
60        T,
61    > BoundedVec<MAX_ALLOWED_LEN, MAX_ALLOWED_TOTAL_DATA_SIZE, MAX_ALLOWED_ELEMENT_DATA_SIZE, T>
62{
63    pub fn new(data: Vec<T>) -> Self {
64        assert!(
65            MAX_ALLOWED_LEN != UNBOUNDED
66                || MAX_ALLOWED_TOTAL_DATA_SIZE != UNBOUNDED
67                || MAX_ALLOWED_ELEMENT_DATA_SIZE != UNBOUNDED,
68            "BoundedVec must be bounded by at least one parameter."
69        );
70
71        Self(data)
72    }
73
74    pub fn get(&self) -> &Vec<T> {
75        &self.0
76    }
77}
78
79impl<
80        'de,
81        const MAX_ALLOWED_LEN: usize,
82        const MAX_ALLOWED_TOTAL_DATA_SIZE: usize,
83        const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize,
84        T: Deserialize<'de> + DataSize,
85    > Deserialize<'de>
86    for BoundedVec<MAX_ALLOWED_LEN, MAX_ALLOWED_TOTAL_DATA_SIZE, MAX_ALLOWED_ELEMENT_DATA_SIZE, T>
87{
88    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
89        struct SeqVisitor<
90            const MAX_ALLOWED_LEN: usize,
91            const MAX_ALLOWED_TOTAL_DATA_SIZE: usize,
92            const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize,
93            T,
94        > {
95            _marker: std::marker::PhantomData<T>,
96        }
97
98        use serde::de::{SeqAccess, Visitor};
99
100        impl<
101                'de,
102                const MAX_ALLOWED_LEN: usize,
103                const MAX_ALLOWED_TOTAL_DATA_SIZE: usize,
104                const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize,
105                T: Deserialize<'de> + DataSize,
106            > Visitor<'de>
107            for SeqVisitor<
108                MAX_ALLOWED_LEN,
109                MAX_ALLOWED_TOTAL_DATA_SIZE,
110                MAX_ALLOWED_ELEMENT_DATA_SIZE,
111                T,
112            >
113        {
114            type Value = BoundedVec<
115                MAX_ALLOWED_LEN,
116                MAX_ALLOWED_TOTAL_DATA_SIZE,
117                MAX_ALLOWED_ELEMENT_DATA_SIZE,
118                T,
119            >;
120
121            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
122                write!(
123                    formatter,
124                    "{}",
125                    describe_sequence(
126                        MAX_ALLOWED_LEN,
127                        MAX_ALLOWED_TOTAL_DATA_SIZE,
128                        MAX_ALLOWED_ELEMENT_DATA_SIZE,
129                    )
130                )
131            }
132
133            fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
134            where
135                S: SeqAccess<'de>,
136            {
137                let mut total_data_size = 0;
138                let mut elements = if MAX_ALLOWED_LEN == UNBOUNDED {
139                    Vec::new()
140                } else {
141                    Vec::with_capacity(MAX_ALLOWED_LEN)
142                };
143                while let Some(element) = seq.next_element::<T>()? {
144                    if elements.len() >= MAX_ALLOWED_LEN {
145                        return Err(serde::de::Error::custom(format!(
146                            "The number of elements exceeds maximum allowed {MAX_ALLOWED_LEN}"
147                        )));
148                    }
149                    // Check that the new element data size is below the maximum allowed limit.
150                    let new_element_data_size = element.data_size();
151                    if new_element_data_size > MAX_ALLOWED_ELEMENT_DATA_SIZE {
152                        return Err(serde::de::Error::custom(format!(
153                            "The single element data size exceeds maximum allowed {MAX_ALLOWED_ELEMENT_DATA_SIZE}"
154                        )));
155                    }
156                    // Check that the new total data size (including new element data size)
157                    // is below the maximum allowed limit.
158                    let new_total_data_size = total_data_size + new_element_data_size;
159                    if new_total_data_size > MAX_ALLOWED_TOTAL_DATA_SIZE {
160                        return Err(serde::de::Error::custom(format!(
161                            "The total data size exceeds maximum allowed {MAX_ALLOWED_TOTAL_DATA_SIZE}"
162                        )));
163                    }
164                    total_data_size = new_total_data_size;
165                    elements.push(element);
166                }
167                Ok(BoundedVec::new(elements))
168            }
169        }
170
171        deserializer.deserialize_seq(SeqVisitor::<
172            MAX_ALLOWED_LEN,
173            MAX_ALLOWED_TOTAL_DATA_SIZE,
174            MAX_ALLOWED_ELEMENT_DATA_SIZE,
175            T,
176        > {
177            _marker: std::marker::PhantomData,
178        })
179    }
180}
181
182fn describe_sequence(
183    max_allowed_len: usize,
184    max_allowed_total_data_size: usize,
185    max_allowed_element_data_size: usize,
186) -> String {
187    let mut msg = String::new();
188    if max_allowed_len != UNBOUNDED {
189        msg.push_str(&format!("max {max_allowed_len} elements"));
190    };
191    if max_allowed_total_data_size != UNBOUNDED {
192        if !msg.is_empty() {
193            msg.push_str(", ");
194        }
195        msg.push_str(&format!("max {max_allowed_total_data_size} bytes total"));
196    };
197    if max_allowed_element_data_size != UNBOUNDED {
198        if !msg.is_empty() {
199            msg.push_str(", ");
200        }
201        msg.push_str(&format!(
202            "max {max_allowed_element_data_size} bytes per element"
203        ));
204    };
205    format!("a sequence with {msg}")
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use crate::{Decode, Encode};
212
213    #[test]
214    fn test_describe_sequence() {
215        assert_eq!(
216            describe_sequence(42, UNBOUNDED, UNBOUNDED),
217            "a sequence with max 42 elements".to_string()
218        );
219        assert_eq!(
220            describe_sequence(UNBOUNDED, 256, UNBOUNDED),
221            "a sequence with max 256 bytes total".to_string(),
222        );
223        assert_eq!(
224            describe_sequence(UNBOUNDED, UNBOUNDED, 64),
225            "a sequence with max 64 bytes per element".to_string(),
226        );
227        assert_eq!(
228            describe_sequence(42, 256, UNBOUNDED),
229            "a sequence with max 42 elements, max 256 bytes total".to_string(),
230        );
231        assert_eq!(
232            describe_sequence(42, UNBOUNDED, 64),
233            "a sequence with max 42 elements, max 64 bytes per element".to_string(),
234        );
235        assert_eq!(
236            describe_sequence(UNBOUNDED, 256, 64),
237            "a sequence with max 256 bytes total, max 64 bytes per element".to_string(),
238        );
239        assert_eq!(
240            describe_sequence(42, 256, 64),
241            "a sequence with max 42 elements, max 256 bytes total, max 64 bytes per element"
242                .to_string(),
243        );
244    }
245
246    #[test]
247    #[should_panic]
248    fn test_not_bounded_vector_fails() {
249        type NotBoundedVec = BoundedVec<UNBOUNDED, UNBOUNDED, UNBOUNDED, u8>;
250
251        let _ = NotBoundedVec::new(vec![1, 2, 3]);
252    }
253
254    #[test]
255    fn test_bounded_vector_lengths() {
256        // This test verifies that the structures containing BoundedVec correctly
257        // throw an error when the number of elements exceeds the maximum allowed.
258        type BoundedLen = BoundedVec<MAX_ALLOWED_LEN, UNBOUNDED, UNBOUNDED, u8>;
259
260        const MAX_ALLOWED_LEN: usize = 30;
261        const TEST_START: usize = 20;
262        const TEST_END: usize = 40;
263        for i in TEST_START..=TEST_END {
264            // Arrange.
265            let data = BoundedLen::new(vec![42; i]);
266
267            // Act.
268            let bytes = Encode!(&data).unwrap();
269            let result = Decode!(&bytes, BoundedLen);
270
271            // Assert.
272            if i <= MAX_ALLOWED_LEN {
273                // Verify decoding without errors for allowed sizes.
274                assert!(result.is_ok());
275                assert_eq!(result.unwrap(), data);
276            } else {
277                // Verify decoding with errors for disallowed sizes.
278                assert!(result.is_err());
279                let error = result.unwrap_err();
280                assert!(
281                    format!("{error:?}").contains(&format!(
282                        "Deserialize error: The number of elements exceeds maximum allowed {MAX_ALLOWED_LEN}"
283                    )),
284                    "Actual: {}",
285                    error
286                );
287            }
288        }
289    }
290
291    #[test]
292    fn test_bounded_vector_total_data_sizes() {
293        // This test verifies that the structures containing BoundedVec correctly
294        // throw an error when the total data size exceeds the maximum allowed.
295        const MAX_ALLOWED_TOTAL_DATA_SIZE: usize = 100;
296        const ELEMENT_SIZE: usize = 37;
297        // Assert element size is not a multiple of total size.
298        assert_ne!(MAX_ALLOWED_TOTAL_DATA_SIZE % ELEMENT_SIZE, 0);
299        for aimed_total_size in 64..=256 {
300            // Arrange.
301            type BoundedSize =
302                BoundedVec<UNBOUNDED, MAX_ALLOWED_TOTAL_DATA_SIZE, UNBOUNDED, Vec<u8>>;
303            let element = vec![b'a'; ELEMENT_SIZE - std::mem::size_of::<Vec<u8>>()];
304            let elements_count = aimed_total_size / element.data_size();
305            let data = BoundedSize::new(vec![element; elements_count]);
306            let actual_total_size = data.get().data_size();
307
308            // Act.
309            let bytes = Encode!(&data).unwrap();
310            let result = Decode!(&bytes, BoundedSize);
311
312            // Assert.
313            if actual_total_size <= MAX_ALLOWED_TOTAL_DATA_SIZE {
314                // Verify decoding without errors for allowed sizes.
315                assert!(result.is_ok());
316                assert_eq!(result.unwrap(), data);
317            } else {
318                // Verify decoding with errors for disallowed sizes.
319                assert!(result.is_err());
320                let error = result.unwrap_err();
321                assert!(
322                    format!("{error:?}").contains(&format!(
323                        "Deserialize error: The total data size exceeds maximum allowed {MAX_ALLOWED_TOTAL_DATA_SIZE}"
324                    )),
325                    "Actual: {}",
326                    error
327                );
328            }
329        }
330    }
331
332    #[test]
333    fn test_bounded_vector_element_data_sizes() {
334        // This test verifies that the structures containing BoundedVec correctly
335        // throw an error when the element data size exceeds the maximum allowed.
336        const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize = 100;
337        for element_size in 64..=256 {
338            // Arrange.
339            type BoundedSize =
340                BoundedVec<UNBOUNDED, UNBOUNDED, MAX_ALLOWED_ELEMENT_DATA_SIZE, Vec<u8>>;
341            let element = vec![b'a'; element_size - std::mem::size_of::<Vec<u8>>()];
342            let data = BoundedSize::new(vec![element; 42]);
343
344            // Act.
345            let bytes = Encode!(&data).unwrap();
346            let result = Decode!(&bytes, BoundedSize);
347
348            // Assert.
349            if element_size <= MAX_ALLOWED_ELEMENT_DATA_SIZE {
350                // Verify decoding without errors for allowed sizes.
351                assert!(result.is_ok());
352                assert_eq!(result.unwrap(), data);
353            } else {
354                // Verify decoding with errors for disallowed sizes.
355                assert!(result.is_err());
356                let error = result.unwrap_err();
357                assert!(
358                    format!("{error:?}").contains(&format!(
359                        "Deserialize error: The single element data size exceeds maximum allowed {MAX_ALLOWED_ELEMENT_DATA_SIZE}"
360                    )),
361                    "Actual: {}",
362                    error
363                );
364            }
365        }
366    }
367}
368
369mod data_size {
370    /// Trait to reasonably estimate the memory usage of a value in bytes.
371    ///
372    /// Default implementation returns zero.
373    pub trait DataSize {
374        /// Default implementation returns zero.
375        fn data_size(&self) -> usize {
376            0
377        }
378    }
379
380    impl DataSize for u8 {
381        fn data_size(&self) -> usize {
382            std::mem::size_of::<u8>()
383        }
384    }
385
386    impl DataSize for [u8] {
387        fn data_size(&self) -> usize {
388            std::mem::size_of_val(self)
389        }
390    }
391
392    impl DataSize for u64 {
393        fn data_size(&self) -> usize {
394            std::mem::size_of::<u64>()
395        }
396    }
397
398    impl DataSize for &str {
399        fn data_size(&self) -> usize {
400            self.as_bytes().data_size()
401        }
402    }
403
404    impl DataSize for String {
405        fn data_size(&self) -> usize {
406            self.as_bytes().data_size()
407        }
408    }
409
410    impl<T: DataSize> DataSize for Vec<T> {
411        fn data_size(&self) -> usize {
412            std::mem::size_of::<Self>() + self.iter().map(|x| x.data_size()).sum::<usize>()
413        }
414    }
415
416    impl DataSize for ic_principal::Principal {
417        fn data_size(&self) -> usize {
418            self.as_slice().len()
419        }
420    }
421
422    #[cfg(test)]
423    mod tests {
424        use super::*;
425
426        #[test]
427        fn test_data_size_u8() {
428            assert_eq!(0_u8.data_size(), 1);
429            assert_eq!(42_u8.data_size(), 1);
430        }
431
432        #[test]
433        fn test_data_size_u8_slice() {
434            let a: [u8; 0] = [];
435            assert_eq!(a.data_size(), 0);
436            assert_eq!([1_u8].data_size(), 1);
437            assert_eq!([1_u8, 2_u8].data_size(), 2);
438        }
439
440        #[test]
441        fn test_data_size_u64() {
442            assert_eq!(0_u64.data_size(), 8);
443            assert_eq!(42_u64.data_size(), 8);
444        }
445
446        #[test]
447        fn test_data_size_u8_vec() {
448            let base = 24;
449            assert_eq!(Vec::<u8>::from([]).data_size(), base);
450            assert_eq!(Vec::<u8>::from([1]).data_size(), base + 1);
451            assert_eq!(Vec::<u8>::from([1, 2]).data_size(), base + 2);
452        }
453
454        #[test]
455        fn test_data_size_str() {
456            assert_eq!("a".data_size(), 1);
457            assert_eq!("ab".data_size(), 2);
458        }
459
460        #[test]
461        fn test_data_size_string() {
462            assert_eq!(String::from("a").data_size(), 1);
463            assert_eq!(String::from("ab").data_size(), 2);
464            for size_bytes in 0..1_024 {
465                assert_eq!(
466                    String::from_utf8(vec![b'x'; size_bytes])
467                        .unwrap()
468                        .data_size(),
469                    size_bytes
470                );
471            }
472        }
473    }
474}