Skip to main content

tfhe_safe_serialize/
lib.rs

1//! Serialization utilities with some safety checks
2
3// Types in this file should never be versioned because they are a wrapper around the versioning
4// process
5#![cfg_attr(
6    dylint_lib = "tfhe_lints",
7    allow(unknown_lints, serialize_without_versionize)
8)]
9
10use std::borrow::Cow;
11use std::fmt::Display;
12
13use bincode::Options;
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16use tfhe_versionable::{Unversionize, Versionize};
17
18mod traits;
19pub use crate::traits::{EnumSet, Named, ParameterSetConformant};
20
21/// This is the global version of the serialization scheme that is used. This should be updated when
22/// the SerializationHeader is updated.
23const SERIALIZATION_VERSION: &str = "0.5";
24
25/// This is the version of the versioning scheme used to add backward compatibibility on tfhe-rs
26/// types. Similar to SERIALIZATION_VERSION, this number should be increased when the versioning
27/// scheme is upgraded.
28const VERSIONING_VERSION: &str = "0.1";
29
30/// This is the current version of this crate. This is used to be able to reject unversioned data
31/// if they come from a previous version.
32const CRATE_VERSION: &str = concat!(
33    env!("CARGO_PKG_VERSION_MAJOR"),
34    ".",
35    env!("CARGO_PKG_VERSION_MINOR")
36);
37
38/// Tells if this serialized object is versioned or not
39#[derive(Serialize, Deserialize, Clone, PartialEq, Eq)]
40enum SerializationVersioningMode {
41    /// Serialize with type versioning for backward compatibility
42    Versioned {
43        /// Version of the versioning scheme in use
44        versioning_version: Cow<'static, str>,
45    },
46    /// Serialize the type without versioning information
47    Unversioned {
48        /// Version of tfhe-rs where this data was generated
49        crate_version: Cow<'static, str>,
50    },
51}
52
53impl Display for SerializationVersioningMode {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        match self {
56            Self::Versioned { .. } => write!(f, "versioned"),
57            Self::Unversioned { .. } => write!(f, "unversioned"),
58        }
59    }
60}
61
62impl SerializationVersioningMode {
63    fn versioned() -> Self {
64        Self::Versioned {
65            versioning_version: Cow::Borrowed(VERSIONING_VERSION),
66        }
67    }
68
69    fn unversioned() -> Self {
70        Self::Unversioned {
71            crate_version: Cow::Borrowed(CRATE_VERSION),
72        }
73    }
74}
75
76/// Header with global metadata about the serialized object. This help checking that we are not
77/// deserializing data that we can't handle.
78#[derive(Serialize, Deserialize)]
79struct SerializationHeader {
80    header_version: Cow<'static, str>,
81    versioning_mode: SerializationVersioningMode,
82    name: Cow<'static, str>,
83}
84
85impl SerializationHeader {
86    /// Creates a new header for a versioned message
87    fn new_versioned<T: Named>() -> Self {
88        Self {
89            header_version: Cow::Borrowed(SERIALIZATION_VERSION),
90            versioning_mode: SerializationVersioningMode::versioned(),
91            name: Cow::Borrowed(T::NAME),
92        }
93    }
94
95    /// Creates a new header for an unversioned message
96    fn new_unversioned<T: Named>() -> Self {
97        Self {
98            header_version: Cow::Borrowed(SERIALIZATION_VERSION),
99            versioning_mode: SerializationVersioningMode::unversioned(),
100            name: Cow::Borrowed(T::NAME),
101        }
102    }
103
104    /// Checks the validity of the header
105    fn validate<T: Named>(&self) -> Result<(), String> {
106        if self.header_version != SERIALIZATION_VERSION {
107            return Err(format!(
108                "On deserialization, expected serialization header version {SERIALIZATION_VERSION}, \
109got version {}",
110                self.header_version
111            ));
112        }
113
114        match &self.versioning_mode {
115            SerializationVersioningMode::Versioned { versioning_version } => {
116                // For the moment there is only one versioning scheme, so another value is
117                // a hard error. But maybe if we upgrade it we will be able to automatically convert
118                // it.
119                if versioning_version != VERSIONING_VERSION {
120                    return Err(format!(
121                        "On deserialization, expected versioning scheme version {VERSIONING_VERSION}, \
122got version {versioning_version}"
123                    ));
124                }
125            }
126            SerializationVersioningMode::Unversioned { crate_version } => {
127                if crate_version != CRATE_VERSION {
128                    return Err(format!(
129                        "This {} has been saved from TFHE-rs v{crate_version}, without versioning information. \
130Please use the versioned serialization mode for backward compatibility.",
131                        self.name
132                    ));
133                }
134            }
135        }
136
137        if self.name != T::NAME
138            && T::BACKWARD_COMPATIBILITY_ALIASES
139                .iter()
140                .all(|alias| self.name != *alias)
141        {
142            return Err(format!(
143                "On deserialization, expected type {}, got type {}",
144                T::NAME,
145                self.name
146            ));
147        }
148
149        Ok(())
150    }
151}
152
153/// A configuration used to Serialize *TFHE-rs* objects. This configuration decides
154/// if the object will be versioned and holds the max byte size of the written data.
155#[derive(Clone)]
156pub struct SerializationConfig {
157    versioned: SerializationVersioningMode,
158    serialized_size_limit: Option<u64>,
159}
160
161impl SerializationConfig {
162    /// Creates a new serialization config. The default configuration will serialize the object
163    /// with versioning information for backward compatibility.
164    /// `serialized_size_limit` is the size limit (in number of bytes) of the serialized object
165    /// (including the header).
166    pub fn new(serialized_size_limit: u64) -> Self {
167        Self {
168            versioned: SerializationVersioningMode::versioned(),
169            serialized_size_limit: Some(serialized_size_limit),
170        }
171    }
172
173    /// Creates a new serialization config without any size check.
174    pub fn new_with_unlimited_size() -> Self {
175        Self {
176            versioned: SerializationVersioningMode::versioned(),
177            serialized_size_limit: None,
178        }
179    }
180
181    /// Disables the size limit for serialized objects
182    pub fn disable_size_limit(self) -> Self {
183        Self {
184            serialized_size_limit: None,
185            ..self
186        }
187    }
188
189    /// Disable the versioning of serialized objects
190    pub fn disable_versioning(self) -> Self {
191        Self {
192            versioned: SerializationVersioningMode::unversioned(),
193            ..self
194        }
195    }
196
197    /// Sets the size limit for this serialization config
198    pub fn with_size_limit(self, size: u64) -> Self {
199        Self {
200            serialized_size_limit: Some(size),
201            ..self
202        }
203    }
204
205    /// Create a serialization header based on the current config
206    fn create_header<T: Named>(&self) -> SerializationHeader {
207        match self.versioned {
208            SerializationVersioningMode::Versioned { .. } => {
209                SerializationHeader::new_versioned::<T>()
210            }
211            SerializationVersioningMode::Unversioned { .. } => {
212                SerializationHeader::new_unversioned::<T>()
213            }
214        }
215    }
216
217    /// Returns the size the object would take if serialized using the current config
218    ///
219    /// The size is returned as a u64 to handle the serialization of large buffers under 32b
220    /// architectures.
221    pub fn serialized_size<T: Serialize + Versionize + Named>(
222        &self,
223        object: &T,
224    ) -> bincode::Result<u64> {
225        let options = bincode::DefaultOptions::new().with_fixint_encoding();
226
227        let header = self.create_header::<T>();
228
229        let header_size = options.serialized_size(&header)?;
230
231        let data_size = match self.versioned {
232            SerializationVersioningMode::Versioned { .. } => {
233                options.serialized_size(&object.versionize())?
234            }
235            SerializationVersioningMode::Unversioned { .. } => options.serialized_size(&object)?,
236        };
237
238        Ok(header_size + data_size)
239    }
240
241    /// Serializes an object into a [writer](std::io::Write), based on the current config.
242    /// The written bytes can be deserialized using [`DeserializationConfig::deserialize_from`].
243    pub fn serialize_into<T: Serialize + Versionize + Named>(
244        self,
245        object: &T,
246        mut writer: impl std::io::Write,
247    ) -> bincode::Result<()> {
248        let options = bincode::DefaultOptions::new()
249            .with_fixint_encoding()
250            .with_limit(0); // Force to explicitly set the limit for each serialization
251
252        let header = self.create_header::<T>();
253        let header_size = options.with_no_limit().serialized_size(&header)?;
254
255        if let Some(size_limit) = self.serialized_size_limit {
256            options
257                .with_limit(size_limit)
258                .serialize_into(&mut writer, &header)?;
259
260            let options = options.with_limit(size_limit - header_size);
261
262            match self.versioned {
263                SerializationVersioningMode::Versioned { .. } => {
264                    options.serialize_into(&mut writer, &object.versionize())?
265                }
266                SerializationVersioningMode::Unversioned { .. } => {
267                    options.serialize_into(&mut writer, &object)?
268                }
269            }
270        } else {
271            let options = options.with_no_limit();
272
273            options.serialize_into(&mut writer, &header)?;
274
275            match self.versioned {
276                SerializationVersioningMode::Versioned { .. } => {
277                    options.serialize_into(&mut writer, &object.versionize())?
278                }
279                SerializationVersioningMode::Unversioned { .. } => {
280                    options.serialize_into(&mut writer, &object)?
281                }
282            }
283        }
284
285        Ok(())
286    }
287}
288
289/// A configuration used to Serialize *TFHE-rs* objects. This configuration decides
290/// the various sanity checks that will be performed during deserialization.
291#[derive(Copy, Clone)]
292pub struct DeserializationConfig {
293    serialized_size_limit: Option<u64>,
294    validate_header: bool,
295}
296
297/// A configuration used to Serialize *TFHE-rs* objects. This is similar to
298/// [`DeserializationConfig`] but it will not require conformance parameters.
299///
300/// This type should be created with [`DeserializationConfig::disable_conformance`]
301#[derive(Copy, Clone)]
302pub struct NonConformantDeserializationConfig {
303    serialized_size_limit: Option<u64>,
304    validate_header: bool,
305}
306
307impl NonConformantDeserializationConfig {
308    /// Deserialize a header using the current config
309    fn deserialize_header(
310        &self,
311        reader: &mut impl std::io::Read,
312    ) -> Result<SerializationHeader, String> {
313        let options = bincode::DefaultOptions::new()
314            .with_fixint_encoding()
315            .with_limit(0);
316
317        if let Some(size_limit) = self.serialized_size_limit {
318            options
319                .with_limit(size_limit)
320                .deserialize_from(reader)
321                .map_err(|err| err.to_string())
322        } else {
323            options
324                .with_no_limit()
325                .deserialize_from(reader)
326                .map_err(|err| err.to_string())
327        }
328    }
329
330    /// Deserializes an object serialized by [`SerializationConfig::serialize_into`] from a
331    /// [reader](std::io::Read). Performs various sanity checks based on the deserialization config,
332    /// but skips conformance checks.
333    pub fn deserialize_from<T: DeserializeOwned + Unversionize + Named>(
334        self,
335        mut reader: impl std::io::Read,
336    ) -> Result<T, String> {
337        let options = bincode::DefaultOptions::new()
338            .with_fixint_encoding()
339            .with_limit(0); // Force to explicitly set the limit for each deserialization
340
341        let deserialized_header: SerializationHeader = self.deserialize_header(&mut reader)?;
342
343        let header_size = options
344            .with_no_limit()
345            .serialized_size(&deserialized_header)
346            .map_err(|err| err.to_string())?;
347
348        if self.validate_header {
349            deserialized_header.validate::<T>()?;
350        }
351
352        if let Some(size_limit) = self.serialized_size_limit {
353            let options = options.with_limit(size_limit - header_size);
354            match deserialized_header.versioning_mode {
355                SerializationVersioningMode::Versioned { .. } => {
356                    let deser_versioned = options
357                        .deserialize_from(&mut reader)
358                        .map_err(|err| err.to_string())?;
359
360                    T::unversionize(deser_versioned).map_err(|e| e.to_string())
361                }
362                SerializationVersioningMode::Unversioned { .. } => options
363                    .deserialize_from(&mut reader)
364                    .map_err(|err| err.to_string()),
365            }
366        } else {
367            let options = options.with_no_limit();
368            match deserialized_header.versioning_mode {
369                SerializationVersioningMode::Versioned { .. } => {
370                    let deser_versioned = options
371                        .deserialize_from(&mut reader)
372                        .map_err(|err| err.to_string())?;
373
374                    T::unversionize(deser_versioned).map_err(|e| e.to_string())
375                }
376                SerializationVersioningMode::Unversioned { .. } => options
377                    .deserialize_from(&mut reader)
378                    .map_err(|err| err.to_string()),
379            }
380        }
381    }
382
383    /// Enables the conformance check on an existing config.
384    pub fn enable_conformance(self) -> DeserializationConfig {
385        DeserializationConfig {
386            serialized_size_limit: self.serialized_size_limit,
387            validate_header: self.validate_header,
388        }
389    }
390}
391
392impl DeserializationConfig {
393    /// Creates a new deserialization config.
394    ///
395    /// By default, it will check that the serialization version and the name of the
396    /// deserialized type are correct.
397    /// `deserialized_size_limit` is the size limit (in number of bytes) of the deserialized object.
398    /// It should be set according to the expected size of the object and the maximum allocatable
399    /// size on your system.
400    ///
401    /// It will also check that the object is conformant with the parameter set given in
402    /// `conformance_params`. Finally, it will check the compatibility of the loaded data with
403    /// the current *TFHE-rs* version.
404    pub fn new(deserialized_size_limit: u64) -> Self {
405        Self {
406            serialized_size_limit: Some(deserialized_size_limit),
407            validate_header: true,
408        }
409    }
410
411    /// Creates a new config without any size limit for the deserialized objects.
412    pub fn new_with_unlimited_size() -> Self {
413        Self {
414            serialized_size_limit: None,
415            validate_header: true,
416        }
417    }
418
419    /// Disables the size limit for the serialized objects.
420    pub fn disable_size_limit(self) -> Self {
421        Self {
422            serialized_size_limit: None,
423            ..self
424        }
425    }
426
427    /// Sets the size limit for this deserialization config
428    pub fn with_size_limit(self, size: u64) -> Self {
429        Self {
430            serialized_size_limit: Some(size),
431            ..self
432        }
433    }
434
435    /// Disables the header validation on the object. This header validations
436    /// checks that the serialized object is the one that is supposed to be loaded
437    /// and is compatible with this version of *TFHE-rs*.
438    pub fn disable_header_validation(self) -> Self {
439        Self {
440            validate_header: false,
441            ..self
442        }
443    }
444
445    /// Disables the conformance check on an existing config.
446    pub fn disable_conformance(self) -> NonConformantDeserializationConfig {
447        NonConformantDeserializationConfig {
448            serialized_size_limit: self.serialized_size_limit,
449            validate_header: self.validate_header,
450        }
451    }
452
453    /// Deserializes an object serialized by [`SerializationConfig::serialize_into`] from a
454    /// [reader](std::io::Read). Performs various sanity checks based on the deserialization config.
455    ///
456    /// # Panics
457    /// This function may panic if `serialized_size_limit` is larger than what can be allocated by
458    /// the system. This may happen even if the size of the serialized data is short. An
459    /// attacker could manipulate the data to create a short serialized message with a huge
460    /// deserialized size.
461    pub fn deserialize_from<T: DeserializeOwned + Unversionize + Named + ParameterSetConformant>(
462        self,
463        reader: impl std::io::Read,
464        parameter_set: &T::ParameterSet,
465    ) -> Result<T, String> {
466        let deser: T = self.disable_conformance().deserialize_from(reader)?;
467        if !deser.is_conformant(parameter_set) {
468            return Err(format!(
469                "Deserialized object of type {} not conformant with given parameter set",
470                T::NAME
471            ));
472        }
473
474        Ok(deser)
475    }
476}
477
478/// Serialize an object with the default configuration (with size limit and versioning).
479/// This is an alias for `SerializationConfig::new(serialized_size_limit).serialize_into`
480pub fn safe_serialize<T: Serialize + Versionize + Named>(
481    object: &T,
482    writer: impl std::io::Write,
483    serialized_size_limit: u64,
484) -> bincode::Result<()> {
485    SerializationConfig::new(serialized_size_limit).serialize_into(object, writer)
486}
487
488/// Return the size the object would take if serialized using [`safe_serialize`]
489pub fn safe_serialized_size<T: Serialize + Versionize + Named>(object: &T) -> bincode::Result<u64> {
490    SerializationConfig::new_with_unlimited_size().serialized_size(object)
491}
492
493/// Serialize an object with the default configuration (with size limit, header check and
494/// versioning).
495///
496/// `deserialized_size_limit` is the size limit (in number of bytes) of the deserialized object.
497/// It should be set according to the expected size of the object and the maximum allocatable size
498/// on your system.
499///
500/// This is an alias for
501/// `DeserializationConfig::new(serialized_size_limit).disable_conformance().deserialize_from`
502///
503/// # Panics
504/// This function may panic if `serialized_size_limit` is larger than what can be allocated by the
505/// system. This may happen even if the size of the serialized data is short. An attacker could
506/// manipulate the data to create a short serialized message with a huge deserialized size.
507pub fn safe_deserialize<T: DeserializeOwned + Unversionize + Named>(
508    reader: impl std::io::Read,
509    deserialized_size_limit: u64,
510) -> Result<T, String> {
511    DeserializationConfig::new(deserialized_size_limit)
512        .disable_conformance()
513        .deserialize_from(reader)
514}
515
516/// Serialize an object with the default configuration and conformance checks (with size limit,
517/// header check and versioning).
518///
519/// `deserialized_size_limit` is the size limit (in number of bytes) of the deserialized object.
520/// It should be set according to the expected size of the object and the maximum allocatable size
521/// on your system.
522///
523/// This is an alias for
524/// `DeserializationConfig::new(serialized_size_limit).deserialize_from`
525///
526/// # Panics
527/// This function may panic if `serialized_size_limit` is larger than what can be allocated by the
528/// system. This may happen even if the size of the serialized data is short. An attacker could
529/// manipulate the data to create a short serialized message with a huge deserialized size.
530pub fn safe_deserialize_conformant<
531    T: DeserializeOwned + Unversionize + Named + ParameterSetConformant,
532>(
533    reader: impl std::io::Read,
534    deserialized_size_limit: u64,
535    parameter_set: &T::ParameterSet,
536) -> Result<T, String> {
537    DeserializationConfig::new(deserialized_size_limit).deserialize_from(reader, parameter_set)
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543    use crate::traits::Named;
544    use std::ops::RangeInclusive;
545
546    #[derive(Serialize, Deserialize, Versionize, Debug, PartialEq)]
547    #[repr(transparent)]
548    struct Foo(u64);
549
550    impl Named for Foo {
551        const NAME: &'static str = "Foo";
552    }
553
554    #[derive(Serialize, Deserialize, Versionize, Debug, PartialEq)]
555    #[repr(transparent)]
556    struct Bar(u64);
557
558    impl Named for Bar {
559        const NAME: &'static str = "Bar";
560        const BACKWARD_COMPATIBILITY_ALIASES: &'static [&'static str] = &["Foo"];
561    }
562
563    #[derive(Serialize, Deserialize, Versionize, Debug, PartialEq)]
564    #[repr(transparent)]
565    struct Baz(u64);
566
567    impl Named for Baz {
568        const NAME: &'static str = "Baz";
569    }
570
571    #[derive(Serialize, Deserialize, Versionize, Debug, PartialEq)]
572    #[repr(transparent)]
573    struct Conformant(u64);
574
575    impl Named for Conformant {
576        const NAME: &'static str = "Conformant";
577    }
578
579    impl ParameterSetConformant for Conformant {
580        type ParameterSet = RangeInclusive<u64>;
581
582        fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
583            parameter_set.contains(&self.0)
584        }
585    }
586
587    fn serialize_versioned(obj: &Foo) -> Vec<u8> {
588        let mut buf = Vec::new();
589        SerializationConfig::new(1 << 20)
590            .serialize_into(obj, &mut buf)
591            .unwrap();
592        buf
593    }
594
595    #[test]
596    fn backward_compatibility_aliases() {
597        let foo = Foo(3);
598        let mut buf = Vec::new();
599        safe_serialize(&foo, &mut buf, 0x1000).unwrap();
600
601        let foo_deser: Foo = safe_deserialize(buf.as_slice(), 0x1000).unwrap();
602        // Bar is backward compatible with Foo, this works
603        let bar_deser: Bar = safe_deserialize(buf.as_slice(), 0x1000).unwrap();
604
605        assert_eq!(foo_deser.0, bar_deser.0);
606        // Baz is not backward compatible with Foo, this fails
607        assert!(safe_deserialize::<Baz>(buf.as_slice(), 0x1000).is_err());
608    }
609
610    #[test]
611    fn serialized_size_matches_actual_versioned() {
612        let foo = Foo(123);
613        let config = SerializationConfig::new(1 << 20);
614        let size = config.serialized_size(&foo).unwrap();
615        let buf = serialize_versioned(&foo);
616        assert_eq!(size as usize, buf.len());
617    }
618
619    #[test]
620    fn serialize_size_limit_works() {
621        let foo = Foo(1);
622        let exact_size = SerializationConfig::new(1 << 20)
623            .serialized_size(&foo)
624            .unwrap();
625
626        let mut buf = Vec::new();
627        let result = SerializationConfig::new(exact_size - 1).serialize_into(&foo, &mut buf);
628        assert!(result.is_err());
629
630        buf.clear();
631        SerializationConfig::new(exact_size)
632            .serialize_into(&foo, &mut buf)
633            .unwrap();
634        assert_eq!(buf.len(), exact_size as usize);
635    }
636
637    #[test]
638    fn deserialize_size_limit_works() {
639        let obj = Conformant(1);
640        let mut buf = Vec::new();
641        SerializationConfig::new(1 << 20)
642            .serialize_into(&obj, &mut buf)
643            .unwrap();
644        let exact_size = buf.len() as u64;
645
646        let result: Result<Conformant, _> =
647            DeserializationConfig::new(exact_size - 1).deserialize_from(buf.as_slice(), &(0..=100));
648        assert!(result.is_err());
649
650        let result: Result<Conformant, _> =
651            DeserializationConfig::new(exact_size).deserialize_from(buf.as_slice(), &(0..=100));
652        assert!(result.is_ok());
653    }
654
655    #[test]
656    fn header_validation_disabled() {
657        let buf = serialize_versioned(&Foo(7));
658
659        let result: Result<Baz, _> = DeserializationConfig::new(1 << 20)
660            .disable_conformance()
661            .deserialize_from(buf.as_slice());
662        assert!(result.is_err());
663
664        let deser: Baz = DeserializationConfig::new(1 << 20)
665            .disable_header_validation()
666            .disable_conformance()
667            .deserialize_from(buf.as_slice())
668            .unwrap();
669        assert_eq!(deser.0, 7);
670    }
671
672    #[test]
673    fn conformance_check() {
674        let obj = Conformant(50);
675        let mut buf = Vec::new();
676        SerializationConfig::new(1 << 20)
677            .serialize_into(&obj, &mut buf)
678            .unwrap();
679
680        let result: Result<Conformant, _> =
681            DeserializationConfig::new(1 << 20).deserialize_from(buf.as_slice(), &(0..=100));
682        assert!(result.is_ok());
683
684        let result: Result<Conformant, _> =
685            DeserializationConfig::new(1 << 20).deserialize_from(buf.as_slice(), &(0..=10));
686        assert!(result.is_err());
687
688        let deser: Conformant = DeserializationConfig::new(1 << 20)
689            .disable_conformance()
690            .deserialize_from(buf.as_slice())
691            .unwrap();
692        assert_eq!(deser, obj);
693    }
694
695    #[test]
696    fn unlimited_size_configs() {
697        let foo = Conformant(999);
698
699        let mut buf = Vec::new();
700        SerializationConfig::new_with_unlimited_size()
701            .serialize_into(&foo, &mut buf)
702            .unwrap();
703
704        let deser: Conformant = DeserializationConfig::new_with_unlimited_size()
705            .disable_conformance()
706            .deserialize_from(buf.as_slice())
707            .unwrap();
708        assert_eq!(deser, foo);
709
710        let deser: Result<Conformant, _> =
711            DeserializationConfig::new(1).deserialize_from(buf.as_slice(), &(0..=1234));
712        assert!(deser.is_err());
713    }
714}