1use std;
4
5use serde::ser;
6
7use crate::error::{Error, Result};
8
9pub trait SizeLimit {
11 fn add(&mut self, n: u64) -> Result<()>;
12 fn limit(&self) -> Option<u64>;
13}
14
15#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
18pub struct Bounded(pub u64);
19
20impl SizeLimit for Bounded {
21 #[inline]
22 fn add(&mut self, n: u64) -> Result<()> {
23 if self.0 >= n {
24 self.0 -= n;
25 Ok(())
26 } else {
27 Err(Error::SizeLimit)
28 }
29 }
30
31 #[inline]
32 fn limit(&self) -> Option<u64> {
33 Some(self.0)
34 }
35}
36
37#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
39pub struct Infinite;
40
41impl SizeLimit for Infinite {
42 #[inline]
43 fn add(&mut self, _n: u64) -> Result<()> {
44 Ok(())
45 }
46
47 #[inline]
48 fn limit(&self) -> Option<u64> {
49 None
50 }
51}
52
53struct Counter {
54 total: u64,
55 limit: Option<u64>,
56}
57
58impl SizeLimit for Counter {
59 fn add(&mut self, n: u64) -> Result<()> {
60 self.total += n;
61 if let Some(limit) = self.limit {
62 if self.total > limit {
63 return Err(Error::SizeLimit);
64 }
65 }
66 Ok(())
67 }
68
69 fn limit(&self) -> Option<u64> {
70 unreachable!();
71 }
72}
73
74struct SizeChecker<S> {
75 counter: S,
76 pos: usize,
77}
78
79impl<S> SizeChecker<S>
80where
81 S: SizeLimit,
82{
83 fn add_padding_of<T>(&mut self) -> Result<()> {
84 let alignment = std::mem::size_of::<T>();
85 let rem_mask = alignment - 1; match (self.pos as usize) & rem_mask {
87 0 => Ok(()),
88 n @ 1..=7 => {
89 let amt = alignment - n;
90 self.add_size(amt as u64)
91 }
92 _ => unreachable!(),
93 }
94 }
95
96 fn add_size(&mut self, size: u64) -> Result<()> {
97 self.pos += size as usize;
98 self.counter.add(size)
99 }
100
101 fn add_usize_as_u32(&mut self, v: usize) -> Result<()> {
102 if v > std::u32::MAX as usize {
103 return Err(Error::NumberOutOfRange);
104 }
105
106 ser::Serializer::serialize_u32(self, v as u32)
107 }
108
109 fn add_value<T>(&mut self, _v: T) -> Result<()> {
110 self.add_padding_of::<T>()?;
111 self.add_size(std::mem::size_of::<T>() as u64)
112 }
113}
114
115impl<'a, S> ser::Serializer for &'a mut SizeChecker<S>
116where
117 S: SizeLimit,
118{
119 type Ok = ();
120 type Error = Error;
121 type SerializeSeq = SizeCompound<'a, S>;
122 type SerializeTuple = SizeCompound<'a, S>;
123 type SerializeTupleStruct = SizeCompound<'a, S>;
124 type SerializeTupleVariant = SizeCompound<'a, S>;
125 type SerializeMap = SizeCompound<'a, S>;
126 type SerializeStruct = SizeCompound<'a, S>;
127 type SerializeStructVariant = SizeCompound<'a, S>;
128
129 fn serialize_bool(self, _v: bool) -> Result<Self::Ok> {
130 self.add_value(0u8)
131 }
132
133 fn serialize_u8(self, v: u8) -> Result<Self::Ok> {
134 self.add_value(v)
135 }
136
137 fn serialize_u16(self, v: u16) -> Result<Self::Ok> {
138 self.add_value(v)
139 }
140
141 fn serialize_u32(self, v: u32) -> Result<Self::Ok> {
142 self.add_value(v)
143 }
144
145 fn serialize_u64(self, v: u64) -> Result<Self::Ok> {
146 self.add_value(v)
147 }
148
149 fn serialize_i8(self, v: i8) -> Result<Self::Ok> {
150 self.add_value(v)
151 }
152
153 fn serialize_i16(self, v: i16) -> Result<Self::Ok> {
154 self.add_value(v)
155 }
156
157 fn serialize_i32(self, v: i32) -> Result<Self::Ok> {
158 self.add_value(v)
159 }
160
161 fn serialize_i64(self, v: i64) -> Result<Self::Ok> {
162 self.add_value(v)
163 }
164
165 fn serialize_f32(self, v: f32) -> Result<Self::Ok> {
166 self.add_value(v)
167 }
168
169 fn serialize_f64(self, v: f64) -> Result<Self::Ok> {
170 self.add_value(v)
171 }
172
173 fn serialize_char(self, v: char) -> Result<Self::Ok> {
174 self.add_size(v.len_utf8() as u64)
175 }
176
177 fn serialize_str(self, v: &str) -> Result<Self::Ok> {
178 self.add_value(0 as u32)?;
179 self.add_size(v.len() as u64 + 1) }
181
182 fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok> {
183 self.add_value(0 as u32)?;
184 self.add_size(v.len() as u64)
185 }
186
187 fn serialize_none(self) -> Result<Self::Ok> {
188 self.add_value(0 as u8)
189 }
190
191 fn serialize_some<T: ?Sized>(self, v: &T) -> Result<Self::Ok>
192 where
193 T: ser::Serialize,
194 {
195 self.add_value(1 as u8)?;
196 v.serialize(self)
197 }
198
199 fn serialize_unit(self) -> Result<Self::Ok> {
200 Ok(())
201 }
202
203 fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok> {
204 Ok(())
205 }
206
207 fn serialize_unit_variant(
208 self,
209 _name: &'static str,
210 variant_index: u32,
211 _variant: &'static str,
212 ) -> Result<Self::Ok> {
213 self.serialize_u32(variant_index)
214 }
215
216 fn serialize_newtype_struct<T: ?Sized>(self, _name: &'static str, value: &T) -> Result<Self::Ok>
217 where
218 T: ser::Serialize,
219 {
220 value.serialize(self)
221 }
222
223 fn serialize_newtype_variant<T: ?Sized>(
224 self,
225 _name: &'static str,
226 variant_index: u32,
227 _variant: &'static str,
228 value: &T,
229 ) -> Result<Self::Ok>
230 where
231 T: ser::Serialize,
232 {
233 self.serialize_u32(variant_index)?;
234 value.serialize(self)
235 }
236
237 fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq> {
238 let len = len.ok_or(Error::SequenceMustHaveLength)?;
239 self.add_usize_as_u32(len)?;
240 Ok(SizeCompound { ser: self })
241 }
242
243 fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple> {
244 Ok(SizeCompound { ser: self })
245 }
246
247 fn serialize_tuple_struct(
248 self,
249 _name: &'static str,
250 _len: usize,
251 ) -> Result<Self::SerializeTupleStruct> {
252 Ok(SizeCompound { ser: self })
253 }
254
255 fn serialize_tuple_variant(
256 self,
257 _name: &'static str,
258 variant_index: u32,
259 _variant: &'static str,
260 _len: usize,
261 ) -> Result<Self::SerializeTupleVariant> {
262 self.serialize_u32(variant_index)?;
263 Ok(SizeCompound { ser: self })
264 }
265
266 fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap> {
267 Err(Error::TypeNotSupported)
268 }
269
270 fn serialize_struct(self, _name: &'static str, _len: usize) -> Result<Self::SerializeStruct> {
271 Ok(SizeCompound { ser: self })
272 }
273
274 fn serialize_struct_variant(
275 self,
276 _name: &'static str,
277 variant_index: u32,
278 _variant: &'static str,
279 _len: usize,
280 ) -> Result<Self::SerializeStructVariant> {
281 self.serialize_u32(variant_index)?;
282 Ok(SizeCompound { ser: self })
283 }
284
285 fn is_human_readable(&self) -> bool {
286 false
287 }
288}
289
290#[doc(hidden)]
291pub struct SizeCompound<'a, S: 'a> {
292 ser: &'a mut SizeChecker<S>,
293}
294
295impl<'a, S> ser::SerializeSeq for SizeCompound<'a, S>
296where
297 S: SizeLimit,
298{
299 type Ok = ();
300 type Error = Error;
301
302 #[inline]
303 fn serialize_element<T: ?Sized>(&mut self, value: &T) -> Result<()>
304 where
305 T: ser::Serialize,
306 {
307 value.serialize(&mut *self.ser)
308 }
309
310 #[inline]
311 fn end(self) -> Result<()> {
312 Ok(())
313 }
314}
315
316impl<'a, S> ser::SerializeTuple for SizeCompound<'a, S>
317where
318 S: SizeLimit,
319{
320 type Ok = ();
321 type Error = Error;
322
323 #[inline]
324 fn serialize_element<T: ?Sized>(&mut self, value: &T) -> Result<()>
325 where
326 T: ser::Serialize,
327 {
328 value.serialize(&mut *self.ser)
329 }
330
331 #[inline]
332 fn end(self) -> Result<()> {
333 Ok(())
334 }
335}
336
337impl<'a, S> ser::SerializeTupleStruct for SizeCompound<'a, S>
338where
339 S: SizeLimit,
340{
341 type Ok = ();
342 type Error = Error;
343
344 #[inline]
345 fn serialize_field<T: ?Sized>(&mut self, value: &T) -> Result<()>
346 where
347 T: ser::Serialize,
348 {
349 value.serialize(&mut *self.ser)
350 }
351
352 #[inline]
353 fn end(self) -> Result<()> {
354 Ok(())
355 }
356}
357
358impl<'a, S> ser::SerializeTupleVariant for SizeCompound<'a, S>
359where
360 S: SizeLimit,
361{
362 type Ok = ();
363 type Error = Error;
364
365 #[inline]
366 fn serialize_field<T: ?Sized>(&mut self, value: &T) -> Result<()>
367 where
368 T: ser::Serialize,
369 {
370 value.serialize(&mut *self.ser)
371 }
372
373 #[inline]
374 fn end(self) -> Result<()> {
375 Ok(())
376 }
377}
378
379impl<'a, S> ser::SerializeMap for SizeCompound<'a, S>
380where
381 S: SizeLimit,
382{
383 type Ok = ();
384 type Error = Error;
385
386 #[inline]
387 fn serialize_key<T: ?Sized>(&mut self, key: &T) -> Result<()>
388 where
389 T: ser::Serialize,
390 {
391 key.serialize(&mut *self.ser)
392 }
393
394 #[inline]
395 fn serialize_value<T: ?Sized>(&mut self, value: &T) -> Result<()>
396 where
397 T: ser::Serialize,
398 {
399 value.serialize(&mut *self.ser)
400 }
401
402 #[inline]
403 fn end(self) -> Result<()> {
404 Ok(())
405 }
406}
407
408impl<'a, S> ser::SerializeStruct for SizeCompound<'a, S>
409where
410 S: SizeLimit,
411{
412 type Ok = ();
413 type Error = Error;
414
415 #[inline]
416 fn serialize_field<T: ?Sized>(&mut self, _key: &'static str, value: &T) -> Result<()>
417 where
418 T: ser::Serialize,
419 {
420 value.serialize(&mut *self.ser)
421 }
422
423 #[inline]
424 fn end(self) -> Result<()> {
425 Ok(())
426 }
427}
428
429impl<'a, S> ser::SerializeStructVariant for SizeCompound<'a, S>
430where
431 S: SizeLimit,
432{
433 type Ok = ();
434 type Error = Error;
435
436 #[inline]
437 fn serialize_field<T: ?Sized>(&mut self, _key: &'static str, value: &T) -> Result<()>
438 where
439 T: ser::Serialize,
440 {
441 value.serialize(&mut *self.ser)
442 }
443
444 #[inline]
445 fn end(self) -> Result<()> {
446 Ok(())
447 }
448}
449
450pub fn calc_serialized_data_size<T: ?Sized>(value: &T) -> u64
452where
453 T: ser::Serialize,
454{
455 let mut checker = SizeChecker {
456 counter: Counter {
457 total: 0,
458 limit: None,
459 },
460 pos: 0,
461 };
462
463 value.serialize(&mut checker).ok();
464 checker.counter.total
465}
466
467pub fn calc_serialized_data_size_bounded<T: ?Sized>(value: &T, max: u64) -> Result<u64>
470where
471 T: ser::Serialize,
472{
473 let mut checker = SizeChecker {
474 counter: Bounded(max),
475 pos: 0,
476 };
477
478 match value.serialize(&mut checker) {
479 Ok(_) => Ok(max - checker.counter.0),
480 Err(e) => Err(e),
481 }
482}