#[cfg(feature = "alloc")]
use alloc::{
format,
vec,
vec::Vec,
};
use crate::error::HpkeError;
pub trait SecureZeroize {
fn secure_zeroize(&mut self);
}
impl SecureZeroize for [u8] {
fn secure_zeroize(&mut self) {
for byte in self.iter_mut() {
*byte = 0;
}
let _dummy = self.as_ptr();
core::sync::atomic::fence(core::sync::atomic::Ordering::SeqCst);
}
}
impl SecureZeroize for Vec<u8> {
fn secure_zeroize(&mut self) {
self.as_mut_slice().secure_zeroize();
}
}
pub struct SecureBytes {
data: Vec<u8>,
is_zeroized: bool,
}
impl SecureBytes {
pub fn new(data: Vec<u8>) -> Self {
Self {
data,
is_zeroized: false,
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
data: Vec::with_capacity(capacity),
is_zeroized: false,
}
}
pub fn zeros(len: usize) -> Self {
Self {
data: vec![0u8; len],
is_zeroized: false,
}
}
pub fn as_bytes(&self) -> &[u8] {
if self.is_zeroized { &[] } else { &self.data }
}
pub fn as_mut_bytes(&mut self) -> &mut [u8] {
if self.is_zeroized {
&mut []
} else {
&mut self.data
}
}
pub fn len(&self) -> usize {
if self.is_zeroized { 0 } else { self.data.len() }
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn extend_from_slice(&mut self, other: &[u8]) {
if !self.is_zeroized {
self.data.extend_from_slice(other);
}
}
pub fn zeroize(&mut self) {
if !self.is_zeroized {
self.data.secure_zeroize();
self.is_zeroized = true;
}
}
pub fn is_zeroized(&self) -> bool {
self.is_zeroized
}
pub fn clone_data(&self) -> SecureBytes {
if self.is_zeroized {
SecureBytes::new(Vec::new())
} else {
SecureBytes::new(self.data.clone())
}
}
}
impl Drop for SecureBytes {
fn drop(&mut self) {
self.zeroize();
}
}
impl From<Vec<u8>> for SecureBytes {
fn from(data: Vec<u8>) -> Self {
Self::new(data)
}
}
impl From<&[u8]> for SecureBytes {
fn from(data: &[u8]) -> Self {
Self::new(data.to_vec())
}
}
pub struct SecureKey {
key_data: SecureBytes,
key_type: KeyType,
}
impl SecureKey {
pub fn new(data: Vec<u8>, key_type: KeyType) -> Result<Self, HpkeError> {
let expected_len = key_type.expected_length();
if data.len() != expected_len {
return Err(HpkeError::CryptoError(format!(
"Invalid key length: expected {}, got {}",
expected_len,
data.len()
)));
}
if data.iter().all(|&b| b == 0) {
return Err(HpkeError::CryptoError(
"Key material cannot be all zeros".into(),
));
}
Ok(Self {
key_data: SecureBytes::new(data),
key_type,
})
}
pub fn key_type(&self) -> KeyType {
self.key_type
}
pub fn as_bytes(&self) -> &[u8] {
self.key_data.as_bytes()
}
pub fn len(&self) -> usize {
self.key_data.len()
}
pub fn is_empty(&self) -> bool {
self.key_data.is_empty()
}
pub fn zeroize(&mut self) {
self.key_data.zeroize();
}
pub fn is_zeroized(&self) -> bool {
self.key_data.is_zeroized()
}
}
impl Drop for SecureKey {
fn drop(&mut self) {
self.zeroize();
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KeyType {
AeadKey,
KemSecretKey,
KemPublicKey,
SharedSecret,
ExporterSecret,
}
impl KeyType {
pub fn expected_length(&self) -> usize {
match self {
KeyType::AeadKey => 32, KeyType::KemSecretKey => 1632, KeyType::KemPublicKey => 800, KeyType::SharedSecret => 32, KeyType::ExporterSecret => 32, }
}
}
pub struct SecureMemoryPool {
pools: Vec<SecureBytes>,
max_pool_size: usize,
}
impl SecureMemoryPool {
pub fn new(max_pool_size: usize) -> Self {
Self {
pools: Vec::new(),
max_pool_size,
}
}
pub fn allocate(&mut self, size: usize) -> SecureBytes {
for pool in &mut self.pools {
if pool.len() >= size && pool.is_zeroized() {
let mut reused = SecureBytes::with_capacity(size);
reused.extend_from_slice(&vec![0u8; size]);
return reused;
}
}
let new_memory = SecureBytes::zeros(size);
if self.pools.len() < self.max_pool_size {
self.pools.push(SecureBytes::zeros(size));
}
new_memory
}
pub fn clear(&mut self) {
for pool in &mut self.pools {
pool.zeroize();
}
self.pools.clear();
}
pub fn stats(&self) -> MemoryPoolStats {
let total_allocated = self.pools.iter().map(|p| p.len()).sum();
let zeroized_count = self.pools.iter().filter(|p| p.is_zeroized()).count();
MemoryPoolStats {
total_pools: self.pools.len(),
total_allocated_bytes: total_allocated,
zeroized_pools: zeroized_count,
max_pool_size: self.max_pool_size,
}
}
}
impl Drop for SecureMemoryPool {
fn drop(&mut self) {
self.clear();
}
}
#[derive(Debug, Clone)]
pub struct MemoryPoolStats {
pub total_pools: usize,
pub total_allocated_bytes: usize,
pub zeroized_pools: usize,
pub max_pool_size: usize,
}
pub struct SecureStackBuffer<const N: usize> {
data: [u8; N],
len: usize,
is_zeroized: bool,
}
impl<const N: usize> SecureStackBuffer<N> {
pub fn new() -> Self {
Self {
data: [0u8; N],
len: 0,
is_zeroized: false,
}
}
pub fn from_slice(data: &[u8]) -> Result<Self, HpkeError> {
if data.len() > N {
return Err(HpkeError::CryptoError(format!(
"Data too large for buffer: {} > {}",
data.len(),
N
)));
}
let mut buffer = Self::new();
buffer.data[..data.len()].copy_from_slice(data);
buffer.len = data.len();
Ok(buffer)
}
pub fn as_slice(&self) -> &[u8] {
if self.is_zeroized {
&[]
} else {
&self.data[..self.len]
}
}
pub fn as_mut_slice(&mut self) -> &mut [u8] {
if self.is_zeroized {
&mut []
} else {
&mut self.data[..self.len]
}
}
pub fn len(&self) -> usize {
if self.is_zeroized { 0 } else { self.len }
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn set_len(&mut self, new_len: usize) -> Result<(), HpkeError> {
if new_len > N {
return Err(HpkeError::CryptoError(format!(
"Length too large for buffer: {} > {}",
new_len, N
)));
}
if !self.is_zeroized {
self.len = new_len;
}
Ok(())
}
pub fn zeroize(&mut self) {
if !self.is_zeroized {
self.data.secure_zeroize();
self.len = 0;
self.is_zeroized = true;
}
}
pub fn is_zeroized(&self) -> bool {
self.is_zeroized
}
}
impl<const N: usize> Drop for SecureStackBuffer<N> {
fn drop(&mut self) {
self.zeroize();
}
}
impl<const N: usize> Default for SecureStackBuffer<N> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_secure_zeroize() {
let mut data = vec![1u8, 2u8, 3u8, 4u8];
data.secure_zeroize();
assert_eq!(data, vec![0u8; 4]);
}
#[test]
fn test_secure_bytes() {
let mut secure = SecureBytes::new(vec![1, 2, 3, 4]);
assert_eq!(secure.len(), 4);
assert_eq!(secure.as_bytes(), &[1, 2, 3, 4]);
assert!(!secure.is_zeroized());
secure.zeroize();
assert!(secure.is_zeroized());
assert_eq!(secure.len(), 0);
}
#[test]
fn test_secure_key() {
let key_data = vec![1u8; 32];
let key = SecureKey::new(key_data, KeyType::AeadKey).unwrap();
assert_eq!(key.key_type(), KeyType::AeadKey);
assert_eq!(key.len(), 32);
assert!(!key.is_zeroized());
}
#[test]
fn test_secure_key_validation() {
let short_key = vec![1u8; 16];
assert!(SecureKey::new(short_key, KeyType::AeadKey).is_err());
let zero_key = vec![0u8; 32];
assert!(SecureKey::new(zero_key, KeyType::AeadKey).is_err());
}
#[test]
fn test_secure_memory_pool() {
let mut pool = SecureMemoryPool::new(5);
let mem1 = pool.allocate(32);
let mem2 = pool.allocate(64);
assert_eq!(mem1.len(), 32);
assert_eq!(mem2.len(), 64);
let stats = pool.stats();
assert!(stats.total_pools <= 5);
}
#[test]
fn test_secure_stack_buffer() {
let buffer = SecureStackBuffer::<32>::new();
assert_eq!(buffer.len(), 0);
assert!(buffer.is_empty());
let data = [1u8, 2u8, 3u8, 4u8];
let mut buffer = SecureStackBuffer::<32>::from_slice(&data).unwrap();
assert_eq!(buffer.len(), 4);
assert_eq!(buffer.as_slice(), &data);
buffer.zeroize();
assert!(buffer.is_zeroized());
assert_eq!(buffer.len(), 0);
}
#[test]
fn test_key_type_lengths() {
assert_eq!(KeyType::AeadKey.expected_length(), 32);
assert_eq!(KeyType::KemSecretKey.expected_length(), 1632);
assert_eq!(KeyType::KemPublicKey.expected_length(), 800);
assert_eq!(KeyType::SharedSecret.expected_length(), 32);
assert_eq!(KeyType::ExporterSecret.expected_length(), 32);
}
}