Skip to main content

cdr/
size.rs

1//! Measuring the size of (de)serialized data.
2
3use std;
4
5use serde::ser;
6
7use crate::error::{Error, Result};
8
9/// Limits on the number of bytes that can be read or written.
10pub trait SizeLimit {
11    fn add(&mut self, n: u64) -> Result<()>;
12    fn limit(&self) -> Option<u64>;
13}
14
15/// A `SizeLimit` that restricts serialized or deserialized messages so that
16/// they do not exceed a certain byte length.
17#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
18pub struct Bounded(pub u64);
19
20impl SizeLimit for Bounded {
21    #[inline]
22    fn add(&mut self, n: u64) -> Result<()> {
23        if self.0 >= n {
24            self.0 -= n;
25            Ok(())
26        } else {
27            Err(Error::SizeLimit)
28        }
29    }
30
31    #[inline]
32    fn limit(&self) -> Option<u64> {
33        Some(self.0)
34    }
35}
36
37/// A `SizeLimit` without a limit.
38#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
39pub struct Infinite;
40
41impl SizeLimit for Infinite {
42    #[inline]
43    fn add(&mut self, _n: u64) -> Result<()> {
44        Ok(())
45    }
46
47    #[inline]
48    fn limit(&self) -> Option<u64> {
49        None
50    }
51}
52
53struct Counter {
54    total: u64,
55    limit: Option<u64>,
56}
57
58impl SizeLimit for Counter {
59    fn add(&mut self, n: u64) -> Result<()> {
60        self.total += n;
61        if let Some(limit) = self.limit {
62            if self.total > limit {
63                return Err(Error::SizeLimit);
64            }
65        }
66        Ok(())
67    }
68
69    fn limit(&self) -> Option<u64> {
70        unreachable!();
71    }
72}
73
74struct SizeChecker<S> {
75    counter: S,
76    pos: usize,
77}
78
79impl<S> SizeChecker<S>
80where
81    S: SizeLimit,
82{
83    fn add_padding_of<T>(&mut self) -> Result<()> {
84        let alignment = std::mem::size_of::<T>();
85        let rem_mask = alignment - 1; // mask like 0x0, 0x1, 0x3, 0x7
86        match (self.pos as usize) & rem_mask {
87            0 => Ok(()),
88            n @ 1..=7 => {
89                let amt = alignment - n;
90                self.add_size(amt as u64)
91            }
92            _ => unreachable!(),
93        }
94    }
95
96    fn add_size(&mut self, size: u64) -> Result<()> {
97        self.pos += size as usize;
98        self.counter.add(size)
99    }
100
101    fn add_usize_as_u32(&mut self, v: usize) -> Result<()> {
102        if v > std::u32::MAX as usize {
103            return Err(Error::NumberOutOfRange);
104        }
105
106        ser::Serializer::serialize_u32(self, v as u32)
107    }
108
109    fn add_value<T>(&mut self, _v: T) -> Result<()> {
110        self.add_padding_of::<T>()?;
111        self.add_size(std::mem::size_of::<T>() as u64)
112    }
113}
114
115impl<'a, S> ser::Serializer for &'a mut SizeChecker<S>
116where
117    S: SizeLimit,
118{
119    type Ok = ();
120    type Error = Error;
121    type SerializeSeq = SizeCompound<'a, S>;
122    type SerializeTuple = SizeCompound<'a, S>;
123    type SerializeTupleStruct = SizeCompound<'a, S>;
124    type SerializeTupleVariant = SizeCompound<'a, S>;
125    type SerializeMap = SizeCompound<'a, S>;
126    type SerializeStruct = SizeCompound<'a, S>;
127    type SerializeStructVariant = SizeCompound<'a, S>;
128
129    fn serialize_bool(self, _v: bool) -> Result<Self::Ok> {
130        self.add_value(0u8)
131    }
132
133    fn serialize_u8(self, v: u8) -> Result<Self::Ok> {
134        self.add_value(v)
135    }
136
137    fn serialize_u16(self, v: u16) -> Result<Self::Ok> {
138        self.add_value(v)
139    }
140
141    fn serialize_u32(self, v: u32) -> Result<Self::Ok> {
142        self.add_value(v)
143    }
144
145    fn serialize_u64(self, v: u64) -> Result<Self::Ok> {
146        self.add_value(v)
147    }
148
149    fn serialize_i8(self, v: i8) -> Result<Self::Ok> {
150        self.add_value(v)
151    }
152
153    fn serialize_i16(self, v: i16) -> Result<Self::Ok> {
154        self.add_value(v)
155    }
156
157    fn serialize_i32(self, v: i32) -> Result<Self::Ok> {
158        self.add_value(v)
159    }
160
161    fn serialize_i64(self, v: i64) -> Result<Self::Ok> {
162        self.add_value(v)
163    }
164
165    fn serialize_f32(self, v: f32) -> Result<Self::Ok> {
166        self.add_value(v)
167    }
168
169    fn serialize_f64(self, v: f64) -> Result<Self::Ok> {
170        self.add_value(v)
171    }
172
173    fn serialize_char(self, v: char) -> Result<Self::Ok> {
174        self.add_size(v.len_utf8() as u64)
175    }
176
177    fn serialize_str(self, v: &str) -> Result<Self::Ok> {
178        self.add_value(0 as u32)?;
179        self.add_size(v.len() as u64 + 1) // adds the length 1 of a terminating character
180    }
181
182    fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok> {
183        self.add_value(0 as u32)?;
184        self.add_size(v.len() as u64)
185    }
186
187    fn serialize_none(self) -> Result<Self::Ok> {
188        self.add_value(0 as u8)
189    }
190
191    fn serialize_some<T: ?Sized>(self, v: &T) -> Result<Self::Ok>
192    where
193        T: ser::Serialize,
194    {
195        self.add_value(1 as u8)?;
196        v.serialize(self)
197    }
198
199    fn serialize_unit(self) -> Result<Self::Ok> {
200        Ok(())
201    }
202
203    fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok> {
204        Ok(())
205    }
206
207    fn serialize_unit_variant(
208        self,
209        _name: &'static str,
210        variant_index: u32,
211        _variant: &'static str,
212    ) -> Result<Self::Ok> {
213        self.serialize_u32(variant_index)
214    }
215
216    fn serialize_newtype_struct<T: ?Sized>(self, _name: &'static str, value: &T) -> Result<Self::Ok>
217    where
218        T: ser::Serialize,
219    {
220        value.serialize(self)
221    }
222
223    fn serialize_newtype_variant<T: ?Sized>(
224        self,
225        _name: &'static str,
226        variant_index: u32,
227        _variant: &'static str,
228        value: &T,
229    ) -> Result<Self::Ok>
230    where
231        T: ser::Serialize,
232    {
233        self.serialize_u32(variant_index)?;
234        value.serialize(self)
235    }
236
237    fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq> {
238        let len = len.ok_or(Error::SequenceMustHaveLength)?;
239        self.add_usize_as_u32(len)?;
240        Ok(SizeCompound { ser: self })
241    }
242
243    fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple> {
244        Ok(SizeCompound { ser: self })
245    }
246
247    fn serialize_tuple_struct(
248        self,
249        _name: &'static str,
250        _len: usize,
251    ) -> Result<Self::SerializeTupleStruct> {
252        Ok(SizeCompound { ser: self })
253    }
254
255    fn serialize_tuple_variant(
256        self,
257        _name: &'static str,
258        variant_index: u32,
259        _variant: &'static str,
260        _len: usize,
261    ) -> Result<Self::SerializeTupleVariant> {
262        self.serialize_u32(variant_index)?;
263        Ok(SizeCompound { ser: self })
264    }
265
266    fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap> {
267        Err(Error::TypeNotSupported)
268    }
269
270    fn serialize_struct(self, _name: &'static str, _len: usize) -> Result<Self::SerializeStruct> {
271        Ok(SizeCompound { ser: self })
272    }
273
274    fn serialize_struct_variant(
275        self,
276        _name: &'static str,
277        variant_index: u32,
278        _variant: &'static str,
279        _len: usize,
280    ) -> Result<Self::SerializeStructVariant> {
281        self.serialize_u32(variant_index)?;
282        Ok(SizeCompound { ser: self })
283    }
284
285    fn is_human_readable(&self) -> bool {
286        false
287    }
288}
289
290#[doc(hidden)]
291pub struct SizeCompound<'a, S: 'a> {
292    ser: &'a mut SizeChecker<S>,
293}
294
295impl<'a, S> ser::SerializeSeq for SizeCompound<'a, S>
296where
297    S: SizeLimit,
298{
299    type Ok = ();
300    type Error = Error;
301
302    #[inline]
303    fn serialize_element<T: ?Sized>(&mut self, value: &T) -> Result<()>
304    where
305        T: ser::Serialize,
306    {
307        value.serialize(&mut *self.ser)
308    }
309
310    #[inline]
311    fn end(self) -> Result<()> {
312        Ok(())
313    }
314}
315
316impl<'a, S> ser::SerializeTuple for SizeCompound<'a, S>
317where
318    S: SizeLimit,
319{
320    type Ok = ();
321    type Error = Error;
322
323    #[inline]
324    fn serialize_element<T: ?Sized>(&mut self, value: &T) -> Result<()>
325    where
326        T: ser::Serialize,
327    {
328        value.serialize(&mut *self.ser)
329    }
330
331    #[inline]
332    fn end(self) -> Result<()> {
333        Ok(())
334    }
335}
336
337impl<'a, S> ser::SerializeTupleStruct for SizeCompound<'a, S>
338where
339    S: SizeLimit,
340{
341    type Ok = ();
342    type Error = Error;
343
344    #[inline]
345    fn serialize_field<T: ?Sized>(&mut self, value: &T) -> Result<()>
346    where
347        T: ser::Serialize,
348    {
349        value.serialize(&mut *self.ser)
350    }
351
352    #[inline]
353    fn end(self) -> Result<()> {
354        Ok(())
355    }
356}
357
358impl<'a, S> ser::SerializeTupleVariant for SizeCompound<'a, S>
359where
360    S: SizeLimit,
361{
362    type Ok = ();
363    type Error = Error;
364
365    #[inline]
366    fn serialize_field<T: ?Sized>(&mut self, value: &T) -> Result<()>
367    where
368        T: ser::Serialize,
369    {
370        value.serialize(&mut *self.ser)
371    }
372
373    #[inline]
374    fn end(self) -> Result<()> {
375        Ok(())
376    }
377}
378
379impl<'a, S> ser::SerializeMap for SizeCompound<'a, S>
380where
381    S: SizeLimit,
382{
383    type Ok = ();
384    type Error = Error;
385
386    #[inline]
387    fn serialize_key<T: ?Sized>(&mut self, key: &T) -> Result<()>
388    where
389        T: ser::Serialize,
390    {
391        key.serialize(&mut *self.ser)
392    }
393
394    #[inline]
395    fn serialize_value<T: ?Sized>(&mut self, value: &T) -> Result<()>
396    where
397        T: ser::Serialize,
398    {
399        value.serialize(&mut *self.ser)
400    }
401
402    #[inline]
403    fn end(self) -> Result<()> {
404        Ok(())
405    }
406}
407
408impl<'a, S> ser::SerializeStruct for SizeCompound<'a, S>
409where
410    S: SizeLimit,
411{
412    type Ok = ();
413    type Error = Error;
414
415    #[inline]
416    fn serialize_field<T: ?Sized>(&mut self, _key: &'static str, value: &T) -> Result<()>
417    where
418        T: ser::Serialize,
419    {
420        value.serialize(&mut *self.ser)
421    }
422
423    #[inline]
424    fn end(self) -> Result<()> {
425        Ok(())
426    }
427}
428
429impl<'a, S> ser::SerializeStructVariant for SizeCompound<'a, S>
430where
431    S: SizeLimit,
432{
433    type Ok = ();
434    type Error = Error;
435
436    #[inline]
437    fn serialize_field<T: ?Sized>(&mut self, _key: &'static str, value: &T) -> Result<()>
438    where
439        T: ser::Serialize,
440    {
441        value.serialize(&mut *self.ser)
442    }
443
444    #[inline]
445    fn end(self) -> Result<()> {
446        Ok(())
447    }
448}
449
450/// Returns the size that an object would be if serialized.
451pub fn calc_serialized_data_size<T: ?Sized>(value: &T) -> u64
452where
453    T: ser::Serialize,
454{
455    let mut checker = SizeChecker {
456        counter: Counter {
457            total: 0,
458            limit: None,
459        },
460        pos: 0,
461    };
462
463    value.serialize(&mut checker).ok();
464    checker.counter.total
465}
466
467/// Given a maximum size limit, check how large an object would be if it were
468/// to be serialized.
469pub fn calc_serialized_data_size_bounded<T: ?Sized>(value: &T, max: u64) -> Result<u64>
470where
471    T: ser::Serialize,
472{
473    let mut checker = SizeChecker {
474        counter: Bounded(max),
475        pos: 0,
476    };
477
478    match value.serialize(&mut checker) {
479        Ok(_) => Ok(max - checker.counter.0),
480        Err(e) => Err(e),
481    }
482}