1use crate::{
4 algorithm::Algorithm, error::CryptoError, metadata::PqcMetadata, Result, PQC_BINARY_VERSION,
5 PQC_MAGIC,
6};
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub struct PqcBinaryFormat {
45 pub magic: [u8; 4],
47 pub version: u8,
49 pub algorithm: Algorithm,
51 pub flags: u8,
53 pub metadata: PqcMetadata,
55 pub data: Vec<u8>,
57 pub checksum: [u8; 32],
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub struct FormatFlags(u8);
66
67impl FormatFlags {
68 #[must_use]
70 pub const fn new() -> Self {
71 Self(0)
72 }
73
74 #[must_use]
76 pub const fn with_compression(mut self) -> Self {
77 self.0 |= 0x01;
78 self
79 }
80
81 #[must_use]
83 pub const fn with_streaming(mut self) -> Self {
84 self.0 |= 0x02;
85 self
86 }
87
88 #[must_use]
90 pub const fn with_additional_auth(mut self) -> Self {
91 self.0 |= 0x04;
92 self
93 }
94
95 #[must_use]
97 pub const fn with_experimental(mut self) -> Self {
98 self.0 |= 0x08;
99 self
100 }
101
102 #[must_use]
104 pub const fn has_compression(self) -> bool {
105 (self.0 & 0x01) != 0
106 }
107
108 #[must_use]
110 pub const fn has_streaming(self) -> bool {
111 (self.0 & 0x02) != 0
112 }
113
114 #[must_use]
116 pub const fn has_additional_auth(self) -> bool {
117 (self.0 & 0x04) != 0
118 }
119
120 #[must_use]
122 pub const fn has_experimental(self) -> bool {
123 (self.0 & 0x08) != 0
124 }
125
126 #[must_use]
128 pub const fn as_u8(self) -> u8 {
129 self.0
130 }
131
132 #[must_use]
134 pub const fn from_u8(value: u8) -> Self {
135 Self(value)
136 }
137}
138
139impl Default for FormatFlags {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145impl PqcBinaryFormat {
146 #[must_use]
172 pub fn new(algorithm: Algorithm, metadata: PqcMetadata, data: Vec<u8>) -> Self {
173 let mut format = Self {
174 magic: PQC_MAGIC,
175 version: PQC_BINARY_VERSION,
176 algorithm,
177 flags: FormatFlags::new().as_u8(),
178 metadata,
179 data,
180 checksum: [0u8; 32],
181 };
182
183 format.checksum = format.calculate_checksum();
185 format
186 }
187
188 #[must_use]
218 pub fn with_flags(
219 algorithm: Algorithm,
220 flags: FormatFlags,
221 metadata: PqcMetadata,
222 data: Vec<u8>,
223 ) -> Self {
224 let mut format = Self {
225 magic: PQC_MAGIC,
226 version: PQC_BINARY_VERSION,
227 algorithm,
228 flags: flags.as_u8(),
229 metadata,
230 data,
231 checksum: [0u8; 32],
232 };
233
234 format.checksum = format.calculate_checksum();
235 format
236 }
237
238 pub fn to_bytes(&self) -> Result<Vec<u8>> {
265 self.validate()?;
267
268 bincode::serialize(self)
269 .map_err(|e| CryptoError::BinaryFormatError(format!("Serialization failed: {e}")))
270 }
271
272 pub fn from_bytes(data: &[u8]) -> Result<Self> {
301 let format: Self = bincode::deserialize(data)
302 .map_err(|e| CryptoError::BinaryFormatError(format!("Deserialization failed: {e}")))?;
303
304 format.validate()?;
306
307 let stored_checksum = format.checksum;
309 let mut format_copy = format;
310 format_copy.checksum = [0u8; 32]; let calculated_checksum = format_copy.calculate_checksum();
312
313 format_copy.checksum = stored_checksum;
315
316 if stored_checksum != calculated_checksum {
317 return Err(CryptoError::ChecksumMismatch);
318 }
319
320 Ok(format_copy)
321 }
322
323 pub fn validate(&self) -> Result<()> {
333 if self.magic != PQC_MAGIC {
335 return Err(CryptoError::InvalidMagic);
336 }
337
338 if self.version != PQC_BINARY_VERSION {
340 return Err(CryptoError::UnsupportedVersion(self.version));
341 }
342
343 if Algorithm::from_id(self.algorithm.as_id()).is_none() {
345 return Err(CryptoError::UnknownAlgorithm(format!(
346 "Invalid algorithm ID: {:#x}",
347 self.algorithm.as_id()
348 )));
349 }
350
351 self.metadata.validate()?;
353
354 Ok(())
355 }
356
357 pub fn update_checksum(&mut self) {
361 self.checksum = self.calculate_checksum();
362 }
363
364 fn calculate_checksum(&self) -> [u8; 32] {
368 let mut hasher = Sha256::new();
369
370 hasher.update(self.magic);
372 hasher.update([self.version]);
373 hasher.update(self.algorithm.as_id().to_le_bytes());
374 hasher.update([self.flags]);
375
376 self.hash_metadata_deterministic(&mut hasher);
378
379 hasher.update((self.data.len() as u64).to_le_bytes());
381 hasher.update(&self.data);
382
383 hasher.finalize().into()
384 }
385
386 #[allow(clippy::cast_possible_truncation)]
390 fn hash_metadata_deterministic(&self, hasher: &mut Sha256) {
391 if let Some(ref kem_params) = self.metadata.kem_params {
393 hasher.update([1u8]); hasher.update((kem_params.public_key.len() as u32).to_le_bytes());
395 hasher.update(&kem_params.public_key);
396 hasher.update((kem_params.ciphertext.len() as u32).to_le_bytes());
397 hasher.update(&kem_params.ciphertext);
398 let mut sorted_params: Vec<_> = kem_params.params.iter().collect();
400 sorted_params.sort_by(|a, b| a.0.cmp(b.0));
401 hasher.update((sorted_params.len() as u32).to_le_bytes());
402 for (key, value) in sorted_params {
403 hasher.update((key.len() as u32).to_le_bytes());
404 hasher.update(key.as_bytes());
405 hasher.update((value.len() as u32).to_le_bytes());
406 hasher.update(value);
407 }
408 } else {
409 hasher.update([0u8]); }
411
412 if let Some(ref sig_params) = self.metadata.sig_params {
414 hasher.update([1u8]); hasher.update((sig_params.public_key.len() as u32).to_le_bytes());
416 hasher.update(&sig_params.public_key);
417 hasher.update((sig_params.signature.len() as u32).to_le_bytes());
418 hasher.update(&sig_params.signature);
419 let mut sorted_params: Vec<_> = sig_params.params.iter().collect();
421 sorted_params.sort_by(|a, b| a.0.cmp(b.0));
422 hasher.update((sorted_params.len() as u32).to_le_bytes());
423 for (key, value) in sorted_params {
424 hasher.update((key.len() as u32).to_le_bytes());
425 hasher.update(key.as_bytes());
426 hasher.update((value.len() as u32).to_le_bytes());
427 hasher.update(value);
428 }
429 } else {
430 hasher.update([0u8]); }
432
433 hasher.update((self.metadata.enc_params.iv.len() as u32).to_le_bytes());
435 hasher.update(&self.metadata.enc_params.iv);
436 hasher.update((self.metadata.enc_params.tag.len() as u32).to_le_bytes());
437 hasher.update(&self.metadata.enc_params.tag);
438 let mut sorted_params: Vec<_> = self.metadata.enc_params.params.iter().collect();
440 sorted_params.sort_by(|a, b| a.0.cmp(b.0));
441 hasher.update((sorted_params.len() as u32).to_le_bytes());
442 for (key, value) in sorted_params {
443 hasher.update((key.len() as u32).to_le_bytes());
444 hasher.update(key.as_bytes());
445 hasher.update((value.len() as u32).to_le_bytes());
446 hasher.update(value);
447 }
448
449 if let Some(ref comp_params) = self.metadata.compression_params {
451 hasher.update([1u8]); hasher.update((comp_params.algorithm.len() as u32).to_le_bytes());
453 hasher.update(comp_params.algorithm.as_bytes());
454 hasher.update(comp_params.level.to_le_bytes());
455 hasher.update(comp_params.original_size.to_le_bytes());
456 } else {
457 hasher.update([0u8]); }
459
460 let mut sorted_custom: Vec<_> = self.metadata.custom.iter().collect();
462 sorted_custom.sort_by(|a, b| a.0.cmp(b.0));
463 hasher.update((sorted_custom.len() as u32).to_le_bytes());
464 for (key, value) in sorted_custom {
465 hasher.update((key.len() as u32).to_le_bytes());
466 hasher.update(key.as_bytes());
467 hasher.update((value.len() as u32).to_le_bytes());
468 hasher.update(value);
469 }
470 }
471
472 #[must_use]
474 pub fn flags(&self) -> FormatFlags {
475 FormatFlags(self.flags)
476 }
477
478 #[must_use]
480 pub const fn algorithm(&self) -> Algorithm {
481 self.algorithm
482 }
483
484 #[must_use]
486 pub fn data(&self) -> &[u8] {
487 &self.data
488 }
489
490 #[must_use]
492 pub const fn metadata(&self) -> &PqcMetadata {
493 &self.metadata
494 }
495
496 #[must_use]
498 pub fn total_size(&self) -> usize {
499 self.to_bytes().map_or(0, |bytes| bytes.len())
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506 use crate::EncParameters;
507 use std::collections::HashMap;
508
509 #[test]
510 fn test_format_flags() {
511 let flags = FormatFlags::new().with_compression().with_streaming();
512
513 assert!(flags.has_compression());
514 assert!(flags.has_streaming());
515 assert!(!flags.has_additional_auth());
516 assert!(!flags.has_experimental());
517 }
518
519 #[test]
520 fn test_binary_format_roundtrip() {
521 let metadata = PqcMetadata {
522 enc_params: EncParameters {
523 iv: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
524 tag: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
525 params: HashMap::new(),
526 },
527 ..Default::default()
528 };
529
530 let original = PqcBinaryFormat::new(Algorithm::Hybrid, metadata, vec![1, 2, 3, 4, 5]);
531
532 let bytes = original.to_bytes().unwrap();
533 let deserialized = PqcBinaryFormat::from_bytes(&bytes).unwrap();
534
535 assert_eq!(original, deserialized);
536 }
537
538 #[test]
539 fn test_checksum_validation() {
540 let metadata = PqcMetadata {
541 enc_params: EncParameters {
542 iv: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
543 tag: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
544 params: HashMap::new(),
545 },
546 ..Default::default()
547 };
548
549 let format = PqcBinaryFormat::new(Algorithm::PostQuantum, metadata, vec![1, 2, 3, 4, 5]);
550
551 let mut bytes = format.to_bytes().unwrap();
552
553 if let Some(byte) = bytes.last_mut() {
555 *byte = byte.wrapping_add(1);
556 }
557
558 assert!(PqcBinaryFormat::from_bytes(&bytes).is_err());
560 }
561
562 #[test]
563 fn test_flags_roundtrip() {
564 let metadata = PqcMetadata {
565 enc_params: EncParameters {
566 iv: vec![1; 12],
567 tag: vec![1; 16],
568 params: HashMap::new(),
569 },
570 ..Default::default()
571 };
572
573 let flags = FormatFlags::new()
574 .with_compression()
575 .with_streaming()
576 .with_additional_auth();
577
578 let format =
579 PqcBinaryFormat::with_flags(Algorithm::QuadLayer, flags, metadata, vec![1, 2, 3]);
580
581 let bytes = format.to_bytes().unwrap();
582 let recovered = PqcBinaryFormat::from_bytes(&bytes).unwrap();
583
584 assert!(recovered.flags().has_compression());
585 assert!(recovered.flags().has_streaming());
586 assert!(recovered.flags().has_additional_auth());
587 assert!(!recovered.flags().has_experimental());
588 }
589}