use crate::{
algorithm::Algorithm, error::CryptoError, metadata::PqcMetadata, Result, PQC_BINARY_VERSION,
PQC_MAGIC,
};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
const HEADER_SIZE: usize = 12;
const CHECKSUM_SIZE: usize = 32;
fn constant_time_eq(a: &[u8; 32], b: &[u8; 32]) -> bool {
let mut diff = 0u8;
for i in 0..32 {
diff |= a[i] ^ b[i];
}
diff == 0
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PqcBinaryFormat {
pub magic: [u8; 4],
pub version: u8,
pub algorithm: Algorithm,
pub flags: u8,
pub metadata: PqcMetadata,
pub data: Vec<u8>,
pub checksum: [u8; 32],
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FormatFlags(u8);
impl FormatFlags {
#[must_use]
pub const fn new() -> Self {
Self(0)
}
#[must_use]
pub const fn with_compression(mut self) -> Self {
self.0 |= 0x01;
self
}
#[must_use]
pub const fn with_streaming(mut self) -> Self {
self.0 |= 0x02;
self
}
#[must_use]
pub const fn with_additional_auth(mut self) -> Self {
self.0 |= 0x04;
self
}
#[must_use]
pub const fn with_experimental(mut self) -> Self {
self.0 |= 0x08;
self
}
#[must_use]
pub const fn has_compression(self) -> bool {
(self.0 & 0x01) != 0
}
#[must_use]
pub const fn has_streaming(self) -> bool {
(self.0 & 0x02) != 0
}
#[must_use]
pub const fn has_additional_auth(self) -> bool {
(self.0 & 0x04) != 0
}
#[must_use]
pub const fn has_experimental(self) -> bool {
(self.0 & 0x08) != 0
}
#[must_use]
pub const fn as_u8(self) -> u8 {
self.0
}
#[must_use]
pub const fn from_u8(value: u8) -> Self {
Self(value)
}
}
impl Default for FormatFlags {
fn default() -> Self {
Self::new()
}
}
impl PqcBinaryFormat {
#[must_use]
pub fn new(algorithm: Algorithm, metadata: PqcMetadata, data: Vec<u8>) -> Self {
let mut format = Self {
magic: PQC_MAGIC,
version: PQC_BINARY_VERSION,
algorithm,
flags: FormatFlags::new().as_u8(),
metadata,
data,
checksum: [0u8; 32],
};
format.checksum = format.calculate_checksum();
format
}
#[must_use]
pub fn with_flags(
algorithm: Algorithm,
flags: FormatFlags,
metadata: PqcMetadata,
data: Vec<u8>,
) -> Self {
let mut format = Self {
magic: PQC_MAGIC,
version: PQC_BINARY_VERSION,
algorithm,
flags: flags.as_u8(),
metadata,
data,
checksum: [0u8; 32],
};
format.checksum = format.calculate_checksum();
format
}
pub fn to_bytes(&self) -> Result<Vec<u8>> {
self.validate()?;
let mut buf = self.serialize_prefix();
let checksum: [u8; 32] = Sha256::digest(&buf).into();
buf.extend_from_slice(&checksum);
Ok(buf)
}
fn serialize_prefix(&self) -> Vec<u8> {
let metadata = self.metadata.to_json_bytes();
let mut buf = Vec::with_capacity(HEADER_SIZE + metadata.len() + self.data.len());
buf.extend_from_slice(&self.magic);
buf.push(self.version);
buf.extend_from_slice(&self.algorithm.as_id().to_le_bytes());
buf.push(self.flags);
let metadata_len = u32::try_from(metadata.len()).unwrap_or(u32::MAX);
buf.extend_from_slice(&metadata_len.to_le_bytes());
buf.extend_from_slice(&metadata);
buf.extend_from_slice(&(self.data.len() as u64).to_le_bytes());
buf.extend_from_slice(&self.data);
buf
}
pub fn from_bytes(data: &[u8]) -> Result<Self> {
const MIN_SIZE: usize = HEADER_SIZE + 8 + CHECKSUM_SIZE;
if data.len() < MIN_SIZE {
return Err(CryptoError::BinaryFormatError(format!(
"Buffer too small: {} bytes, minimum {MIN_SIZE}",
data.len(),
)));
}
let magic: [u8; 4] = [data[0], data[1], data[2], data[3]];
if magic != PQC_MAGIC {
return Err(CryptoError::InvalidMagic);
}
let version = data[4];
if version != PQC_BINARY_VERSION {
return Err(CryptoError::UnsupportedVersion(version));
}
let algorithm_id = u16::from_le_bytes([data[5], data[6]]);
let algorithm = Algorithm::from_id(algorithm_id).ok_or_else(|| {
CryptoError::UnknownAlgorithm(format!("Invalid algorithm ID: {algorithm_id:#x}"))
})?;
let flags = data[7];
let metadata_len = u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize;
let meta_start = HEADER_SIZE;
let meta_end = meta_start
.checked_add(metadata_len)
.ok_or_else(|| CryptoError::BinaryFormatError("Metadata length overflow".into()))?;
if meta_end + CHECKSUM_SIZE > data.len() {
return Err(CryptoError::BinaryFormatError(
"Metadata length exceeds buffer".into(),
));
}
let metadata = PqcMetadata::from_json_bytes(&data[meta_start..meta_end])?;
let len_end = meta_end + 8;
if len_end + CHECKSUM_SIZE > data.len() {
return Err(CryptoError::BinaryFormatError(
"Truncated before data length".into(),
));
}
let mut data_len_bytes = [0u8; 8];
data_len_bytes.copy_from_slice(&data[meta_end..len_end]);
let data_len = usize::try_from(u64::from_le_bytes(data_len_bytes))
.map_err(|_| CryptoError::BinaryFormatError("Data length exceeds usize".into()))?;
let data_end = len_end
.checked_add(data_len)
.ok_or_else(|| CryptoError::BinaryFormatError("Data length overflow".into()))?;
if data_end + CHECKSUM_SIZE != data.len() {
return Err(CryptoError::BinaryFormatError(format!(
"Length mismatch: expected {} bytes, got {}",
data_end + CHECKSUM_SIZE,
data.len()
)));
}
let payload = data[len_end..data_end].to_vec();
let mut stored = [0u8; 32];
stored.copy_from_slice(&data[data_end..]);
let calculated: [u8; 32] = Sha256::digest(&data[..data_end]).into();
if !constant_time_eq(&stored, &calculated) {
return Err(CryptoError::ChecksumMismatch);
}
let format = Self {
magic,
version,
algorithm,
flags,
metadata,
data: payload,
checksum: stored,
};
format.validate()?;
Ok(format)
}
pub fn validate(&self) -> Result<()> {
if self.magic != PQC_MAGIC {
return Err(CryptoError::InvalidMagic);
}
if self.version != PQC_BINARY_VERSION {
return Err(CryptoError::UnsupportedVersion(self.version));
}
if Algorithm::from_id(self.algorithm.as_id()).is_none() {
return Err(CryptoError::UnknownAlgorithm(format!(
"Invalid algorithm ID: {:#x}",
self.algorithm.as_id()
)));
}
self.metadata.validate()?;
Ok(())
}
pub fn update_checksum(&mut self) {
self.checksum = self.calculate_checksum();
}
fn calculate_checksum(&self) -> [u8; 32] {
Sha256::digest(self.serialize_prefix()).into()
}
#[must_use]
pub fn flags(&self) -> FormatFlags {
FormatFlags(self.flags)
}
#[must_use]
pub const fn algorithm(&self) -> Algorithm {
self.algorithm
}
#[must_use]
pub fn data(&self) -> &[u8] {
&self.data
}
#[must_use]
pub const fn metadata(&self) -> &PqcMetadata {
&self.metadata
}
#[must_use]
pub fn total_size(&self) -> usize {
self.to_bytes().map_or(0, |bytes| bytes.len())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::EncParameters;
use std::collections::HashMap;
#[test]
fn test_format_flags() {
let flags = FormatFlags::new().with_compression().with_streaming();
assert!(flags.has_compression());
assert!(flags.has_streaming());
assert!(!flags.has_additional_auth());
assert!(!flags.has_experimental());
}
#[test]
fn test_binary_format_roundtrip() {
let metadata = PqcMetadata {
enc_params: EncParameters {
iv: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
tag: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
params: HashMap::new(),
},
..Default::default()
};
let original = PqcBinaryFormat::new(Algorithm::Hybrid, metadata, vec![1, 2, 3, 4, 5]);
let bytes = original.to_bytes().unwrap();
let deserialized = PqcBinaryFormat::from_bytes(&bytes).unwrap();
assert_eq!(original, deserialized);
}
#[test]
fn test_checksum_validation() {
let metadata = PqcMetadata {
enc_params: EncParameters {
iv: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
tag: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
params: HashMap::new(),
},
..Default::default()
};
let format = PqcBinaryFormat::new(Algorithm::PostQuantum, metadata, vec![1, 2, 3, 4, 5]);
let mut bytes = format.to_bytes().unwrap();
if let Some(byte) = bytes.last_mut() {
*byte = byte.wrapping_add(1);
}
assert!(PqcBinaryFormat::from_bytes(&bytes).is_err());
}
#[test]
fn test_flags_roundtrip() {
let metadata = PqcMetadata {
enc_params: EncParameters {
iv: vec![1; 12],
tag: vec![1; 16],
params: HashMap::new(),
},
..Default::default()
};
let flags = FormatFlags::new()
.with_compression()
.with_streaming()
.with_additional_auth();
let format =
PqcBinaryFormat::with_flags(Algorithm::QuadLayer, flags, metadata, vec![1, 2, 3]);
let bytes = format.to_bytes().unwrap();
let recovered = PqcBinaryFormat::from_bytes(&bytes).unwrap();
assert!(recovered.flags().has_compression());
assert!(recovered.flags().has_streaming());
assert!(recovered.flags().has_additional_auth());
assert!(!recovered.flags().has_experimental());
}
}