use crate::{
algorithm::Algorithm, error::CryptoError, metadata::PqcMetadata, Result, PQC_BINARY_VERSION,
PQC_MAGIC,
};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PqcBinaryFormat {
pub magic: [u8; 4],
pub version: u8,
pub algorithm: Algorithm,
pub flags: u8,
pub metadata: PqcMetadata,
pub data: Vec<u8>,
pub checksum: [u8; 32],
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FormatFlags(u8);
impl FormatFlags {
#[must_use]
pub const fn new() -> Self {
Self(0)
}
#[must_use]
pub const fn with_compression(mut self) -> Self {
self.0 |= 0x01;
self
}
#[must_use]
pub const fn with_streaming(mut self) -> Self {
self.0 |= 0x02;
self
}
#[must_use]
pub const fn with_additional_auth(mut self) -> Self {
self.0 |= 0x04;
self
}
#[must_use]
pub const fn with_experimental(mut self) -> Self {
self.0 |= 0x08;
self
}
#[must_use]
pub const fn has_compression(self) -> bool {
(self.0 & 0x01) != 0
}
#[must_use]
pub const fn has_streaming(self) -> bool {
(self.0 & 0x02) != 0
}
#[must_use]
pub const fn has_additional_auth(self) -> bool {
(self.0 & 0x04) != 0
}
#[must_use]
pub const fn has_experimental(self) -> bool {
(self.0 & 0x08) != 0
}
#[must_use]
pub const fn as_u8(self) -> u8 {
self.0
}
#[must_use]
pub const fn from_u8(value: u8) -> Self {
Self(value)
}
}
impl Default for FormatFlags {
fn default() -> Self {
Self::new()
}
}
impl PqcBinaryFormat {
#[must_use]
pub fn new(algorithm: Algorithm, metadata: PqcMetadata, data: Vec<u8>) -> Self {
let mut format = Self {
magic: PQC_MAGIC,
version: PQC_BINARY_VERSION,
algorithm,
flags: FormatFlags::new().as_u8(),
metadata,
data,
checksum: [0u8; 32],
};
format.checksum = format.calculate_checksum();
format
}
#[must_use]
pub fn with_flags(
algorithm: Algorithm,
flags: FormatFlags,
metadata: PqcMetadata,
data: Vec<u8>,
) -> Self {
let mut format = Self {
magic: PQC_MAGIC,
version: PQC_BINARY_VERSION,
algorithm,
flags: flags.as_u8(),
metadata,
data,
checksum: [0u8; 32],
};
format.checksum = format.calculate_checksum();
format
}
pub fn to_bytes(&self) -> Result<Vec<u8>> {
self.validate()?;
bincode::serialize(self)
.map_err(|e| CryptoError::BinaryFormatError(format!("Serialization failed: {e}")))
}
pub fn from_bytes(data: &[u8]) -> Result<Self> {
let format: Self = bincode::deserialize(data)
.map_err(|e| CryptoError::BinaryFormatError(format!("Deserialization failed: {e}")))?;
format.validate()?;
let stored_checksum = format.checksum;
let mut format_copy = format;
format_copy.checksum = [0u8; 32]; let calculated_checksum = format_copy.calculate_checksum();
format_copy.checksum = stored_checksum;
if stored_checksum != calculated_checksum {
return Err(CryptoError::ChecksumMismatch);
}
Ok(format_copy)
}
pub fn validate(&self) -> Result<()> {
if self.magic != PQC_MAGIC {
return Err(CryptoError::InvalidMagic);
}
if self.version != PQC_BINARY_VERSION {
return Err(CryptoError::UnsupportedVersion(self.version));
}
if Algorithm::from_id(self.algorithm.as_id()).is_none() {
return Err(CryptoError::UnknownAlgorithm(format!(
"Invalid algorithm ID: {:#x}",
self.algorithm.as_id()
)));
}
self.metadata.validate()?;
Ok(())
}
pub fn update_checksum(&mut self) {
self.checksum = self.calculate_checksum();
}
fn calculate_checksum(&self) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(self.magic);
hasher.update([self.version]);
hasher.update(self.algorithm.as_id().to_le_bytes());
hasher.update([self.flags]);
self.hash_metadata_deterministic(&mut hasher);
hasher.update((self.data.len() as u64).to_le_bytes());
hasher.update(&self.data);
hasher.finalize().into()
}
#[allow(clippy::cast_possible_truncation)]
fn hash_metadata_deterministic(&self, hasher: &mut Sha256) {
if let Some(ref kem_params) = self.metadata.kem_params {
hasher.update([1u8]); hasher.update((kem_params.public_key.len() as u32).to_le_bytes());
hasher.update(&kem_params.public_key);
hasher.update((kem_params.ciphertext.len() as u32).to_le_bytes());
hasher.update(&kem_params.ciphertext);
let mut sorted_params: Vec<_> = kem_params.params.iter().collect();
sorted_params.sort_by(|a, b| a.0.cmp(b.0));
hasher.update((sorted_params.len() as u32).to_le_bytes());
for (key, value) in sorted_params {
hasher.update((key.len() as u32).to_le_bytes());
hasher.update(key.as_bytes());
hasher.update((value.len() as u32).to_le_bytes());
hasher.update(value);
}
} else {
hasher.update([0u8]); }
if let Some(ref sig_params) = self.metadata.sig_params {
hasher.update([1u8]); hasher.update((sig_params.public_key.len() as u32).to_le_bytes());
hasher.update(&sig_params.public_key);
hasher.update((sig_params.signature.len() as u32).to_le_bytes());
hasher.update(&sig_params.signature);
let mut sorted_params: Vec<_> = sig_params.params.iter().collect();
sorted_params.sort_by(|a, b| a.0.cmp(b.0));
hasher.update((sorted_params.len() as u32).to_le_bytes());
for (key, value) in sorted_params {
hasher.update((key.len() as u32).to_le_bytes());
hasher.update(key.as_bytes());
hasher.update((value.len() as u32).to_le_bytes());
hasher.update(value);
}
} else {
hasher.update([0u8]); }
hasher.update((self.metadata.enc_params.iv.len() as u32).to_le_bytes());
hasher.update(&self.metadata.enc_params.iv);
hasher.update((self.metadata.enc_params.tag.len() as u32).to_le_bytes());
hasher.update(&self.metadata.enc_params.tag);
let mut sorted_params: Vec<_> = self.metadata.enc_params.params.iter().collect();
sorted_params.sort_by(|a, b| a.0.cmp(b.0));
hasher.update((sorted_params.len() as u32).to_le_bytes());
for (key, value) in sorted_params {
hasher.update((key.len() as u32).to_le_bytes());
hasher.update(key.as_bytes());
hasher.update((value.len() as u32).to_le_bytes());
hasher.update(value);
}
if let Some(ref comp_params) = self.metadata.compression_params {
hasher.update([1u8]); hasher.update((comp_params.algorithm.len() as u32).to_le_bytes());
hasher.update(comp_params.algorithm.as_bytes());
hasher.update(comp_params.level.to_le_bytes());
hasher.update(comp_params.original_size.to_le_bytes());
} else {
hasher.update([0u8]); }
let mut sorted_custom: Vec<_> = self.metadata.custom.iter().collect();
sorted_custom.sort_by(|a, b| a.0.cmp(b.0));
hasher.update((sorted_custom.len() as u32).to_le_bytes());
for (key, value) in sorted_custom {
hasher.update((key.len() as u32).to_le_bytes());
hasher.update(key.as_bytes());
hasher.update((value.len() as u32).to_le_bytes());
hasher.update(value);
}
}
#[must_use]
pub fn flags(&self) -> FormatFlags {
FormatFlags(self.flags)
}
#[must_use]
pub const fn algorithm(&self) -> Algorithm {
self.algorithm
}
#[must_use]
pub fn data(&self) -> &[u8] {
&self.data
}
#[must_use]
pub const fn metadata(&self) -> &PqcMetadata {
&self.metadata
}
#[must_use]
pub fn total_size(&self) -> usize {
self.to_bytes().map_or(0, |bytes| bytes.len())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::EncParameters;
use std::collections::HashMap;
#[test]
fn test_format_flags() {
let flags = FormatFlags::new().with_compression().with_streaming();
assert!(flags.has_compression());
assert!(flags.has_streaming());
assert!(!flags.has_additional_auth());
assert!(!flags.has_experimental());
}
#[test]
fn test_binary_format_roundtrip() {
let metadata = PqcMetadata {
enc_params: EncParameters {
iv: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
tag: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
params: HashMap::new(),
},
..Default::default()
};
let original = PqcBinaryFormat::new(Algorithm::Hybrid, metadata, vec![1, 2, 3, 4, 5]);
let bytes = original.to_bytes().unwrap();
let deserialized = PqcBinaryFormat::from_bytes(&bytes).unwrap();
assert_eq!(original, deserialized);
}
#[test]
fn test_checksum_validation() {
let metadata = PqcMetadata {
enc_params: EncParameters {
iv: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
tag: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
params: HashMap::new(),
},
..Default::default()
};
let format = PqcBinaryFormat::new(Algorithm::PostQuantum, metadata, vec![1, 2, 3, 4, 5]);
let mut bytes = format.to_bytes().unwrap();
if let Some(byte) = bytes.last_mut() {
*byte = byte.wrapping_add(1);
}
assert!(PqcBinaryFormat::from_bytes(&bytes).is_err());
}
#[test]
fn test_flags_roundtrip() {
let metadata = PqcMetadata {
enc_params: EncParameters {
iv: vec![1; 12],
tag: vec![1; 16],
params: HashMap::new(),
},
..Default::default()
};
let flags = FormatFlags::new()
.with_compression()
.with_streaming()
.with_additional_auth();
let format =
PqcBinaryFormat::with_flags(Algorithm::QuadLayer, flags, metadata, vec![1, 2, 3]);
let bytes = format.to_bytes().unwrap();
let recovered = PqcBinaryFormat::from_bytes(&bytes).unwrap();
assert!(recovered.flags().has_compression());
assert!(recovered.flags().has_streaming());
assert!(recovered.flags().has_additional_auth());
assert!(!recovered.flags().has_experimental());
}
}