#[cfg(feature = "alloc")]
use alloc::format;
use subtle::ConstantTimeEq;
use crate::error::HpkeError;
pub fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
a.ct_eq(b).into()
}
pub fn constant_time_select(condition: bool, true_value: u8, false_value: u8) -> u8 {
let mask = if condition { 0xFF } else { 0x00 };
(true_value & mask) | (false_value & !mask)
}
pub fn constant_time_copy(condition: bool, src: &[u8], dst: &mut [u8]) {
if src.len() != dst.len() {
return;
}
let mask = if condition { 0xFF } else { 0x00 };
for (s, d) in src.iter().zip(dst.iter_mut()) {
*d = (*d & !mask) | (*s & mask);
}
}
pub fn validate_key_side_channel_resistant(
key: &[u8],
expected_len: usize,
) -> Result<(), HpkeError> {
let len_match = key.len() == expected_len;
let mut zero_check = 0u8;
for &byte in key {
zero_check |= byte;
}
let not_zero = zero_check != 0;
let mut ones_check = 0xFFu8;
for &byte in key {
ones_check &= byte;
}
let not_ones = ones_check != 0xFF;
if !len_match {
return Err(HpkeError::CryptoError(format!(
"Invalid key length: expected {}, got {}",
expected_len,
key.len()
)));
}
if !not_zero {
return Err(HpkeError::CryptoError(
"Key material cannot be all zeros".into(),
));
}
if !not_ones {
return Err(HpkeError::CryptoError(
"Key material cannot be all ones".into(),
));
}
Ok(())
}
pub fn validate_nonce_side_channel_resistant(
nonce: &[u8],
expected_len: usize,
) -> Result<(), HpkeError> {
let len_match = nonce.len() == expected_len;
if !len_match {
return Err(HpkeError::CryptoError(format!(
"Invalid nonce length: expected {}, got {}",
expected_len,
nonce.len()
)));
}
Ok(())
}
pub fn validate_ciphertext_side_channel_resistant(
ciphertext: &[u8],
min_len: usize,
) -> Result<(), HpkeError> {
let len_ok = ciphertext.len() >= min_len;
if !len_ok {
return Err(HpkeError::CryptoError(format!(
"Ciphertext too short: minimum {} bytes required",
min_len
)));
}
Ok(())
}
pub fn verify_auth_tag_constant_time(expected: &[u8], actual: &[u8]) -> Result<(), HpkeError> {
if !constant_time_compare(expected, actual) {
return Err(HpkeError::CryptoError(
"Authentication tag verification failed".into(),
));
}
Ok(())
}
pub fn check_entropy_side_channel_resistant(data: &[u8], min_unique_bytes: usize) -> bool {
if data.is_empty() {
return false;
}
let mut unique_count = 0usize;
for i in 0..256 {
let mut found = false;
for &byte in data {
if byte == i as u8 {
found = true;
break;
}
}
if found {
unique_count += 1;
}
}
unique_count >= min_unique_bytes
}
#[cfg(test)]
mod tests {
use alloc::vec;
use alloc::vec::Vec;
use super::*;
#[test]
fn test_constant_time_compare() {
let a = b"hello";
let b = b"hello";
let c = b"world";
let d = b"hell";
assert!(constant_time_compare(a, b));
assert!(!constant_time_compare(a, c));
assert!(!constant_time_compare(a, d));
}
#[test]
fn test_constant_time_select() {
assert_eq!(constant_time_select(true, 0xFF, 0x00), 0xFF);
assert_eq!(constant_time_select(false, 0xFF, 0x00), 0x00);
}
#[test]
fn test_validate_key_side_channel_resistant() {
let valid_key = vec![1u8; 32];
let invalid_len_key = vec![1u8; 16];
let zero_key = vec![0u8; 32];
let ones_key = vec![0xFFu8; 32];
assert!(validate_key_side_channel_resistant(&valid_key, 32).is_ok());
assert!(validate_key_side_channel_resistant(&invalid_len_key, 32).is_err());
assert!(validate_key_side_channel_resistant(&zero_key, 32).is_err());
assert!(validate_key_side_channel_resistant(&ones_key, 32).is_err());
}
#[test]
fn test_verify_auth_tag_constant_time() {
let tag1 = vec![1u8; 16];
let tag2 = vec![1u8; 16];
let tag3 = vec![2u8; 16];
assert!(verify_auth_tag_constant_time(&tag1, &tag2).is_ok());
assert!(verify_auth_tag_constant_time(&tag1, &tag3).is_err());
}
#[test]
fn test_check_entropy_side_channel_resistant() {
let low_entropy = vec![0u8; 32]; let high_entropy = (0..32).collect::<Vec<u8>>();
assert!(!check_entropy_side_channel_resistant(&low_entropy, 4));
assert!(check_entropy_side_channel_resistant(&high_entropy, 4));
}
}