1#![cfg_attr(
6 dylint_lib = "tfhe_lints",
7 allow(unknown_lints, serialize_without_versionize)
8)]
9
10use std::borrow::Cow;
11use std::fmt::Display;
12
13use bincode::Options;
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16use tfhe_versionable::{Unversionize, Versionize};
17
18mod traits;
19pub use crate::traits::{EnumSet, Named, ParameterSetConformant};
20
21const SERIALIZATION_VERSION: &str = "0.5";
24
25const VERSIONING_VERSION: &str = "0.1";
29
30const CRATE_VERSION: &str = concat!(
33 env!("CARGO_PKG_VERSION_MAJOR"),
34 ".",
35 env!("CARGO_PKG_VERSION_MINOR")
36);
37
38#[derive(Serialize, Deserialize, Clone, PartialEq, Eq)]
40enum SerializationVersioningMode {
41 Versioned {
43 versioning_version: Cow<'static, str>,
45 },
46 Unversioned {
48 crate_version: Cow<'static, str>,
50 },
51}
52
53impl Display for SerializationVersioningMode {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 match self {
56 Self::Versioned { .. } => write!(f, "versioned"),
57 Self::Unversioned { .. } => write!(f, "unversioned"),
58 }
59 }
60}
61
62impl SerializationVersioningMode {
63 fn versioned() -> Self {
64 Self::Versioned {
65 versioning_version: Cow::Borrowed(VERSIONING_VERSION),
66 }
67 }
68
69 fn unversioned() -> Self {
70 Self::Unversioned {
71 crate_version: Cow::Borrowed(CRATE_VERSION),
72 }
73 }
74}
75
76#[derive(Serialize, Deserialize)]
79struct SerializationHeader {
80 header_version: Cow<'static, str>,
81 versioning_mode: SerializationVersioningMode,
82 name: Cow<'static, str>,
83}
84
85impl SerializationHeader {
86 fn new_versioned<T: Named>() -> Self {
88 Self {
89 header_version: Cow::Borrowed(SERIALIZATION_VERSION),
90 versioning_mode: SerializationVersioningMode::versioned(),
91 name: Cow::Borrowed(T::NAME),
92 }
93 }
94
95 fn new_unversioned<T: Named>() -> Self {
97 Self {
98 header_version: Cow::Borrowed(SERIALIZATION_VERSION),
99 versioning_mode: SerializationVersioningMode::unversioned(),
100 name: Cow::Borrowed(T::NAME),
101 }
102 }
103
104 fn validate<T: Named>(&self) -> Result<(), String> {
106 if self.header_version != SERIALIZATION_VERSION {
107 return Err(format!(
108 "On deserialization, expected serialization header version {SERIALIZATION_VERSION}, \
109got version {}",
110 self.header_version
111 ));
112 }
113
114 match &self.versioning_mode {
115 SerializationVersioningMode::Versioned { versioning_version } => {
116 if versioning_version != VERSIONING_VERSION {
120 return Err(format!(
121 "On deserialization, expected versioning scheme version {VERSIONING_VERSION}, \
122got version {versioning_version}"
123 ));
124 }
125 }
126 SerializationVersioningMode::Unversioned { crate_version } => {
127 if crate_version != CRATE_VERSION {
128 return Err(format!(
129 "This {} has been saved from TFHE-rs v{crate_version}, without versioning information. \
130Please use the versioned serialization mode for backward compatibility.",
131 self.name
132 ));
133 }
134 }
135 }
136
137 if self.name != T::NAME
138 && T::BACKWARD_COMPATIBILITY_ALIASES
139 .iter()
140 .all(|alias| self.name != *alias)
141 {
142 return Err(format!(
143 "On deserialization, expected type {}, got type {}",
144 T::NAME,
145 self.name
146 ));
147 }
148
149 Ok(())
150 }
151}
152
153#[derive(Clone)]
156pub struct SerializationConfig {
157 versioned: SerializationVersioningMode,
158 serialized_size_limit: Option<u64>,
159}
160
161impl SerializationConfig {
162 pub fn new(serialized_size_limit: u64) -> Self {
167 Self {
168 versioned: SerializationVersioningMode::versioned(),
169 serialized_size_limit: Some(serialized_size_limit),
170 }
171 }
172
173 pub fn new_with_unlimited_size() -> Self {
175 Self {
176 versioned: SerializationVersioningMode::versioned(),
177 serialized_size_limit: None,
178 }
179 }
180
181 pub fn disable_size_limit(self) -> Self {
183 Self {
184 serialized_size_limit: None,
185 ..self
186 }
187 }
188
189 pub fn disable_versioning(self) -> Self {
191 Self {
192 versioned: SerializationVersioningMode::unversioned(),
193 ..self
194 }
195 }
196
197 pub fn with_size_limit(self, size: u64) -> Self {
199 Self {
200 serialized_size_limit: Some(size),
201 ..self
202 }
203 }
204
205 fn create_header<T: Named>(&self) -> SerializationHeader {
207 match self.versioned {
208 SerializationVersioningMode::Versioned { .. } => {
209 SerializationHeader::new_versioned::<T>()
210 }
211 SerializationVersioningMode::Unversioned { .. } => {
212 SerializationHeader::new_unversioned::<T>()
213 }
214 }
215 }
216
217 pub fn serialized_size<T: Serialize + Versionize + Named>(
222 &self,
223 object: &T,
224 ) -> bincode::Result<u64> {
225 let options = bincode::DefaultOptions::new().with_fixint_encoding();
226
227 let header = self.create_header::<T>();
228
229 let header_size = options.serialized_size(&header)?;
230
231 let data_size = match self.versioned {
232 SerializationVersioningMode::Versioned { .. } => {
233 options.serialized_size(&object.versionize())?
234 }
235 SerializationVersioningMode::Unversioned { .. } => options.serialized_size(&object)?,
236 };
237
238 Ok(header_size + data_size)
239 }
240
241 pub fn serialize_into<T: Serialize + Versionize + Named>(
244 self,
245 object: &T,
246 mut writer: impl std::io::Write,
247 ) -> bincode::Result<()> {
248 let options = bincode::DefaultOptions::new()
249 .with_fixint_encoding()
250 .with_limit(0); let header = self.create_header::<T>();
253 let header_size = options.with_no_limit().serialized_size(&header)?;
254
255 if let Some(size_limit) = self.serialized_size_limit {
256 options
257 .with_limit(size_limit)
258 .serialize_into(&mut writer, &header)?;
259
260 let options = options.with_limit(size_limit - header_size);
261
262 match self.versioned {
263 SerializationVersioningMode::Versioned { .. } => {
264 options.serialize_into(&mut writer, &object.versionize())?
265 }
266 SerializationVersioningMode::Unversioned { .. } => {
267 options.serialize_into(&mut writer, &object)?
268 }
269 }
270 } else {
271 let options = options.with_no_limit();
272
273 options.serialize_into(&mut writer, &header)?;
274
275 match self.versioned {
276 SerializationVersioningMode::Versioned { .. } => {
277 options.serialize_into(&mut writer, &object.versionize())?
278 }
279 SerializationVersioningMode::Unversioned { .. } => {
280 options.serialize_into(&mut writer, &object)?
281 }
282 }
283 }
284
285 Ok(())
286 }
287}
288
289#[derive(Copy, Clone)]
292pub struct DeserializationConfig {
293 serialized_size_limit: Option<u64>,
294 validate_header: bool,
295}
296
297#[derive(Copy, Clone)]
302pub struct NonConformantDeserializationConfig {
303 serialized_size_limit: Option<u64>,
304 validate_header: bool,
305}
306
307impl NonConformantDeserializationConfig {
308 fn deserialize_header(
310 &self,
311 reader: &mut impl std::io::Read,
312 ) -> Result<SerializationHeader, String> {
313 let options = bincode::DefaultOptions::new()
314 .with_fixint_encoding()
315 .with_limit(0);
316
317 if let Some(size_limit) = self.serialized_size_limit {
318 options
319 .with_limit(size_limit)
320 .deserialize_from(reader)
321 .map_err(|err| err.to_string())
322 } else {
323 options
324 .with_no_limit()
325 .deserialize_from(reader)
326 .map_err(|err| err.to_string())
327 }
328 }
329
330 pub fn deserialize_from<T: DeserializeOwned + Unversionize + Named>(
334 self,
335 mut reader: impl std::io::Read,
336 ) -> Result<T, String> {
337 let options = bincode::DefaultOptions::new()
338 .with_fixint_encoding()
339 .with_limit(0); let deserialized_header: SerializationHeader = self.deserialize_header(&mut reader)?;
342
343 let header_size = options
344 .with_no_limit()
345 .serialized_size(&deserialized_header)
346 .map_err(|err| err.to_string())?;
347
348 if self.validate_header {
349 deserialized_header.validate::<T>()?;
350 }
351
352 if let Some(size_limit) = self.serialized_size_limit {
353 let options = options.with_limit(size_limit - header_size);
354 match deserialized_header.versioning_mode {
355 SerializationVersioningMode::Versioned { .. } => {
356 let deser_versioned = options
357 .deserialize_from(&mut reader)
358 .map_err(|err| err.to_string())?;
359
360 T::unversionize(deser_versioned).map_err(|e| e.to_string())
361 }
362 SerializationVersioningMode::Unversioned { .. } => options
363 .deserialize_from(&mut reader)
364 .map_err(|err| err.to_string()),
365 }
366 } else {
367 let options = options.with_no_limit();
368 match deserialized_header.versioning_mode {
369 SerializationVersioningMode::Versioned { .. } => {
370 let deser_versioned = options
371 .deserialize_from(&mut reader)
372 .map_err(|err| err.to_string())?;
373
374 T::unversionize(deser_versioned).map_err(|e| e.to_string())
375 }
376 SerializationVersioningMode::Unversioned { .. } => options
377 .deserialize_from(&mut reader)
378 .map_err(|err| err.to_string()),
379 }
380 }
381 }
382
383 pub fn enable_conformance(self) -> DeserializationConfig {
385 DeserializationConfig {
386 serialized_size_limit: self.serialized_size_limit,
387 validate_header: self.validate_header,
388 }
389 }
390}
391
392impl DeserializationConfig {
393 pub fn new(deserialized_size_limit: u64) -> Self {
405 Self {
406 serialized_size_limit: Some(deserialized_size_limit),
407 validate_header: true,
408 }
409 }
410
411 pub fn new_with_unlimited_size() -> Self {
413 Self {
414 serialized_size_limit: None,
415 validate_header: true,
416 }
417 }
418
419 pub fn disable_size_limit(self) -> Self {
421 Self {
422 serialized_size_limit: None,
423 ..self
424 }
425 }
426
427 pub fn with_size_limit(self, size: u64) -> Self {
429 Self {
430 serialized_size_limit: Some(size),
431 ..self
432 }
433 }
434
435 pub fn disable_header_validation(self) -> Self {
439 Self {
440 validate_header: false,
441 ..self
442 }
443 }
444
445 pub fn disable_conformance(self) -> NonConformantDeserializationConfig {
447 NonConformantDeserializationConfig {
448 serialized_size_limit: self.serialized_size_limit,
449 validate_header: self.validate_header,
450 }
451 }
452
453 pub fn deserialize_from<T: DeserializeOwned + Unversionize + Named + ParameterSetConformant>(
462 self,
463 reader: impl std::io::Read,
464 parameter_set: &T::ParameterSet,
465 ) -> Result<T, String> {
466 let deser: T = self.disable_conformance().deserialize_from(reader)?;
467 if !deser.is_conformant(parameter_set) {
468 return Err(format!(
469 "Deserialized object of type {} not conformant with given parameter set",
470 T::NAME
471 ));
472 }
473
474 Ok(deser)
475 }
476}
477
478pub fn safe_serialize<T: Serialize + Versionize + Named>(
481 object: &T,
482 writer: impl std::io::Write,
483 serialized_size_limit: u64,
484) -> bincode::Result<()> {
485 SerializationConfig::new(serialized_size_limit).serialize_into(object, writer)
486}
487
488pub fn safe_serialized_size<T: Serialize + Versionize + Named>(object: &T) -> bincode::Result<u64> {
490 SerializationConfig::new_with_unlimited_size().serialized_size(object)
491}
492
493pub fn safe_deserialize<T: DeserializeOwned + Unversionize + Named>(
508 reader: impl std::io::Read,
509 deserialized_size_limit: u64,
510) -> Result<T, String> {
511 DeserializationConfig::new(deserialized_size_limit)
512 .disable_conformance()
513 .deserialize_from(reader)
514}
515
516pub fn safe_deserialize_conformant<
531 T: DeserializeOwned + Unversionize + Named + ParameterSetConformant,
532>(
533 reader: impl std::io::Read,
534 deserialized_size_limit: u64,
535 parameter_set: &T::ParameterSet,
536) -> Result<T, String> {
537 DeserializationConfig::new(deserialized_size_limit).deserialize_from(reader, parameter_set)
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543 use crate::traits::Named;
544 use std::ops::RangeInclusive;
545
546 #[derive(Serialize, Deserialize, Versionize, Debug, PartialEq)]
547 #[repr(transparent)]
548 struct Foo(u64);
549
550 impl Named for Foo {
551 const NAME: &'static str = "Foo";
552 }
553
554 #[derive(Serialize, Deserialize, Versionize, Debug, PartialEq)]
555 #[repr(transparent)]
556 struct Bar(u64);
557
558 impl Named for Bar {
559 const NAME: &'static str = "Bar";
560 const BACKWARD_COMPATIBILITY_ALIASES: &'static [&'static str] = &["Foo"];
561 }
562
563 #[derive(Serialize, Deserialize, Versionize, Debug, PartialEq)]
564 #[repr(transparent)]
565 struct Baz(u64);
566
567 impl Named for Baz {
568 const NAME: &'static str = "Baz";
569 }
570
571 #[derive(Serialize, Deserialize, Versionize, Debug, PartialEq)]
572 #[repr(transparent)]
573 struct Conformant(u64);
574
575 impl Named for Conformant {
576 const NAME: &'static str = "Conformant";
577 }
578
579 impl ParameterSetConformant for Conformant {
580 type ParameterSet = RangeInclusive<u64>;
581
582 fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
583 parameter_set.contains(&self.0)
584 }
585 }
586
587 fn serialize_versioned(obj: &Foo) -> Vec<u8> {
588 let mut buf = Vec::new();
589 SerializationConfig::new(1 << 20)
590 .serialize_into(obj, &mut buf)
591 .unwrap();
592 buf
593 }
594
595 #[test]
596 fn backward_compatibility_aliases() {
597 let foo = Foo(3);
598 let mut buf = Vec::new();
599 safe_serialize(&foo, &mut buf, 0x1000).unwrap();
600
601 let foo_deser: Foo = safe_deserialize(buf.as_slice(), 0x1000).unwrap();
602 let bar_deser: Bar = safe_deserialize(buf.as_slice(), 0x1000).unwrap();
604
605 assert_eq!(foo_deser.0, bar_deser.0);
606 assert!(safe_deserialize::<Baz>(buf.as_slice(), 0x1000).is_err());
608 }
609
610 #[test]
611 fn serialized_size_matches_actual_versioned() {
612 let foo = Foo(123);
613 let config = SerializationConfig::new(1 << 20);
614 let size = config.serialized_size(&foo).unwrap();
615 let buf = serialize_versioned(&foo);
616 assert_eq!(size as usize, buf.len());
617 }
618
619 #[test]
620 fn serialize_size_limit_works() {
621 let foo = Foo(1);
622 let exact_size = SerializationConfig::new(1 << 20)
623 .serialized_size(&foo)
624 .unwrap();
625
626 let mut buf = Vec::new();
627 let result = SerializationConfig::new(exact_size - 1).serialize_into(&foo, &mut buf);
628 assert!(result.is_err());
629
630 buf.clear();
631 SerializationConfig::new(exact_size)
632 .serialize_into(&foo, &mut buf)
633 .unwrap();
634 assert_eq!(buf.len(), exact_size as usize);
635 }
636
637 #[test]
638 fn deserialize_size_limit_works() {
639 let obj = Conformant(1);
640 let mut buf = Vec::new();
641 SerializationConfig::new(1 << 20)
642 .serialize_into(&obj, &mut buf)
643 .unwrap();
644 let exact_size = buf.len() as u64;
645
646 let result: Result<Conformant, _> =
647 DeserializationConfig::new(exact_size - 1).deserialize_from(buf.as_slice(), &(0..=100));
648 assert!(result.is_err());
649
650 let result: Result<Conformant, _> =
651 DeserializationConfig::new(exact_size).deserialize_from(buf.as_slice(), &(0..=100));
652 assert!(result.is_ok());
653 }
654
655 #[test]
656 fn header_validation_disabled() {
657 let buf = serialize_versioned(&Foo(7));
658
659 let result: Result<Baz, _> = DeserializationConfig::new(1 << 20)
660 .disable_conformance()
661 .deserialize_from(buf.as_slice());
662 assert!(result.is_err());
663
664 let deser: Baz = DeserializationConfig::new(1 << 20)
665 .disable_header_validation()
666 .disable_conformance()
667 .deserialize_from(buf.as_slice())
668 .unwrap();
669 assert_eq!(deser.0, 7);
670 }
671
672 #[test]
673 fn conformance_check() {
674 let obj = Conformant(50);
675 let mut buf = Vec::new();
676 SerializationConfig::new(1 << 20)
677 .serialize_into(&obj, &mut buf)
678 .unwrap();
679
680 let result: Result<Conformant, _> =
681 DeserializationConfig::new(1 << 20).deserialize_from(buf.as_slice(), &(0..=100));
682 assert!(result.is_ok());
683
684 let result: Result<Conformant, _> =
685 DeserializationConfig::new(1 << 20).deserialize_from(buf.as_slice(), &(0..=10));
686 assert!(result.is_err());
687
688 let deser: Conformant = DeserializationConfig::new(1 << 20)
689 .disable_conformance()
690 .deserialize_from(buf.as_slice())
691 .unwrap();
692 assert_eq!(deser, obj);
693 }
694
695 #[test]
696 fn unlimited_size_configs() {
697 let foo = Conformant(999);
698
699 let mut buf = Vec::new();
700 SerializationConfig::new_with_unlimited_size()
701 .serialize_into(&foo, &mut buf)
702 .unwrap();
703
704 let deser: Conformant = DeserializationConfig::new_with_unlimited_size()
705 .disable_conformance()
706 .deserialize_from(buf.as_slice())
707 .unwrap();
708 assert_eq!(deser, foo);
709
710 let deser: Result<Conformant, _> =
711 DeserializationConfig::new(1).deserialize_from(buf.as_slice(), &(0..=1234));
712 assert!(deser.is_err());
713 }
714}