1use crate::{
4 algorithm::Algorithm, error::CryptoError, metadata::PqcMetadata, Result, PQC_BINARY_VERSION,
5 PQC_MAGIC,
6};
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub struct PqcBinaryFormat {
45 pub magic: [u8; 4],
47 pub version: u8,
49 pub algorithm: Algorithm,
51 pub flags: u8,
53 pub metadata: PqcMetadata,
55 pub data: Vec<u8>,
57 pub checksum: [u8; 32],
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub struct FormatFlags(u8);
66
67impl FormatFlags {
68 #[must_use]
70 pub const fn new() -> Self {
71 Self(0)
72 }
73
74 #[must_use]
76 pub const fn with_compression(mut self) -> Self {
77 self.0 |= 0x01;
78 self
79 }
80
81 #[must_use]
83 pub const fn with_streaming(mut self) -> Self {
84 self.0 |= 0x02;
85 self
86 }
87
88 #[must_use]
90 pub const fn with_additional_auth(mut self) -> Self {
91 self.0 |= 0x04;
92 self
93 }
94
95 #[must_use]
97 pub const fn with_experimental(mut self) -> Self {
98 self.0 |= 0x08;
99 self
100 }
101
102 #[must_use]
104 pub const fn has_compression(self) -> bool {
105 (self.0 & 0x01) != 0
106 }
107
108 #[must_use]
110 pub const fn has_streaming(self) -> bool {
111 (self.0 & 0x02) != 0
112 }
113
114 #[must_use]
116 pub const fn has_additional_auth(self) -> bool {
117 (self.0 & 0x04) != 0
118 }
119
120 #[must_use]
122 pub const fn has_experimental(self) -> bool {
123 (self.0 & 0x08) != 0
124 }
125
126 #[must_use]
128 pub const fn as_u8(self) -> u8 {
129 self.0
130 }
131
132 #[must_use]
134 pub const fn from_u8(value: u8) -> Self {
135 Self(value)
136 }
137}
138
139impl Default for FormatFlags {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145impl PqcBinaryFormat {
146 #[must_use]
172 pub fn new(algorithm: Algorithm, metadata: PqcMetadata, data: Vec<u8>) -> Self {
173 let mut format = Self {
174 magic: PQC_MAGIC,
175 version: PQC_BINARY_VERSION,
176 algorithm,
177 flags: FormatFlags::new().as_u8(),
178 metadata,
179 data,
180 checksum: [0u8; 32],
181 };
182
183 format.checksum = format.calculate_checksum();
185 format
186 }
187
188 #[must_use]
218 pub fn with_flags(
219 algorithm: Algorithm,
220 flags: FormatFlags,
221 metadata: PqcMetadata,
222 data: Vec<u8>,
223 ) -> Self {
224 let mut format = Self {
225 magic: PQC_MAGIC,
226 version: PQC_BINARY_VERSION,
227 algorithm,
228 flags: flags.as_u8(),
229 metadata,
230 data,
231 checksum: [0u8; 32],
232 };
233
234 format.checksum = format.calculate_checksum();
235 format
236 }
237
238 pub fn to_bytes(&self) -> Result<Vec<u8>> {
265 self.validate()?;
267
268 bincode::serialize(self)
269 .map_err(|e| CryptoError::BinaryFormatError(format!("Serialization failed: {e}")))
270 }
271
272 pub fn from_bytes(data: &[u8]) -> Result<Self> {
301 let format: Self = bincode::deserialize(data)
302 .map_err(|e| CryptoError::BinaryFormatError(format!("Deserialization failed: {e}")))?;
303
304 format.validate()?;
306
307 let stored_checksum = format.checksum;
309 let mut format_copy = format;
310 format_copy.checksum = [0u8; 32]; let calculated_checksum = format_copy.calculate_checksum();
312
313 format_copy.checksum = stored_checksum;
315
316 if stored_checksum != calculated_checksum {
317 return Err(CryptoError::ChecksumMismatch);
318 }
319
320 Ok(format_copy)
321 }
322
323 pub fn validate(&self) -> Result<()> {
333 if self.magic != PQC_MAGIC {
335 return Err(CryptoError::InvalidMagic);
336 }
337
338 if self.version != PQC_BINARY_VERSION {
340 return Err(CryptoError::UnsupportedVersion(self.version));
341 }
342
343 if Algorithm::from_id(self.algorithm.as_id()).is_none() {
345 return Err(CryptoError::UnknownAlgorithm(format!(
346 "Invalid algorithm ID: {:#x}",
347 self.algorithm.as_id()
348 )));
349 }
350
351 self.metadata.validate()?;
353
354 Ok(())
355 }
356
357 pub fn update_checksum(&mut self) {
361 self.checksum = self.calculate_checksum();
362 }
363
364 fn calculate_checksum(&self) -> [u8; 32] {
368 let mut hasher = Sha256::new();
369
370 hasher.update(&self.magic);
372 hasher.update(&[self.version]);
373 hasher.update(&self.algorithm.as_id().to_le_bytes());
374 hasher.update(&[self.flags]);
375
376 self.hash_metadata_deterministic(&mut hasher);
378
379 hasher.update(&(self.data.len() as u64).to_le_bytes());
381 hasher.update(&self.data);
382
383 hasher.finalize().into()
384 }
385
386 fn hash_metadata_deterministic(&self, hasher: &mut Sha256) {
390 if let Some(ref kem_params) = self.metadata.kem_params {
392 hasher.update(&[1u8]); hasher.update(&(kem_params.public_key.len() as u32).to_le_bytes());
394 hasher.update(&kem_params.public_key);
395 hasher.update(&(kem_params.ciphertext.len() as u32).to_le_bytes());
396 hasher.update(&kem_params.ciphertext);
397 let mut sorted_params: Vec<_> = kem_params.params.iter().collect();
399 sorted_params.sort_by(|a, b| a.0.cmp(b.0));
400 hasher.update(&(sorted_params.len() as u32).to_le_bytes());
401 for (key, value) in sorted_params {
402 hasher.update(&(key.len() as u32).to_le_bytes());
403 hasher.update(key.as_bytes());
404 hasher.update(&(value.len() as u32).to_le_bytes());
405 hasher.update(value);
406 }
407 } else {
408 hasher.update(&[0u8]); }
410
411 if let Some(ref sig_params) = self.metadata.sig_params {
413 hasher.update(&[1u8]); hasher.update(&(sig_params.public_key.len() as u32).to_le_bytes());
415 hasher.update(&sig_params.public_key);
416 hasher.update(&(sig_params.signature.len() as u32).to_le_bytes());
417 hasher.update(&sig_params.signature);
418 let mut sorted_params: Vec<_> = sig_params.params.iter().collect();
420 sorted_params.sort_by(|a, b| a.0.cmp(b.0));
421 hasher.update(&(sorted_params.len() as u32).to_le_bytes());
422 for (key, value) in sorted_params {
423 hasher.update(&(key.len() as u32).to_le_bytes());
424 hasher.update(key.as_bytes());
425 hasher.update(&(value.len() as u32).to_le_bytes());
426 hasher.update(value);
427 }
428 } else {
429 hasher.update(&[0u8]); }
431
432 hasher.update(&(self.metadata.enc_params.iv.len() as u32).to_le_bytes());
434 hasher.update(&self.metadata.enc_params.iv);
435 hasher.update(&(self.metadata.enc_params.tag.len() as u32).to_le_bytes());
436 hasher.update(&self.metadata.enc_params.tag);
437 let mut sorted_params: Vec<_> = self.metadata.enc_params.params.iter().collect();
439 sorted_params.sort_by(|a, b| a.0.cmp(b.0));
440 hasher.update(&(sorted_params.len() as u32).to_le_bytes());
441 for (key, value) in sorted_params {
442 hasher.update(&(key.len() as u32).to_le_bytes());
443 hasher.update(key.as_bytes());
444 hasher.update(&(value.len() as u32).to_le_bytes());
445 hasher.update(value);
446 }
447
448 if let Some(ref comp_params) = self.metadata.compression_params {
450 hasher.update(&[1u8]); hasher.update(&(comp_params.algorithm.len() as u32).to_le_bytes());
452 hasher.update(comp_params.algorithm.as_bytes());
453 hasher.update(&comp_params.level.to_le_bytes());
454 hasher.update(&comp_params.original_size.to_le_bytes());
455 } else {
456 hasher.update(&[0u8]); }
458
459 let mut sorted_custom: Vec<_> = self.metadata.custom.iter().collect();
461 sorted_custom.sort_by(|a, b| a.0.cmp(b.0));
462 hasher.update(&(sorted_custom.len() as u32).to_le_bytes());
463 for (key, value) in sorted_custom {
464 hasher.update(&(key.len() as u32).to_le_bytes());
465 hasher.update(key.as_bytes());
466 hasher.update(&(value.len() as u32).to_le_bytes());
467 hasher.update(value);
468 }
469 }
470
471 #[must_use]
473 pub fn flags(&self) -> FormatFlags {
474 FormatFlags(self.flags)
475 }
476
477 #[must_use]
479 pub const fn algorithm(&self) -> Algorithm {
480 self.algorithm
481 }
482
483 #[must_use]
485 pub fn data(&self) -> &[u8] {
486 &self.data
487 }
488
489 #[must_use]
491 pub const fn metadata(&self) -> &PqcMetadata {
492 &self.metadata
493 }
494
495 #[must_use]
497 pub fn total_size(&self) -> usize {
498 self.to_bytes().map_or(0, |bytes| bytes.len())
499 }
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505 use crate::EncParameters;
506 use std::collections::HashMap;
507
508 #[test]
509 fn test_format_flags() {
510 let flags = FormatFlags::new().with_compression().with_streaming();
511
512 assert!(flags.has_compression());
513 assert!(flags.has_streaming());
514 assert!(!flags.has_additional_auth());
515 assert!(!flags.has_experimental());
516 }
517
518 #[test]
519 fn test_binary_format_roundtrip() {
520 let metadata = PqcMetadata {
521 enc_params: EncParameters {
522 iv: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
523 tag: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
524 params: HashMap::new(),
525 },
526 ..Default::default()
527 };
528
529 let original =
530 PqcBinaryFormat::new(Algorithm::Hybrid, metadata, vec![1, 2, 3, 4, 5]);
531
532 let bytes = original.to_bytes().unwrap();
533 let deserialized = PqcBinaryFormat::from_bytes(&bytes).unwrap();
534
535 assert_eq!(original, deserialized);
536 }
537
538 #[test]
539 fn test_checksum_validation() {
540 let metadata = PqcMetadata {
541 enc_params: EncParameters {
542 iv: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
543 tag: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
544 params: HashMap::new(),
545 },
546 ..Default::default()
547 };
548
549 let format =
550 PqcBinaryFormat::new(Algorithm::PostQuantum, metadata, vec![1, 2, 3, 4, 5]);
551
552 let mut bytes = format.to_bytes().unwrap();
553
554 if let Some(byte) = bytes.last_mut() {
556 *byte = byte.wrapping_add(1);
557 }
558
559 assert!(PqcBinaryFormat::from_bytes(&bytes).is_err());
561 }
562
563 #[test]
564 fn test_flags_roundtrip() {
565 let metadata = PqcMetadata {
566 enc_params: EncParameters {
567 iv: vec![1; 12],
568 tag: vec![1; 16],
569 params: HashMap::new(),
570 },
571 ..Default::default()
572 };
573
574 let flags = FormatFlags::new()
575 .with_compression()
576 .with_streaming()
577 .with_additional_auth();
578
579 let format = PqcBinaryFormat::with_flags(Algorithm::QuadLayer, flags, metadata, vec![1, 2, 3]);
580
581 let bytes = format.to_bytes().unwrap();
582 let recovered = PqcBinaryFormat::from_bytes(&bytes).unwrap();
583
584 assert!(recovered.flags().has_compression());
585 assert!(recovered.flags().has_streaming());
586 assert!(recovered.flags().has_additional_auth());
587 assert!(!recovered.flags().has_experimental());
588 }
589}